Skip to content

Commit 690bed7

Browse files
xaduprewschin
andauthored
Add operator MulSigmoid, MulMulSigmoid (#741)
* Add operator MulSigmoid * add mul mul sigmoid * add comments * Apply suggestions from code review --------- Co-authored-by: Wei-Sheng Chin <[email protected]>
1 parent f505546 commit 690bed7

File tree

5 files changed

+335
-1
lines changed

5 files changed

+335
-1
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#ifdef USE_CUDA
77
#include "cuda/add_mul.h"
88
#include "cuda/fast_gelu.h"
9+
#include "cuda/mul_sigmoid.h"
910
#include "cuda/negxplus1.h"
1011
#include "cuda/scatter_nd_of_shape.h"
1112
#include "cuda/transpose_cast.h"
@@ -32,6 +33,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
3233
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
3334
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
3435
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
36+
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<float>),
37+
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
3538
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
3639
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
3740
#if ORT_API_VERSION >= 16
@@ -41,6 +44,8 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
4144
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
4245
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
4346
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
47+
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<ortc::MFloat16>),
48+
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
4449
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
4550
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
4651
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),

operators/cuda/mul_sigmoid.h

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "mul_sigmoid_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
/**
12+
* MulSigmoid(X) = X * Sigmoid(X)
13+
14+
No shape broadcasting supported.
15+
*/
16+
template <typename T>
17+
struct MulSigmoid {
18+
template <typename TDict>
19+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
20+
return {};
21+
}
22+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
23+
const ortc::Tensor<T>& input,
24+
ortc::Tensor<T>& output) const {
25+
const T* input_data = input.Data();
26+
T* output_data = output.Allocate(input.Shape());
27+
auto input_length = input.NumberOfElement();
28+
if (0 == input_length) {
29+
return {};
30+
}
31+
LaunchMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
32+
input_length,
33+
input_data,
34+
output_data);
35+
return {};
36+
}
37+
};
38+
39+
/**
40+
* MulSigmoid(X, Y) = X * Y * Sigmoid(Y)
41+
42+
No shape broadcasting supported.
43+
*/
44+
template <typename T>
45+
struct MulMulSigmoid {
46+
template <typename TDict>
47+
OrtxStatus OnModelAttach(const TDict& /*dict*/) {
48+
return {};
49+
}
50+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
51+
const ortc::Tensor<T>& input_x,
52+
const ortc::Tensor<T>& input_y,
53+
ortc::Tensor<T>& output) const {
54+
const T* input_data_x = input_x.Data();
55+
const T* input_data_y = input_y.Data();
56+
auto input_length_x = input_x.NumberOfElement();
57+
auto input_length_y = input_y.NumberOfElement();
58+
if (0 == input_length_x || 0 == input_data_y) {
59+
return {};
60+
}
61+
T* output_data = output.Allocate(input_length_x > input_length_y ? input_x.Shape() : input_y.Shape());
62+
LaunchMulMulSigmoidKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
63+
input_length_x,
64+
input_length_y,
65+
input_data_x,
66+
input_data_y,
67+
output_data);
68+
return {};
69+
}
70+
};
71+
72+
} // namespace contrib

operators/cuda/mul_sigmoid_impl.cu

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "mul_sigmoid_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
#ifndef CUDA_LONG
10+
#define CUDA_LONG int32_t
11+
#endif
12+
13+
using namespace Ort::Custom;
14+
15+
template <typename T> __device__ __inline__ T _exp_typed(const T x);
16+
17+
template <> __device__ __inline__ float _exp_typed(const float x) { return expf(x); }
18+
19+
#if __CUDA_ARCH__ < 700
20+
template <> __device__ __inline__ half _exp_typed(const half x) {
21+
return __float2half(expf(__half2float(x)));
22+
}
23+
#else
24+
template <> __device__ __inline__ half _exp_typed(const half x) { return hexp(x); }
25+
#endif
26+
27+
template <typename T> __device__ __inline__ T sigmoid(const T a) {
28+
return a > T(0) ? (T)1 / ((T)1. + _exp_typed<T>(-a))
29+
: (T)1 - (T)1 / ((T)1 + _exp_typed<T>(a));
30+
}
31+
32+
#if __CUDA_ARCH__ < 700
33+
template <> __device__ __inline__ half sigmoid(const half a) {
34+
return __float2half(sigmoid(__half2float(a)));
35+
}
36+
#endif
37+
38+
template <typename T> __device__ __inline__ T mul_sigmoid(const T a) { return a * sigmoid(a); }
39+
40+
#if __CUDA_ARCH__ < 700
41+
template <> __device__ __inline__ half mul_sigmoid(const half a) {
42+
float x = __half2float(a);
43+
return __float2half(x * sigmoid(x));
44+
}
45+
#endif
46+
47+
template <typename T> __device__ __inline__ T mul_mul_sigmoid(const T x, const T y) {
48+
return x * y * sigmoid(y);
49+
}
50+
51+
#if __CUDA_ARCH__ < 700
52+
template <> __device__ __inline__ half mul_mul_sigmoid(const half x, const half y) {
53+
float hy = __half2float(y);
54+
return __float2half(__half2float(x) * hy * sigmoid(hy));
55+
}
56+
#endif
57+
58+
template <typename T>
59+
__global__ void MulSigmoidKernel(T *output_data, const T *input_data, CUDA_LONG N) {
60+
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
61+
if (id >= N)
62+
return;
63+
output_data[id] = mul_sigmoid(input_data[id]);
64+
}
65+
66+
template <typename T>
67+
__global__ void MulMulSigmoidKernel(T *output_data, const T *px, const T *py, CUDA_LONG N,
68+
CUDA_LONG Nx, CUDA_LONG Ny) {
69+
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
70+
if (id >= N)
71+
return;
72+
output_data[id] = mul_mul_sigmoid(px[id % Nx], py[id % Ny]);
73+
}
74+
75+
template <typename T>
76+
cudaError_t _LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output) {
77+
constexpr int blockSize = 256;
78+
const int gridSize = (input_length + blockSize - 1) / blockSize;
79+
using TT = typename contrib::CudaT<T>::MappedType;
80+
MulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output), reinterpret_cast<const TT*>(input), input_length);
81+
return cudaGetLastError();
82+
}
83+
84+
template <>
85+
cudaError_t LaunchMulSigmoidKernel<float>(cudaStream_t stream, int input_length, const float* input, float* output) {
86+
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
87+
}
88+
89+
template <>
90+
cudaError_t LaunchMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input, ortc::MFloat16* output) {
91+
return _LaunchMulSigmoidKernel(stream, input_length, input, output);
92+
}
93+
94+
template <typename T>
95+
cudaError_t _LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y,
96+
const T* input_data_x, const T* input_data_y, T* output) {
97+
int input_length = std::max(input_length_x, input_length_y);
98+
constexpr int blockSize = 256;
99+
const int gridSize = (input_length + blockSize - 1) / blockSize;
100+
using TT = typename contrib::CudaT<T>::MappedType;
101+
MulMulSigmoidKernel<TT><<<gridSize, blockSize, 0, stream>>>(reinterpret_cast<TT*>(output),
102+
reinterpret_cast<const TT*>(input_data_x),
103+
reinterpret_cast<const TT*>(input_data_y),
104+
input_length, input_length_x, input_length_y);
105+
return cudaGetLastError();
106+
}
107+
108+
template <>
109+
cudaError_t LaunchMulMulSigmoidKernel<float>(cudaStream_t stream, int input_length_x, int input_length_y,
110+
const float* input_data_x, const float* input_data_y, float* output) {
111+
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output);
112+
}
113+
114+
template <>
115+
cudaError_t LaunchMulMulSigmoidKernel<ortc::MFloat16>(cudaStream_t stream, int input_length_x, int input_length_y,
116+
const ortc::MFloat16* input_data_x, const ortc::MFloat16* input_data_y,
117+
ortc::MFloat16* output) {
118+
return _LaunchMulMulSigmoidKernel(stream, input_length_x, input_length_y, input_data_x, input_data_y, output);
119+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename T>
9+
cudaError_t LaunchMulSigmoidKernel(cudaStream_t stream, int input_length, const T* input, T* output);
10+
11+
template <typename T>
12+
cudaError_t LaunchMulMulSigmoidKernel(cudaStream_t stream, int input_length_x, int input_length_y,
13+
const T* input_data_x, const T* input_data_y, T* output);

test/cuda/test_cudaops.py

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import unittest
22
import numpy as np
3-
from numpy.testing import assert_almost_equal
3+
from numpy.testing import assert_almost_equal, assert_allclose
44
from onnx import helper, numpy_helper, onnx_pb as onnx_proto, TensorProto
55
from onnx.reference import ReferenceEvaluator
66
from onnx.reference.op_run import OpRun
@@ -128,6 +128,131 @@ def test_cuda_fastgelu_f16(self):
128128
else:
129129
print("CUDAExecutionProvider not available, test_cuda_fastgelu_f16 skipped.")
130130

131+
def _mulmulsigmoid_cuda(self, itype, broad=False, atol=1e-5, rtol=1e-3):
132+
model1 = helper.make_model(
133+
helper.make_graph(
134+
[
135+
helper.make_node("Mul", ["X", "Y"], ["xy"]),
136+
helper.make_node("Sigmoid", ["Y"], ["sy"]),
137+
helper.make_node("Mul", ["xy", "sy"], ["final"]),
138+
],
139+
"nd",
140+
[
141+
helper.make_tensor_value_info("X", itype, [None, None, None]),
142+
helper.make_tensor_value_info("Y", itype, [None, None, None]),
143+
],
144+
[helper.make_tensor_value_info("final", itype, [None, None, None])],
145+
),
146+
opset_imports=[helper.make_opsetid("", 18)],
147+
ir_version=9,
148+
)
149+
150+
model2 = helper.make_model(
151+
helper.make_graph(
152+
[
153+
helper.make_node(
154+
"MulMulSigmoid",
155+
["X", "Y"],
156+
["final"],
157+
domain="ai.onnx.contrib",
158+
)
159+
],
160+
"nd",
161+
[
162+
helper.make_tensor_value_info("X", itype, [None, None, None]),
163+
helper.make_tensor_value_info("Y", itype, [None, None, None]),
164+
],
165+
[helper.make_tensor_value_info("final", itype, [None, None, None])],
166+
),
167+
opset_imports=[
168+
helper.make_opsetid("", 18),
169+
helper.make_opsetid("ai.onnx.contrib", 1),
170+
],
171+
ir_version=9,
172+
)
173+
174+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
175+
shapex = (1, 2, 3) if broad else (3, 2, 3)
176+
shapey = (3, 2, 3)
177+
x = (np.arange(np.prod(shapex)) + 1).reshape(shapex).astype(dtype)
178+
y = (np.arange(np.prod(shapey)) + 2).reshape(shapey).astype(dtype)
179+
x /= x.size
180+
y /= y.size
181+
182+
feeds1 = dict(X=x, Y=y)
183+
ref = ReferenceEvaluator(model1)
184+
expected = ref.run(None, feeds1)[0]
185+
186+
opts = _ort.SessionOptions()
187+
opts.register_custom_ops_library(_get_library_path())
188+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
189+
got = sess.run(None, feeds1)[0]
190+
assert_allclose(expected, got, atol=atol, rtol=rtol)
191+
192+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
193+
def test_mulmulsigmoid_cuda(self):
194+
self._mulmulsigmoid_cuda(TensorProto.FLOAT)
195+
self._mulmulsigmoid_cuda(TensorProto.FLOAT16)
196+
197+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
198+
def test_mulmulsigmoid_cuda_broadcast(self):
199+
self._mulmulsigmoid_cuda(TensorProto.FLOAT, True)
200+
self._mulmulsigmoid_cuda(TensorProto.FLOAT16, True)
201+
202+
def _mul_sigmoid_cuda(self, itype):
203+
model1 = helper.make_model(
204+
helper.make_graph(
205+
[
206+
helper.make_node("Sigmoid", ["X"], ["sx"]),
207+
helper.make_node("Mul", ["X", "sx"], ["Y"]),
208+
],
209+
"nd",
210+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
211+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
212+
),
213+
opset_imports=[helper.make_opsetid("", 18)],
214+
ir_version=9,
215+
)
216+
217+
model2 = helper.make_model(
218+
helper.make_graph(
219+
[
220+
helper.make_node(
221+
"MulSigmoid",
222+
["X"],
223+
["Y"],
224+
domain="ai.onnx.contrib",
225+
)
226+
],
227+
"nd",
228+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
229+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
230+
),
231+
opset_imports=[
232+
helper.make_opsetid("", 18),
233+
helper.make_opsetid("ai.onnx.contrib", 1),
234+
],
235+
ir_version=9,
236+
)
237+
238+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
239+
x = (np.arange(18) + 1).reshape((3, 2, 3)).astype(dtype)
240+
241+
feeds1 = dict(X=x)
242+
ref = ReferenceEvaluator(model1)
243+
expected = ref.run(None, feeds1)[0]
244+
245+
opts = _ort.SessionOptions()
246+
opts.register_custom_ops_library(_get_library_path())
247+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
248+
got = sess.run(None, feeds1)[0]
249+
assert_allclose(expected, got, atol=1e-5 if itype == TensorProto.FLOAT else 1e-2)
250+
251+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
252+
def test_mul_sigmoid_cuda(self):
253+
self._mul_sigmoid_cuda(TensorProto.FLOAT)
254+
self._mul_sigmoid_cuda(TensorProto.FLOAT16)
255+
131256
def _negxplus1_cuda(self, itype):
132257
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
133258
model1 = helper.make_model(

0 commit comments

Comments
 (0)