Skip to content

Commit f505546

Browse files
authored
Add custom kernel ScatterNDOfShape (#705)
* first draft * clang * Draft for ScatterNFOfShape * fix build * disable test when cuda is missing * fix implementation * update test * add MaskedScatterNdOfShape * fix merge conflicts
1 parent 79f3b04 commit f505546

File tree

5 files changed

+646
-0
lines changed

5 files changed

+646
-0
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "cuda/add_mul.h"
88
#include "cuda/fast_gelu.h"
99
#include "cuda/negxplus1.h"
10+
#include "cuda/scatter_nd_of_shape.h"
1011
#include "cuda/transpose_cast.h"
1112
#endif
1213

@@ -29,15 +30,19 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
2930
,
3031
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
3132
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
33+
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
3234
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
3335
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
36+
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
3437
#if ORT_API_VERSION >= 16
3538

3639
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
3740
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
3841
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
42+
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
3943
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
4044
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
45+
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
4146
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
4247
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
4348
#endif
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "string_utils.h"
7+
#include "scatter_nd_of_shape_impl.cuh"
8+
9+
namespace contrib {
10+
11+
template <typename T>
12+
struct ScatterNDOfShape {
13+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
14+
std::string value;
15+
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
16+
if (status != nullptr)
17+
return status;
18+
19+
if (value == "add")
20+
reduction_ = ScatterReduction::Add;
21+
else if (value == "mul")
22+
reduction_ = ScatterReduction::Mul;
23+
else if (value == "min")
24+
reduction_ = ScatterReduction::Min;
25+
else if (value == "max")
26+
reduction_ = ScatterReduction::Max;
27+
else
28+
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
29+
30+
return nullptr;
31+
}
32+
33+
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
34+
const ortc::Tensor<int64_t>& output_shape,
35+
const ortc::Tensor<int64_t>& indices,
36+
const ortc::Tensor<T>& updates,
37+
ortc::Tensor<T>& output) const {
38+
auto& output_shape_shape = output_shape.Shape();
39+
auto& indices_shape = indices.Shape();
40+
auto& updates_shape = updates.Shape();
41+
42+
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
43+
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
44+
}
45+
if (indices_shape[indices_shape.size() - 1] != 1) {
46+
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
47+
}
48+
49+
const int64_t* shape_data = output_shape.Data(); // CPU pointer
50+
const int64_t* indices_data = indices.Data(); // GPU pointer
51+
const T* updates_data = updates.Data(); // GPU pointer
52+
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
53+
T* output_data = output.Allocate(voutput_shape); // GPU pointer
54+
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
55+
voutput_shape,
56+
indices_shape,
57+
indices_data,
58+
updates_data,
59+
output_data,
60+
reduction_);
61+
return nullptr;
62+
}
63+
64+
static OrtMemType GetInputMemoryType(size_t input_index) {
65+
if (input_index == 0) // shape
66+
return OrtMemType::OrtMemTypeCPUInput;
67+
return OrtMemType::OrtMemTypeDefault;
68+
}
69+
70+
ScatterReduction reduction_;
71+
};
72+
73+
74+
template <typename T>
75+
struct MaskedScatterNDOfShape {
76+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
77+
std::string value;
78+
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
79+
if (status != nullptr)
80+
return status;
81+
82+
if (value == "add")
83+
reduction_ = ScatterReduction::Add;
84+
else if (value == "mul")
85+
reduction_ = ScatterReduction::Mul;
86+
else if (value == "min")
87+
reduction_ = ScatterReduction::Min;
88+
else if (value == "max")
89+
reduction_ = ScatterReduction::Max;
90+
else
91+
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
92+
93+
status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
94+
if (status != nullptr)
95+
return status;
96+
97+
return nullptr;
98+
}
99+
100+
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
101+
const ortc::Tensor<int64_t>& output_shape,
102+
const ortc::Tensor<int64_t>& indices,
103+
const ortc::Tensor<T>& updates,
104+
ortc::Tensor<T>& output) const {
105+
auto& output_shape_shape = output_shape.Shape();
106+
auto& indices_shape = indices.Shape();
107+
auto& updates_shape = updates.Shape();
108+
109+
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
110+
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
111+
}
112+
if (indices_shape[indices_shape.size() - 1] != 1) {
113+
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
114+
}
115+
116+
const int64_t* shape_data = output_shape.Data(); // CPU pointer
117+
const int64_t* indices_data = indices.Data(); // GPU pointer
118+
const T* updates_data = updates.Data(); // GPU pointer
119+
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
120+
T* output_data = output.Allocate(voutput_shape); // GPU pointer
121+
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
122+
voutput_shape,
123+
indices_shape,
124+
indices_data,
125+
updates_data,
126+
output_data,
127+
reduction_,
128+
masked_value_);
129+
return nullptr;
130+
}
131+
132+
static OrtMemType GetInputMemoryType(size_t input_index) {
133+
if (input_index == 0) // shape
134+
return OrtMemType::OrtMemTypeCPUInput;
135+
return OrtMemType::OrtMemTypeDefault;
136+
}
137+
138+
private:
139+
ScatterReduction reduction_;
140+
int64_t masked_value_;
141+
};
142+
143+
} // namespace contrib

0 commit comments

Comments
 (0)