Skip to content

Commit d0e6d24

Browse files
authored
Merge pull request opencv#17363 from YashasSamaga:cuda4dnn-eltwise-fusion2
cuda4dnn(conv): fuse eltwise with convolutions * fuse eltwise with convolutions * manually rebase to avoid bad git merge
1 parent 44d473f commit d0e6d24

22 files changed

+1609
-273
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// This file is part of OpenCV project.
2+
// It is subject to the license terms in the LICENSE file found in the top-level directory
3+
// of this distribution and at http://opencv.org/license.html.
4+
5+
#include <cuda_runtime.h>
6+
#include <cuda_fp16.h>
7+
8+
#include "functors.hpp"
9+
#include "vector_traits.hpp"
10+
#include "grid_stride_range.hpp"
11+
#include "execution.hpp"
12+
13+
#include "../cuda4dnn/csl/stream.hpp"
14+
#include "../cuda4dnn/csl/span.hpp"
15+
16+
using namespace cv::dnn::cuda4dnn::csl;
17+
using namespace cv::dnn::cuda4dnn::csl::device;
18+
19+
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
20+
21+
namespace raw {
22+
23+
template <class T, class ActivationOp, class EltwiseOp, std::size_t N>
24+
__global__ void generic_op_eltwise_op_inplace_vec(Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params act_params, const typename EltwiseOp::Params eltwise_params) {
25+
using vector_type = get_vector_type_t<T, N>;
26+
27+
auto inplace_output_vPtr = vector_type::get_pointer(inplace_output.data());
28+
auto eltwise_vPtr = vector_type::get_pointer(eltwise.data());
29+
30+
ActivationOp activation_op(act_params);
31+
EltwiseOp eltwise_op(eltwise_params);
32+
33+
for (auto i : grid_stride_range(inplace_output.size() / vector_type::size())) {
34+
vector_type output_vec, eltwise_vec;
35+
v_load(output_vec, inplace_output_vPtr[i]);
36+
v_load(eltwise_vec, eltwise_vPtr[i]);
37+
for(int j = 0; j < output_vec.size(); j++)
38+
output_vec.data[j] = eltwise_op(activation_op(output_vec.data[j]), eltwise_vec.data[j]);
39+
v_store(inplace_output_vPtr[i], output_vec);
40+
}
41+
}
42+
}
43+
44+
template <class T, class ActivationOp, class EltwiseOp, std::size_t N> static
45+
void launch_vectorized_generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params& act_params, const typename EltwiseOp::Params& eltwise_params) {
46+
CV_Assert(is_fully_aligned<T>(inplace_output, N));
47+
CV_Assert(is_fully_aligned<T>(eltwise, N));
48+
49+
auto kernel = raw::generic_op_eltwise_op_inplace_vec<T, ActivationOp, EltwiseOp, N>;
50+
auto policy = make_policy(kernel, inplace_output.size() / N, 0, stream);
51+
launch_kernel(kernel, policy, inplace_output, eltwise, act_params, eltwise_params);
52+
}
53+
54+
template <class T, class ActivationOp, class EltwiseOp> static
55+
void generic_op_eltwise_op_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, const typename ActivationOp::Params& act_params = {}, const typename EltwiseOp::Params& eltwise_params = {}) {
56+
CV_Assert(inplace_output.size() == eltwise.size());
57+
58+
if (is_fully_aligned<T>(inplace_output, 4) && is_fully_aligned<T>(eltwise, 4)) {
59+
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 4>(stream, inplace_output, eltwise, act_params, eltwise_params);
60+
} else if (is_fully_aligned<T>(inplace_output, 2) && is_fully_aligned<T>(eltwise, 2)) {
61+
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 2>(stream, inplace_output, eltwise, act_params, eltwise_params);
62+
} else {
63+
launch_vectorized_generic_op_eltwise_op_inplace<T, ActivationOp, EltwiseOp, 1>(stream, inplace_output, eltwise, act_params, eltwise_params);
64+
}
65+
}
66+
67+
template <class T>
68+
void relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T slope) {
69+
generic_op_eltwise_op_inplace<T, ReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {slope});
70+
}
71+
72+
template <class T>
73+
void clipped_relu_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T floor, T ceiling) {
74+
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
75+
generic_op_eltwise_op_inplace<T, ClippedReLUFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {floor, ceiling});
76+
}
77+
78+
template <class T>
79+
void tanh_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
80+
generic_op_eltwise_op_inplace<T, TanHFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
81+
}
82+
83+
template <class T>
84+
void swish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
85+
generic_op_eltwise_op_inplace<T, SwishFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
86+
}
87+
88+
template <class T>
89+
void mish_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
90+
generic_op_eltwise_op_inplace<T, MishFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
91+
}
92+
93+
template <class T>
94+
void sigmoid_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise) {
95+
generic_op_eltwise_op_inplace<T, SigmoidFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise);
96+
}
97+
98+
template <class T>
99+
void power_eltwise_sum_2_inplace(const Stream& stream, Span<T> inplace_output, View<T> eltwise, T exp, T scale, T shift) {
100+
generic_op_eltwise_op_inplace<T, PowerFunctor<T>, SumFunctor<T>>(stream, inplace_output, eltwise, {exp, scale, shift});
101+
}
102+
103+
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
104+
template void relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half);
105+
template void clipped_relu_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
106+
template void tanh_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
107+
template void swish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
108+
template void mish_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
109+
template void sigmoid_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>);
110+
template void power_eltwise_sum_2_inplace<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
111+
#endif
112+
113+
template void relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float);
114+
template void clipped_relu_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float, float);
115+
template void tanh_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
116+
template void swish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
117+
template void mish_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
118+
template void sigmoid_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>);
119+
template void power_eltwise_sum_2_inplace<float>(const Stream&, Span<float>, View<float>, float, float, float);
120+
121+
}}}} /* namespace cv::dnn::cuda4dnn::kernels */

modules/dnn/src/cuda/activations.cu

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ using namespace cv::dnn::cuda4dnn::csl::device;
2626
namespace cv { namespace dnn { namespace cuda4dnn { namespace kernels {
2727

2828
namespace raw {
29-
template <class T, class Functor, std::size_t N, class ...FunctorArgs>
30-
__global__ void generic_op_vec(Span<T> output, View<T> input, FunctorArgs ...functorArgs) {
29+
template <class T, class ActivationOp, std::size_t N>
30+
__global__ void generic_op_vec(Span<T> output, View<T> input, const typename ActivationOp::Params params) {
3131
using vector_type = get_vector_type_t<T, N>;
3232

3333
auto output_vPtr = vector_type::get_pointer(output.data());
3434
auto input_vPtr = vector_type::get_pointer(input.data());
3535

36-
Functor functor(functorArgs...);
36+
ActivationOp activation_op(params);
3737

3838
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
3939
vector_type vec;
4040
v_load(vec, input_vPtr[i]);
4141
for (int j = 0; j < vector_type::size(); j++)
42-
vec.data[j] = functor(vec.data[j]);
42+
vec.data[j] = activation_op(vec.data[j]);
4343
v_store(output_vPtr[i], vec);
4444
}
4545
}
@@ -51,9 +51,8 @@ namespace raw {
5151
auto output_vPtr = vector_type::get_pointer(output.data());
5252
auto input_vPtr = vector_type::get_pointer(input.data());
5353

54-
inner_size /= vector_type::size();
5554
for (auto i : grid_stride_range(output.size() / vector_type::size())) {
56-
const index_type c = (i / inner_size) % static_cast<size_type>(slope.size());
55+
const index_type c = (i / inner_size) % slope.size();
5756

5857
vector_type vec;
5958
v_load(vec, input_vPtr[i]);
@@ -65,73 +64,73 @@ namespace raw {
6564

6665
} /* namespace raw */
6766

68-
template <class T, template <class> class Activation, std::size_t N, class ...ActivationArgs> static
69-
void launch_vectorized_generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
67+
template <class T, class ActivationOp, std::size_t N> static
68+
void launch_vectorized_generic_op(const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params) {
7069
CV_Assert(is_fully_aligned<T>(output, N));
7170
CV_Assert(is_fully_aligned<T>(input, N));
7271

73-
auto kernel = raw::generic_op_vec<T, Activation<T>, N, ActivationArgs...>;
72+
auto kernel = raw::generic_op_vec<T, ActivationOp, N>;
7473
auto policy = make_policy(kernel, output.size() / N, 0, stream);
75-
launch_kernel(kernel, policy, output, input, activationArgs...);
74+
launch_kernel(kernel, policy, output, input, params);
7675
}
7776

78-
template <class T, template <class> class Activation, class ...ActivationArgs> static
79-
void generic_op(const Stream& stream, Span<T> output, View<T> input, ActivationArgs ...activationArgs) {
77+
template <class T, class ActivationOp> static
78+
void generic_op(const Stream& stream, Span<T> output, View<T> input, const typename ActivationOp::Params& params = {}) {
8079
CV_Assert(input.size() == output.size());
8180

8281
if (is_fully_aligned<T>(output, 4) && is_fully_aligned<T>(input, 4)) {
83-
launch_vectorized_generic_op<T, Activation, 4>(stream, output, input, activationArgs...);
82+
launch_vectorized_generic_op<T, ActivationOp, 4>(stream, output, input, params);
8483
} else if (is_fully_aligned<T>(output, 2) && is_fully_aligned<T>(input, 2)) {
85-
launch_vectorized_generic_op<T, Activation, 2>(stream, output, input, activationArgs...);
84+
launch_vectorized_generic_op<T, ActivationOp, 2>(stream, output, input, params);
8685
} else {
87-
launch_vectorized_generic_op<T, Activation, 1>(stream, output, input, activationArgs...);
86+
launch_vectorized_generic_op<T, ActivationOp, 1>(stream, output, input, params);
8887
}
8988
}
9089

9190
template <class T>
92-
void abs(const Stream& stream, Span<T> output, View<T> input) {
93-
generic_op<T, abs_functor>(stream, output, input);
91+
void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
92+
generic_op<T, ReLUFunctor<T>>(stream, output, input, {slope});
93+
}
94+
95+
template <class T>
96+
void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
97+
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
98+
generic_op<T, ClippedReLUFunctor<T>>(stream, output, input, {floor, ceiling});
9499
}
95100

96101
template <class T>
97102
void tanh(const Stream& stream, Span<T> output, View<T> input) {
98-
generic_op<T, tanh_functor>(stream, output, input);
103+
generic_op<T, TanHFunctor<T>>(stream, output, input);
99104
}
100105

101106
template <class T>
102107
void swish(const Stream& stream, Span<T> output, View<T> input) {
103-
generic_op<T, swish_functor>(stream, output, input);
108+
generic_op<T, SwishFunctor<T>>(stream, output, input);
104109
}
105110

106111
template <class T>
107112
void mish(const Stream& stream, Span<T> output, View<T> input) {
108-
generic_op<T, mish_functor>(stream, output, input);
113+
generic_op<T, MishFunctor<T>>(stream, output, input);
109114
}
110115

111116
template <class T>
112117
void sigmoid(const Stream& stream, Span<T> output, View<T> input) {
113-
generic_op<T, sigmoid_functor>(stream, output, input);
114-
}
115-
116-
template <class T>
117-
void bnll(const Stream& stream, Span<T> output, View<T> input) {
118-
generic_op<T, bnll_functor>(stream, output, input);
118+
generic_op<T, SigmoidFunctor<T>>(stream, output, input);
119119
}
120120

121121
template <class T>
122122
void elu(const Stream& stream, Span<T> output, View<T> input) {
123-
generic_op<T, elu_functor>(stream, output, input);
123+
generic_op<T, ELUFunctor<T>>(stream, output, input);
124124
}
125125

126126
template <class T>
127-
void relu(const Stream& stream, Span<T> output, View<T> input, T slope) {
128-
generic_op<T, relu_functor>(stream, output, input, slope);
127+
void bnll(const Stream& stream, Span<T> output, View<T> input) {
128+
generic_op<T, BNLLFunctor<T>>(stream, output, input);
129129
}
130130

131131
template <class T>
132-
void clipped_relu(const Stream& stream, Span<T> output, View<T> input, T floor, T ceiling) {
133-
CV_Assert(static_cast<double>(floor) <= static_cast<double>(ceiling));
134-
generic_op<T, clipped_relu_functor>(stream, output, input, floor, ceiling);
132+
void abs(const Stream& stream, Span<T> output, View<T> input) {
133+
generic_op<T, AbsFunctor<T>>(stream, output, input);
135134
}
136135

137136
template <class T>
@@ -143,31 +142,32 @@ void power(const Stream& stream, Span<T> output, View<T> input, T exp, T scale,
143142
return;
144143
}
145144

146-
generic_op<T, power_functor>(stream, output, input, exp, scale, shift);
145+
generic_op<T, PowerFunctor<T>>(stream, output, input, {exp, scale, shift});
147146
}
148147

149148
#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 530)
150-
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
149+
template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
150+
template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
151151
template void tanh<__half>(const Stream&, Span<__half>, View<__half>);
152152
template void swish<__half>(const Stream&, Span<__half>, View<__half>);
153153
template void mish<__half>(const Stream&, Span<__half>, View<__half>);
154154
template void sigmoid<__half>(const Stream&, Span<__half>, View<__half>);
155-
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
156155
template void elu<__half>(const Stream&, Span<__half>, View<__half>);
157-
template void relu<__half>(const Stream&, Span<__half>, View<__half>, __half);
158-
template void clipped_relu<__half>(const Stream&, Span<__half>, View<__half>, __half, __half);
156+
template void abs<__half>(const Stream& stream, Span<__half> output, View<__half> input);
157+
template void bnll<__half>(const Stream&, Span<__half>, View<__half>);
159158
template void power<__half>(const Stream&, Span<__half>, View<__half>, __half, __half, __half);
160159
#endif
161160

162-
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
161+
162+
template void relu<float>(const Stream&, Span<float>, View<float>, float);
163+
template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
163164
template void tanh<float>(const Stream&, Span<float>, View<float>);
164165
template void swish<float>(const Stream&, Span<float>, View<float>);
165166
template void mish<float>(const Stream&, Span<float>, View<float>);
166167
template void sigmoid<float>(const Stream&, Span<float>, View<float>);
167-
template void bnll<float>(const Stream&, Span<float>, View<float>);
168168
template void elu<float>(const Stream&, Span<float>, View<float>);
169-
template void relu<float>(const Stream&, Span<float>, View<float>, float);
170-
template void clipped_relu<float>(const Stream&, Span<float>, View<float>, float, float);
169+
template void abs<float>(const Stream& stream, Span<float> output, View<float> input);
170+
template void bnll<float>(const Stream&, Span<float>, View<float>);
171171
template void power<float>(const Stream&, Span<float>, View<float>, float, float, float);
172172

173173
template <class T, std::size_t N> static
@@ -178,7 +178,7 @@ void launch_vectorized_axiswise_relu(const Stream& stream, Span<T> output, View<
178178

179179
auto kernel = raw::axiswise_relu_vec<T, N>;
180180
auto policy = make_policy(kernel, output.size() / N, 0, stream);
181-
launch_kernel(kernel, policy, output, input, inner_size, slope);
181+
launch_kernel(kernel, policy, output, input, inner_size / N, slope);
182182
}
183183

184184
template <class T>

0 commit comments

Comments
 (0)