3
3
4
4
#include " core/providers/cuda/tensor/scatter_nd.h"
5
5
#include " core/providers/cuda/tensor/scatter_nd_impl.h"
6
+ #include " core/providers/cuda/tensor/scatter_nd_common.h"
6
7
#include " core/providers/cuda/shared_inc/cuda_utils.h"
7
8
#include " core/providers/cpu/tensor/utils.h"
8
9
@@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
16
17
(*KernelDefBuilder::Create ())
17
18
.TypeConstraint(" T" , DataTypeImpl::AllFixedSizeTensorTypes())
18
19
.MayInplace(0 , 0 ),
19
- ScatterND);
20
+ ScatterNDDisjointAndNoReduction);
21
+
22
+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (ScatterND,
23
+ kOnnxDomain ,
24
+ 13 , 15 ,
25
+ kCudaExecutionProvider ,
26
+ (*KernelDefBuilder::Create ())
27
+ .TypeConstraint(" T" , DataTypeImpl::AllFixedSizeTensorTypes())
28
+ .MayInplace(0 , 0 ),
29
+ ScatterNDWithAtomicReduction);
30
+
31
+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (ScatterND,
32
+ kOnnxDomain ,
33
+ 16 , 17 ,
34
+ kCudaExecutionProvider ,
35
+ (*KernelDefBuilder::Create ())
36
+ .TypeConstraint(" T" , DataTypeImpl::AllFixedSizeTensorTypes())
37
+ .MayInplace(0 , 0 ),
38
+ ScatterNDWithAtomicReduction);
20
39
21
40
ONNX_OPERATOR_KERNEL_EX (ScatterND,
22
41
kOnnxDomain ,
23
- 13 ,
42
+ 18 ,
24
43
kCudaExecutionProvider ,
25
44
(*KernelDefBuilder::Create ())
26
45
.TypeConstraint(" T" , DataTypeImpl::AllFixedSizeTensorTypes())
27
46
.MayInplace(0 , 0 ),
28
- ScatterND );
47
+ ScatterNDWithAtomicReduction );
29
48
30
- Status ScatterND::ComputeInternal (OpKernelContext* context) const {
49
+ static Status InitiliazeElementCountsAndInputDimsSpanOrGpu (int64_t last_index_dimension, const TensorShape& input_shape,
50
+ ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
51
+ CudaKernel::CudaAsyncBuffer<int64_t >& element_counts_and_input_dims_gpu,
52
+ onnxruntime::OpKernelContext* context) {
53
+ TensorPitches input_strides (input_shape);
54
+
55
+ if (last_index_dimension < 6 ) {
56
+ element_counts_and_input_dims.gpu_ptr = nullptr ;
57
+ for (int64_t i = 0 ; i < last_index_dimension; ++i) {
58
+ element_counts_and_input_dims.stack_ptr [i] = input_strides[i];
59
+ element_counts_and_input_dims.stack_ptr [i + last_index_dimension] = input_shape[i];
60
+ }
61
+ } else {
62
+ element_counts_and_input_dims_gpu.AllocCpuPtr (last_index_dimension * 2 );
63
+ memset (element_counts_and_input_dims_gpu.CpuPtr (), 0 , sizeof (int64_t ) * last_index_dimension * 2 );
64
+ for (int64_t i = 0 ; i < last_index_dimension; ++i) {
65
+ element_counts_and_input_dims_gpu.CpuPtr ()[i] = input_strides[i];
66
+ element_counts_and_input_dims_gpu.CpuPtr ()[i + last_index_dimension] = input_shape[i];
67
+ }
68
+ ORT_RETURN_IF_ERROR (element_counts_and_input_dims_gpu.CopyToGpu (context->GetComputeStream ()));
69
+ element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr ();
70
+ }
71
+ return Status::OK ();
72
+ }
73
+
74
+ Status ScatterNDDisjointAndNoReduction::ComputeInternal (OpKernelContext* context) const {
31
75
const auto * input_tensor = context->Input <Tensor>(0 );
32
76
const auto * indices_tensor = context->Input <Tensor>(1 );
33
77
const auto * updates_tensor = context->Input <Tensor>(2 );
@@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
44
88
const void * input_data = input_tensor->DataRaw ();
45
89
void * output_data = output_tensor->MutableDataRaw ();
46
90
47
- size_t element_size = input_tensor->DataType ()->Size ();
48
-
49
91
if (input_data != output_data) {
50
92
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
51
93
CUDA_RETURN_IF_ERROR (
@@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
58
100
}
59
101
60
102
auto last_index_dimension = indices_shape[indices_shape.NumDimensions () - 1 ];
103
+ size_t element_size = input_tensor->DataType ()->Size ();
61
104
62
105
// We need element counts for each dimension and the input dim value for each dimension
63
106
// for the range [0, last_index_dimension).
64
107
// To avoid multiple GPU data transfers, we combine this into one array and send it through
65
- TensorPitches input_strides (input_shape);
66
- std::vector<int64_t > element_counts_and_input_dims (last_index_dimension * 2 , 0LL );
67
- for (int64_t i = 0 ; i < last_index_dimension; ++i) {
68
- element_counts_and_input_dims[i] = input_strides[i];
69
- element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
70
- }
71
- CudaAsyncBuffer<int64_t > element_counts_and_input_dims_gpu (this , element_counts_and_input_dims);
72
- ORT_RETURN_IF_ERROR (element_counts_and_input_dims_gpu.CopyToGpu (context->GetComputeStream ()));
108
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
109
+ CudaAsyncBuffer<int64_t > element_counts_and_input_dims_gpu (this );
110
+ ORT_RETURN_IF_ERROR (InitiliazeElementCountsAndInputDimsSpanOrGpu (last_index_dimension, input_shape,
111
+ element_counts_and_input_dims,
112
+ element_counts_and_input_dims_gpu,
113
+ context));
73
114
74
115
ORT_RETURN_IF_ERROR (ScatterNDImpl (
75
116
Stream (context),
@@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
78
119
indices_shape.Size () / static_cast <size_t >(last_index_dimension),
79
120
indices_tensor->Data <int64_t >(), // only int64_t is supported for indices as per the onnx spec
80
121
last_index_dimension,
81
- element_counts_and_input_dims_gpu. GpuPtr () ,
122
+ element_counts_and_input_dims ,
82
123
updates_tensor->DataRaw (),
83
124
input_shape.SizeFromDimension (last_index_dimension)));
84
125
85
126
return Status::OK ();
86
127
}
87
128
129
+ Status ScatterNDWithAtomicReduction::ComputeInternal (OpKernelContext* context) const {
130
+ const auto * input_tensor = context->Input <Tensor>(0 );
131
+ const auto * indices_tensor = context->Input <Tensor>(1 );
132
+ const auto * updates_tensor = context->Input <Tensor>(2 );
133
+
134
+ const auto & input_shape = input_tensor->Shape ();
135
+ const auto & indices_shape = indices_tensor->Shape ();
136
+ const auto & updates_shape = updates_tensor->Shape ();
137
+
138
+ // Validate input shapes
139
+ ORT_RETURN_IF_ERROR (onnxruntime::ScatterND::ValidateShapes (input_shape, indices_shape, updates_shape));
140
+
141
+ auto * output_tensor = context->Output (0 , input_shape);
142
+
143
+ const void * input_data = input_tensor->DataRaw ();
144
+ void * output_data = output_tensor->MutableDataRaw ();
145
+
146
+ if (input_data != output_data) {
147
+ // TODO: Run benchmarks to determine if a dedicated kernel doing data copy will
148
+ // be faster than invoking cudaMemcpy ?
149
+ CUDA_RETURN_IF_ERROR (
150
+ cudaMemcpyAsync (output_data, input_data, input_tensor->SizeInBytes (),
151
+ cudaMemcpyDeviceToDevice, Stream (context)));
152
+ }
153
+
154
+ // Bail out early
155
+ if (indices_shape.Size () == 0 ) {
156
+ return Status::OK ();
157
+ }
158
+
159
+ auto last_index_dimension = indices_shape[indices_shape.NumDimensions () - 1 ];
160
+ ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
161
+ CudaAsyncBuffer<int64_t > element_counts_and_input_dims_gpu (this );
162
+ ORT_RETURN_IF_ERROR (InitiliazeElementCountsAndInputDimsSpanOrGpu (last_index_dimension, input_shape,
163
+ element_counts_and_input_dims,
164
+ element_counts_and_input_dims_gpu,
165
+ context));
166
+
167
+ switch (reduction_) {
168
+ case ScatterNDReduction::None: {
169
+ size_t element_size = input_tensor->DataType ()->Size ();
170
+ ORT_RETURN_IF_ERROR (ScatterNDImpl (
171
+ Stream (context),
172
+ output_data,
173
+ element_size,
174
+ indices_shape.Size () / static_cast <size_t >(last_index_dimension),
175
+ indices_tensor->Data <int64_t >(), // only int64_t is supported for indices as per the onnx spec
176
+ last_index_dimension,
177
+ element_counts_and_input_dims,
178
+ updates_tensor->DataRaw (),
179
+ input_shape.SizeFromDimension (last_index_dimension)));
180
+ } break ;
181
+ case ScatterNDReduction::Add:
182
+ case ScatterNDReduction::Min:
183
+ case ScatterNDReduction::Max:
184
+ case ScatterNDReduction::Mul: {
185
+ auto element_type = input_tensor->DataType ()->AsPrimitiveDataType ()->GetDataType ();
186
+ ORT_RETURN_IF_ERROR (ScatterNDImplReduction (
187
+ Stream (context),
188
+ output_data,
189
+ element_type,
190
+ indices_shape.Size () / static_cast <size_t >(last_index_dimension),
191
+ indices_tensor->Data <int64_t >(), // only int64_t is supported for indices as per the onnx spec
192
+ last_index_dimension,
193
+ element_counts_and_input_dims,
194
+ updates_tensor->DataRaw (),
195
+ input_shape.SizeFromDimension (last_index_dimension),
196
+ reduction_));
197
+ } break ;
198
+ default :
199
+ ORT_THROW (" ScatterND not supported for other reduction than Add, None." );
200
+ break ;
201
+ }
202
+
203
+ return Status::OK ();
204
+ }
205
+
88
206
} // namespace cuda
89
207
} // namespace onnxruntime
0 commit comments