Skip to content

Commit 29c460f

Browse files
committed
fix merge conflicts
2 parents f67d3b1 + f505546 commit 29c460f

File tree

5 files changed

+649
-0
lines changed

5 files changed

+649
-0
lines changed

operators/cuda/cuda_ops.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77
#include "cuda/add_mul.h"
88
#include "cuda/fast_gelu.h"
99
#include "cuda/negxplus1.h"
10+
<<<<<<< HEAD
1011
#include "cuda/rotary.h"
12+
=======
13+
#include "cuda/scatter_nd_of_shape.h"
14+
>>>>>>> f5055466d5376059c2ea74e3cea46e16a537bc0d
1115
#include "cuda/transpose_cast.h"
1216
#endif
1317

@@ -30,17 +34,21 @@ FxLoadCustomOpFactory LoadCustomOpClasses_Contrib = []() -> CustomOpArray& {
3034
,
3135
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat32Type),
3236
CustomCudaStructV2("FastGelu", contrib::FastGelu<float>),
37+
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<float>),
3338
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat32Type),
3439
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<float>),
3540
CustomCudaStructV2("Rotary", contrib::Rotary<float>),
41+
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<float>),
3642
#if ORT_API_VERSION >= 16
3743

3844
CustomCudaStructV2("AddSharedInput", AddSharedInputFloat16Type),
3945
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::MFloat16>),
4046
CustomCudaStructV2("FastGelu", contrib::FastGelu<ortc::BFloat16>),
47+
CustomCudaStructV2("MaskedScatterNDOfShape", contrib::MaskedScatterNDOfShape<ortc::MFloat16>),
4148
CustomCudaStructV2("MulSharedInput", MulSharedInputFloat16Type),
4249
CustomCudaStructV2("NegXPlus1", contrib::NegXPlus1<ortc::MFloat16>),
4350
CustomCudaStructV2("Rotary", contrib::Rotary<ortc::MFloat16>),
51+
CustomCudaStructV2("ScatterNDOfShape", contrib::ScatterNDOfShape<ortc::MFloat16>),
4452
CustomCudaStructV2("Transpose2DCastFP16", Transpose2DCastFloat32ToFloat16Type),
4553
CustomCudaStructV2("Transpose2DCastFP32", Transpose2DCastFloat16ToFloat32Type)
4654
#endif
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "ocos.h"
6+
#include "string_utils.h"
7+
#include "scatter_nd_of_shape_impl.cuh"
8+
9+
namespace contrib {
10+
11+
template <typename T>
12+
struct ScatterNDOfShape {
13+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
14+
std::string value;
15+
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
16+
if (status != nullptr)
17+
return status;
18+
19+
if (value == "add")
20+
reduction_ = ScatterReduction::Add;
21+
else if (value == "mul")
22+
reduction_ = ScatterReduction::Mul;
23+
else if (value == "min")
24+
reduction_ = ScatterReduction::Min;
25+
else if (value == "max")
26+
reduction_ = ScatterReduction::Max;
27+
else
28+
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
29+
30+
return nullptr;
31+
}
32+
33+
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
34+
const ortc::Tensor<int64_t>& output_shape,
35+
const ortc::Tensor<int64_t>& indices,
36+
const ortc::Tensor<T>& updates,
37+
ortc::Tensor<T>& output) const {
38+
auto& output_shape_shape = output_shape.Shape();
39+
auto& indices_shape = indices.Shape();
40+
auto& updates_shape = updates.Shape();
41+
42+
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
43+
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
44+
}
45+
if (indices_shape[indices_shape.size() - 1] != 1) {
46+
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
47+
}
48+
49+
const int64_t* shape_data = output_shape.Data(); // CPU pointer
50+
const int64_t* indices_data = indices.Data(); // GPU pointer
51+
const T* updates_data = updates.Data(); // GPU pointer
52+
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
53+
T* output_data = output.Allocate(voutput_shape); // GPU pointer
54+
LaunchScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
55+
voutput_shape,
56+
indices_shape,
57+
indices_data,
58+
updates_data,
59+
output_data,
60+
reduction_);
61+
return nullptr;
62+
}
63+
64+
static OrtMemType GetInputMemoryType(size_t input_index) {
65+
if (input_index == 0) // shape
66+
return OrtMemType::OrtMemTypeCPUInput;
67+
return OrtMemType::OrtMemTypeDefault;
68+
}
69+
70+
ScatterReduction reduction_;
71+
};
72+
73+
74+
template <typename T>
75+
struct MaskedScatterNDOfShape {
76+
OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) {
77+
std::string value;
78+
OrtStatusPtr status = OrtW::GetOpAttribute(info, "reduction", value);
79+
if (status != nullptr)
80+
return status;
81+
82+
if (value == "add")
83+
reduction_ = ScatterReduction::Add;
84+
else if (value == "mul")
85+
reduction_ = ScatterReduction::Mul;
86+
else if (value == "min")
87+
reduction_ = ScatterReduction::Min;
88+
else if (value == "max")
89+
reduction_ = ScatterReduction::Max;
90+
else
91+
ORTX_CXX_API_THROW("Unexpected reduction, only Add is implemented.", ORT_RUNTIME_EXCEPTION);
92+
93+
status = OrtW::GetOpAttribute(info, "maskedValue", masked_value_);
94+
if (status != nullptr)
95+
return status;
96+
97+
return nullptr;
98+
}
99+
100+
OrtStatusPtr Compute(Ort::Custom::CUDAKernelContext* ctx,
101+
const ortc::Tensor<int64_t>& output_shape,
102+
const ortc::Tensor<int64_t>& indices,
103+
const ortc::Tensor<T>& updates,
104+
ortc::Tensor<T>& output) const {
105+
auto& output_shape_shape = output_shape.Shape();
106+
auto& indices_shape = indices.Shape();
107+
auto& updates_shape = updates.Shape();
108+
109+
if (output_shape_shape.size() != 1 || output_shape_shape[0] == 0) {
110+
ORTX_CXX_API_THROW("output shape must be a 1D tensor", ORT_RUNTIME_EXCEPTION);
111+
}
112+
if (indices_shape[indices_shape.size() - 1] != 1) {
113+
ORTX_CXX_API_THROW("last dimension of the indices tensor should be one", ORT_RUNTIME_EXCEPTION);
114+
}
115+
116+
const int64_t* shape_data = output_shape.Data(); // CPU pointer
117+
const int64_t* indices_data = indices.Data(); // GPU pointer
118+
const T* updates_data = updates.Data(); // GPU pointer
119+
std::vector<int64_t> voutput_shape(shape_data, shape_data + output_shape_shape[0]);
120+
T* output_data = output.Allocate(voutput_shape); // GPU pointer
121+
LaunchMaskedScatterNDOfShapeKernel<T>(reinterpret_cast<cudaStream_t>(ctx->GetCudaStream()),
122+
voutput_shape,
123+
indices_shape,
124+
indices_data,
125+
updates_data,
126+
output_data,
127+
reduction_,
128+
masked_value_);
129+
return nullptr;
130+
}
131+
132+
static OrtMemType GetInputMemoryType(size_t input_index) {
133+
if (input_index == 0) // shape
134+
return OrtMemType::OrtMemTypeCPUInput;
135+
return OrtMemType::OrtMemTypeDefault;
136+
}
137+
138+
private:
139+
ScatterReduction reduction_;
140+
int64_t masked_value_;
141+
};
142+
143+
} // namespace contrib

0 commit comments

Comments
 (0)