@@ -26,20 +26,20 @@ using namespace cv::dnn::cuda4dnn::csl::device;
26
26
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
27
27
28
28
namespace raw {
29
- template <class T , class Functor , std::size_t N, class ...FunctorArgs >
30
- __global__ void generic_op_vec (Span<T> output, View<T> input, FunctorArgs ...functorArgs ) {
29
+ template <class T , class ActivationOp , std::size_t N>
30
+ __global__ void generic_op_vec (Span<T> output, View<T> input, const typename ActivationOp::Params params ) {
31
31
using vector_type = get_vector_type_t <T, N>;
32
32
33
33
auto output_vPtr = vector_type::get_pointer (output.data ());
34
34
auto input_vPtr = vector_type::get_pointer (input.data ());
35
35
36
- Functor functor (functorArgs... );
36
+ ActivationOp activation_op (params );
37
37
38
38
for (auto i : grid_stride_range (output.size () / vector_type::size ())) {
39
39
vector_type vec;
40
40
v_load (vec, input_vPtr[i]);
41
41
for (int j = 0 ; j < vector_type::size (); j++)
42
- vec.data [j] = functor (vec.data [j]);
42
+ vec.data [j] = activation_op (vec.data [j]);
43
43
v_store (output_vPtr[i], vec);
44
44
}
45
45
}
@@ -51,9 +51,8 @@ namespace raw {
51
51
auto output_vPtr = vector_type::get_pointer (output.data ());
52
52
auto input_vPtr = vector_type::get_pointer (input.data ());
53
53
54
- inner_size /= vector_type::size ();
55
54
for (auto i : grid_stride_range (output.size () / vector_type::size ())) {
56
- const index_type c = (i / inner_size) % static_cast <size_type>( slope.size () );
55
+ const index_type c = (i / inner_size) % slope.size ();
57
56
58
57
vector_type vec;
59
58
v_load (vec, input_vPtr[i]);
@@ -65,73 +64,73 @@ namespace raw {
65
64
66
65
} /* namespace raw */
67
66
68
- template <class T , template < class > class Activation , std::size_t N, class ...ActivationArgs > static
69
- void launch_vectorized_generic_op (const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs ) {
67
+ template <class T , class ActivationOp , std::size_t N> static
68
+ void launch_vectorized_generic_op (const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params ) {
70
69
CV_Assert (is_fully_aligned<T>(output, N));
71
70
CV_Assert (is_fully_aligned<T>(input, N));
72
71
73
- auto kernel = raw::generic_op_vec<T, Activation<T> , N, ActivationArgs... >;
72
+ auto kernel = raw::generic_op_vec<T, ActivationOp , N>;
74
73
auto policy = make_policy (kernel, output.size () / N, 0 , stream);
75
- launch_kernel (kernel, policy, output, input, activationArgs... );
74
+ launch_kernel (kernel, policy, output, input, params );
76
75
}
77
76
78
- template <class T , template < class > class Activation , class ...ActivationArgs > static
79
- void generic_op (const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs ) {
77
+ template <class T , class ActivationOp > static
78
+ void generic_op (const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params = {} ) {
80
79
CV_Assert (input.size () == output.size ());
81
80
82
81
if (is_fully_aligned<T>(output, 4 ) && is_fully_aligned<T>(input, 4 )) {
83
- launch_vectorized_generic_op<T, Activation , 4 >(stream, output, input, activationArgs... );
82
+ launch_vectorized_generic_op<T, ActivationOp , 4 >(stream, output, input, params );
84
83
} else if (is_fully_aligned<T>(output, 2 ) && is_fully_aligned<T>(input, 2 )) {
85
- launch_vectorized_generic_op<T, Activation , 2 >(stream, output, input, activationArgs... );
84
+ launch_vectorized_generic_op<T, ActivationOp , 2 >(stream, output, input, params );
86
85
} else {
87
- launch_vectorized_generic_op<T, Activation , 1 >(stream, output, input, activationArgs... );
86
+ launch_vectorized_generic_op<T, ActivationOp , 1 >(stream, output, input, params );
88
87
}
89
88
}
90
89
91
90
template <class T >
92
- void abs (const Stream& stream, Span<T> output, View<T> input) {
93
- generic_op<T, abs_functor>(stream, output, input);
91
+ void relu (const Stream& stream, Span<T> output, View<T> input, T slope) {
92
+ generic_op<T, ReLUFunctor<T>>(stream, output, input, {slope});
93
+ }
94
+
95
+ template <class T >
96
+ void clipped_relu (const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
97
+ CV_Assert (static_cast <double >(floor) <= static_cast <double >(ceiling));
98
+ generic_op<T, ClippedReLUFunctor<T>>(stream, output, input, {floor, ceiling});
94
99
}
95
100
96
101
template <class T >
97
102
void tanh (const Stream& stream, Span<T> output, View<T> input) {
98
- generic_op<T, tanh_functor >(stream, output, input);
103
+ generic_op<T, TanHFunctor<T> >(stream, output, input);
99
104
}
100
105
101
106
template <class T >
102
107
void swish (const Stream& stream, Span<T> output, View<T> input) {
103
- generic_op<T, swish_functor >(stream, output, input);
108
+ generic_op<T, SwishFunctor<T> >(stream, output, input);
104
109
}
105
110
106
111
template <class T >
107
112
void mish (const Stream& stream, Span<T> output, View<T> input) {
108
- generic_op<T, mish_functor >(stream, output, input);
113
+ generic_op<T, MishFunctor<T> >(stream, output, input);
109
114
}
110
115
111
116
template <class T >
112
117
void sigmoid (const Stream& stream, Span<T> output, View<T> input) {
113
- generic_op<T, sigmoid_functor>(stream, output, input);
114
- }
115
-
116
- template <class T >
117
- void bnll (const Stream& stream, Span<T> output, View<T> input) {
118
- generic_op<T, bnll_functor>(stream, output, input);
118
+ generic_op<T, SigmoidFunctor<T>>(stream, output, input);
119
119
}
120
120
121
121
template <class T >
122
122
void elu (const Stream& stream, Span<T> output, View<T> input) {
123
- generic_op<T, elu_functor >(stream, output, input);
123
+ generic_op<T, ELUFunctor<T> >(stream, output, input);
124
124
}
125
125
126
126
template <class T >
127
- void relu (const Stream& stream, Span<T> output, View<T> input, T slope ) {
128
- generic_op<T, relu_functor> (stream, output, input, slope );
127
+ void bnll (const Stream& stream, Span<T> output, View<T> input) {
128
+ generic_op<T, BNLLFunctor<T>> (stream, output, input);
129
129
}
130
130
131
131
template <class T >
132
- void clipped_relu (const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
133
- CV_Assert (static_cast <double >(floor) <= static_cast <double >(ceiling));
134
- generic_op<T, clipped_relu_functor>(stream, output, input, floor, ceiling);
132
+ void abs (const Stream& stream, Span<T> output, View<T> input) {
133
+ generic_op<T, AbsFunctor<T>>(stream, output, input);
135
134
}
136
135
137
136
template <class T >
@@ -143,31 +142,32 @@ void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale,
143
142
return ;
144
143
}
145
144
146
- generic_op<T, power_functor> (stream, output, input, exp, scale, shift);
145
+ generic_op<T, PowerFunctor<T>> (stream, output, input, { exp, scale, shift} );
147
146
}
148
147
149
148
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
150
- template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
149
+ template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
150
+ template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
151
151
template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
152
152
template void swish<__half>(const Stream&, Span<__half>, View<__half>);
153
153
template void mish<__half>(const Stream&, Span<__half>, View<__half>);
154
154
template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
155
- template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
156
155
template void elu<__half>(const Stream&, Span<__half>, View<__half>);
157
- template void relu <__half>(const Stream&, Span<__half>, View<__half>, __half );
158
- template void clipped_relu <__half>(const Stream&, Span<__half>, View<__half>, __half, __half );
156
+ template void abs <__half>(const Stream& stream , Span<__half> output , View<__half> input );
157
+ template void bnll <__half>(const Stream&, Span<__half>, View<__half>);
159
158
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
160
159
#endif
161
160
162
- template void abs<float >(const Stream& stream, Span<float > output, View<float > input);
161
+
162
+ template void relu<float >(const Stream&, Span<float >, View<float >, float );
163
+ template void clipped_relu<float >(const Stream&, Span<float >, View<float >, float , float );
163
164
template void tanh<float >(const Stream&, Span<float >, View<float >);
164
165
template void swish<float >(const Stream&, Span<float >, View<float >);
165
166
template void mish<float >(const Stream&, Span<float >, View<float >);
166
167
template void sigmoid<float >(const Stream&, Span<float >, View<float >);
167
- template void bnll<float >(const Stream&, Span<float >, View<float >);
168
168
template void elu<float >(const Stream&, Span<float >, View<float >);
169
- template void relu <float >(const Stream&, Span<float >, View<float >, float );
170
- template void clipped_relu <float >(const Stream&, Span<float >, View<float >, float , float );
169
+ template void abs <float >(const Stream& stream , Span<float > output , View<float > input );
170
+ template void bnll <float >(const Stream&, Span<float >, View<float >);
171
171
template void power<float >(const Stream&, Span<float >, View<float >, float , float , float );
172
172
173
173
template <class T , std::size_t N> static
@@ -178,7 +178,7 @@ void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<
178
178
179
179
auto kernel = raw::axiswise_relu_vec<T, N>;
180
180
auto policy = make_policy (kernel, output.size () / N, 0 , stream);
181
- launch_kernel (kernel, policy, output, input, inner_size, slope);
181
+ launch_kernel (kernel, policy, output, input, inner_size / N , slope);
182
182
}
183
183
184
184
template <class T >
0 commit comments