Skip to content

Commit bef5f07

Browse files
authored
Add custom ops ReplaceZero (#739)
* Add custom ops ReplaceZero * fix merge conflicts
1 parent 05df33b commit bef5f07

File tree

5 files changed

+191
-2
lines changed

5 files changed

+191
-2
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
#include "cuda/fast_gelu.h"
99
#include "cuda/mul_sigmoid.h"
1010
#include "cuda/negxplus1.h"
11+
#include "cuda/replace_zero.h"
1112
#include "cuda/scatter_nd_of_shape.h"
1213
#include "cuda/transpose_cast.h"
1314
#endif
1415

1516
FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
16-
1717
using AddSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, true>;
1818
using MulSharedInputFloat32Type = typename contrib::AddOrMulSharedInput<float, false>;
1919

@@ -24,7 +24,6 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
2424
using Transpose2DCastFloat16ToFloat32Type = typename contrib::Transpose2DCast<ortc::MFloat16, float>;
2525
#endif
2626

27-
2827
static OrtOpLoader op_loader(
2928
[]() { return nullptr; }
3029
#ifdef USE_CUDA
@@ -36,6 +35,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
3635
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<float>),
3736
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<float>),
3837
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
38+
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<float>),
3939
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
4040
#if ORT_API_VERSION >= 16
4141

@@ -47,6 +47,7 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
4747
CustomCudaStructV2("MulMulSigmoid", contrib::MulMulSigmoid<ortc::MFloat16>),
4848
CustomCudaStructV2("MulSigmoid", contrib::MulSigmoid<ortc::MFloat16>),
4949
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
50+
CustomCudaStructV2("ReplaceZero", contrib::ReplaceZero<ortc::MFloat16>),
5051
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
5152
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
5253
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)

operators/cuda/replace_zero.h

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "replace_zero_impl.cuh"
7+
#include "ortx_common.h"
8+
9+
namespace contrib {
10+
11+
/**
12+
* Y = ReplaceZero(X, by=c) is equivalent to:
13+
*
14+
* Y = X.copy()
15+
* X[X == 0] = c
16+
*
17+
* This operation usually appears when a tensor is updated with an operator Equal and Where.
18+
* This kernel avoids the creation of one null tensor.
19+
*/
20+
template <typename T>
21+
struct ReplaceZero {
22+
template <typename TDict>
23+
OrtxStatus OnModelAttach(const TDict& dict) {
24+
float default_value=0;
25+
by_ = dict.TryToGetAttributeWithDefault("by", default_value);
26+
return {};
27+
}
28+
OrtxStatus Compute(Ort::Custom::CUDAKernelContext* ctx,
29+
const ortc::Tensor<T>& input,
30+
ortc::Tensor<T>& output) const {
31+
const T* input_data = input.Data();
32+
auto input_shape = input.Shape();
33+
T* output_data = output.Allocate(input_shape);
34+
auto input_length = input.NumberOfElement();
35+
if (0 == input_length) {
36+
return {};
37+
}
38+
39+
LaunchReplaceZeroKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
40+
input_length,
41+
input_data,
42+
output_data,
43+
by_);
44+
return {};
45+
}
46+
47+
private:
48+
float by_;
49+
};
50+
51+
} // namespace contrib
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "device_prop.cuh"
5+
#include "utils.cuh"
6+
#include "replace_zero_impl.cuh"
7+
#include "cuda_type.h"
8+
9+
#ifndef CUDA_LONG
10+
#define CUDA_LONG int32_t
11+
#endif
12+
13+
using namespace Ort::Custom;
14+
15+
template <typename T>
16+
__device__ __inline__ T _replace_zero(const T x, const T by) {
17+
return x == (T)0 ? by : x;
18+
}
19+
20+
template <>
21+
__device__ __inline__ half _replace_zero(const half x, const half by) {
22+
#if __CUDA_ARCH__ < 700
23+
return __half2float(x) == 0 ? by : x;
24+
#else
25+
return x == (half)0 ? by : x;
26+
#endif
27+
}
28+
29+
template <typename T>
30+
__global__ void ReplaceZeroKernel(T* output_data, const T* input_data, CUDA_LONG N, const T by) {
31+
CUDA_LONG id = blockDim.x * blockIdx.x + threadIdx.x;
32+
if (id >= N)
33+
return;
34+
output_data[id] = _replace_zero(input_data[id], by);
35+
}
36+
37+
template <typename T>
38+
T _cast(float value) { return (T)value; }
39+
40+
template <>
41+
half _cast(float value) { return __float2half(value); }
42+
43+
template <typename T>
44+
cudaError_t _LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by) {
45+
if (input_length == 0)
46+
return cudaGetLastError();
47+
using TT = typename contrib::CudaT<T>::MappedType;
48+
49+
CUDA_LONG N = static_cast<CUDA_LONG>(input_length);
50+
51+
const int num_threads_per_block = 256;
52+
const int num_elements_per_thread = (N + num_threads_per_block - 1) / num_threads_per_block;
53+
54+
TT cby = _cast<TT>(by);
55+
ReplaceZeroKernel<TT><<<num_elements_per_thread, num_threads_per_block, 0, stream>>>(
56+
reinterpret_cast<TT*>(output_data), reinterpret_cast<const TT*>(input_data), N, cby);
57+
return cudaGetLastError();
58+
}
59+
60+
template <>
61+
cudaError_t LaunchReplaceZeroKernel<float>(cudaStream_t stream, int input_length, const float* input_data, float* output_data, float by) {
62+
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
63+
}
64+
65+
template <>
66+
cudaError_t LaunchReplaceZeroKernel<ortc::MFloat16>(cudaStream_t stream, int input_length, const ortc::MFloat16* input_data, ortc::MFloat16* output_data, float by) {
67+
return _LaunchReplaceZeroKernel(stream, input_length, input_data, output_data, by);
68+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include <cuda.h>
6+
#include <cuda_runtime.h>
7+
8+
template <typename T>
9+
cudaError_t LaunchReplaceZeroKernel(cudaStream_t stream, int input_length, const T* input_data, T* output_data, float by);

test/cuda/test_cudaops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,66 @@ def test_transpose_cast_cuda(self):
652652
self._transpose_cast_cuda(TensorProto.FLOAT)
653653
self._transpose_cast_cuda(TensorProto.FLOAT16)
654654

655+
def _replace_zero_cuda(self, itype):
656+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
657+
model1 = helper.make_model(
658+
helper.make_graph(
659+
[
660+
helper.make_node("Equal", ["X", "zero"], ["cond"]),
661+
helper.make_node("Where", ["cond", "cst", "X"], ["Y"]),
662+
],
663+
"nd",
664+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
665+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
666+
[
667+
numpy_helper.from_array(np.array([0], dtype=dtype), name="zero"),
668+
numpy_helper.from_array(np.array([1.67], dtype=dtype), name="cst"),
669+
],
670+
),
671+
opset_imports=[helper.make_opsetid("", 18)],
672+
ir_version=9,
673+
)
674+
675+
model2 = helper.make_model(
676+
helper.make_graph(
677+
[
678+
helper.make_node(
679+
"ReplaceZero",
680+
["X"],
681+
["Y"],
682+
by=1.67,
683+
domain="ai.onnx.contrib",
684+
)
685+
],
686+
"nd",
687+
[helper.make_tensor_value_info("X", itype, [None, None, None])],
688+
[helper.make_tensor_value_info("Y", itype, [None, None, None])],
689+
),
690+
opset_imports=[
691+
helper.make_opsetid("", 18),
692+
helper.make_opsetid("ai.onnx.contrib", 1),
693+
],
694+
ir_version=9,
695+
)
696+
697+
dtype = np.float32 if itype == TensorProto.FLOAT else np.float16
698+
x = (np.arange(18) - 4).reshape((3, 2, 3)).astype(dtype)
699+
700+
feeds1 = dict(X=x)
701+
ref = ReferenceEvaluator(model1)
702+
expected = ref.run(None, feeds1)[0]
703+
704+
opts = _ort.SessionOptions()
705+
opts.register_custom_ops_library(_get_library_path())
706+
sess = _ort.InferenceSession(model2.SerializeToString(), opts, providers=["CUDAExecutionProvider"])
707+
got = sess.run(None, feeds1)[0]
708+
assert_allclose(expected, got, atol=1e-5)
709+
710+
@unittest.skipIf(not has_cuda(), reason="cuda not available")
711+
def test_replace_zero_cuda(self):
712+
self._replace_zero_cuda(TensorProto.FLOAT)
713+
self._replace_zero_cuda(TensorProto.FLOAT16)
714+
655715

656716
if __name__ == "__main__":
657717
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)