Skip to content

Commit 80213a9

Browse files
authored
Add implementation for ScatterND (#19540)
### Description onnxruntime switches to CPU for ScatterND after opset 13. This extends the implementation of higher opsets.
1 parent 14fcf0a commit 80213a9

15 files changed

+868
-41
lines changed

docs/OperatorKernels.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,9 @@ Do not modify directly.*
774774
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
775775
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
776776
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
777-
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
777+
|ScatterND|*in* data:**T**<br> *in* indices:**tensor(int64)**<br> *in* updates:**T**<br> *out* output:**T**|18+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
778+
|||[16, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
779+
|||[13, 15]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
778780
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
779781
|Selu|*in* X:**T**<br> *out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
780782
|SequenceAt|*in* input_sequence:**S**<br> *in* position:**I**<br> *out* tensor:**T**|11+|**I** = tensor(int32), tensor(int64)<br/> **S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))<br/> **T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

onnxruntime/core/providers/cuda/cuda_execution_provider.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,7 +1157,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
11571157
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN);
11581158
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN);
11591159
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity);
1160-
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND);
1160+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND);
11611161
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad);
11621162
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad);
11631163
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad);
@@ -1295,6 +1295,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
12951295
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
12961296
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
12971297
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
1298+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND);
12981299
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
12991300

13001301
// Opset 17
@@ -1312,6 +1313,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
13121313
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax);
13131314
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax);
13141315
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements);
1316+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND);
13151317
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad);
13161318
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad);
13171319
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad);
@@ -2071,7 +2073,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
20712073
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, LRN)>,
20722074
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, LRN)>,
20732075
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
2074-
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, ScatterND)>,
2076+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 15, ScatterND)>,
20752077
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, float, Pad)>,
20762078
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, double, Pad)>,
20772079
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 17, MLFloat16, Pad)>,
@@ -2202,6 +2204,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
22022204
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual)>,
22032205
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual)>,
22042206
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements)>,
2207+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterND)>,
22052208
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample)>,
22062209

22072210
// Opset 17
@@ -2225,6 +2228,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
22252228
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int32_t, ReduceMax)>,
22262229
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, int64_t, ReduceMax)>,
22272230
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterElements)>,
2231+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, ScatterND)>,
22282232
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, float, Pad)>,
22292233
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, double, Pad)>,
22302234
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 18, MLFloat16, Pad)>,

onnxruntime/core/providers/cuda/tensor/scatter_elements.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ Status ScatterElements::ComputeInternal(OpKernelContext* context) const {
133133
} else if (reduction_ == "max") {
134134
args.operation = GatherScatterElementsArgs::Operation::MAX;
135135
} else {
136-
ORT_THROW("Unsupported reduction type");
136+
ORT_THROW("Unsupported reduction type for ScatterElements.");
137137
}
138138

139139
// Use element size instead of concrete types so we can specialize less template functions to reduce binary size.

onnxruntime/core/providers/cuda/tensor/scatter_nd.cc

Lines changed: 133 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/providers/cuda/tensor/scatter_nd.h"
55
#include "core/providers/cuda/tensor/scatter_nd_impl.h"
6+
#include "core/providers/cuda/tensor/scatter_nd_common.h"
67
#include "core/providers/cuda/shared_inc/cuda_utils.h"
78
#include "core/providers/cpu/tensor/utils.h"
89

@@ -16,18 +17,61 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
1617
(*KernelDefBuilder::Create())
1718
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
1819
.MayInplace(0, 0),
19-
ScatterND);
20+
ScatterNDDisjointAndNoReduction);
21+
22+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
23+
kOnnxDomain,
24+
13, 15,
25+
kCudaExecutionProvider,
26+
(*KernelDefBuilder::Create())
27+
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
28+
.MayInplace(0, 0),
29+
ScatterNDWithAtomicReduction);
30+
31+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(ScatterND,
32+
kOnnxDomain,
33+
16, 17,
34+
kCudaExecutionProvider,
35+
(*KernelDefBuilder::Create())
36+
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
37+
.MayInplace(0, 0),
38+
ScatterNDWithAtomicReduction);
2039

2140
ONNX_OPERATOR_KERNEL_EX(ScatterND,
2241
kOnnxDomain,
23-
13,
42+
18,
2443
kCudaExecutionProvider,
2544
(*KernelDefBuilder::Create())
2645
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes())
2746
.MayInplace(0, 0),
28-
ScatterND);
47+
ScatterNDWithAtomicReduction);
2948

30-
Status ScatterND::ComputeInternal(OpKernelContext* context) const {
49+
static Status InitiliazeElementCountsAndInputDimsSpanOrGpu(int64_t last_index_dimension, const TensorShape& input_shape,
50+
ElementCountsAndInputDimsSpanOrGpu& element_counts_and_input_dims,
51+
CudaKernel::CudaAsyncBuffer<int64_t>& element_counts_and_input_dims_gpu,
52+
onnxruntime::OpKernelContext* context) {
53+
TensorPitches input_strides(input_shape);
54+
55+
if (last_index_dimension < 6) {
56+
element_counts_and_input_dims.gpu_ptr = nullptr;
57+
for (int64_t i = 0; i < last_index_dimension; ++i) {
58+
element_counts_and_input_dims.stack_ptr[i] = input_strides[i];
59+
element_counts_and_input_dims.stack_ptr[i + last_index_dimension] = input_shape[i];
60+
}
61+
} else {
62+
element_counts_and_input_dims_gpu.AllocCpuPtr(last_index_dimension * 2);
63+
memset(element_counts_and_input_dims_gpu.CpuPtr(), 0, sizeof(int64_t) * last_index_dimension * 2);
64+
for (int64_t i = 0; i < last_index_dimension; ++i) {
65+
element_counts_and_input_dims_gpu.CpuPtr()[i] = input_strides[i];
66+
element_counts_and_input_dims_gpu.CpuPtr()[i + last_index_dimension] = input_shape[i];
67+
}
68+
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
69+
element_counts_and_input_dims.gpu_ptr = element_counts_and_input_dims_gpu.GpuPtr();
70+
}
71+
return Status::OK();
72+
}
73+
74+
Status ScatterNDDisjointAndNoReduction::ComputeInternal(OpKernelContext* context) const {
3175
const auto* input_tensor = context->Input<Tensor>(0);
3276
const auto* indices_tensor = context->Input<Tensor>(1);
3377
const auto* updates_tensor = context->Input<Tensor>(2);
@@ -44,8 +88,6 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
4488
const void* input_data = input_tensor->DataRaw();
4589
void* output_data = output_tensor->MutableDataRaw();
4690

47-
size_t element_size = input_tensor->DataType()->Size();
48-
4991
if (input_data != output_data) {
5092
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will be faster than invoking cudaMemcpy ?
5193
CUDA_RETURN_IF_ERROR(
@@ -58,18 +100,17 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
58100
}
59101

60102
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
103+
size_t element_size = input_tensor->DataType()->Size();
61104

62105
// We need element counts for each dimension and the input dim value for each dimension
63106
// for the range [0, last_index_dimension).
64107
// To avoid multiple GPU data transfers, we combine this into one array and send it through
65-
TensorPitches input_strides(input_shape);
66-
std::vector<int64_t> element_counts_and_input_dims(last_index_dimension * 2, 0LL);
67-
for (int64_t i = 0; i < last_index_dimension; ++i) {
68-
element_counts_and_input_dims[i] = input_strides[i];
69-
element_counts_and_input_dims[i + last_index_dimension] = input_shape[i];
70-
}
71-
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this, element_counts_and_input_dims);
72-
ORT_RETURN_IF_ERROR(element_counts_and_input_dims_gpu.CopyToGpu(context->GetComputeStream()));
108+
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
109+
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
110+
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
111+
element_counts_and_input_dims,
112+
element_counts_and_input_dims_gpu,
113+
context));
73114

74115
ORT_RETURN_IF_ERROR(ScatterNDImpl(
75116
Stream(context),
@@ -78,12 +119,89 @@ Status ScatterND::ComputeInternal(OpKernelContext* context) const {
78119
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
79120
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
80121
last_index_dimension,
81-
element_counts_and_input_dims_gpu.GpuPtr(),
122+
element_counts_and_input_dims,
82123
updates_tensor->DataRaw(),
83124
input_shape.SizeFromDimension(last_index_dimension)));
84125

85126
return Status::OK();
86127
}
87128

129+
Status ScatterNDWithAtomicReduction::ComputeInternal(OpKernelContext* context) const {
130+
const auto* input_tensor = context->Input<Tensor>(0);
131+
const auto* indices_tensor = context->Input<Tensor>(1);
132+
const auto* updates_tensor = context->Input<Tensor>(2);
133+
134+
const auto& input_shape = input_tensor->Shape();
135+
const auto& indices_shape = indices_tensor->Shape();
136+
const auto& updates_shape = updates_tensor->Shape();
137+
138+
// Validate input shapes
139+
ORT_RETURN_IF_ERROR(onnxruntime::ScatterND::ValidateShapes(input_shape, indices_shape, updates_shape));
140+
141+
auto* output_tensor = context->Output(0, input_shape);
142+
143+
const void* input_data = input_tensor->DataRaw();
144+
void* output_data = output_tensor->MutableDataRaw();
145+
146+
if (input_data != output_data) {
147+
// TODO: Run benchmarks to determine if a dedicated kernel doing data copy will
148+
// be faster than invoking cudaMemcpy ?
149+
CUDA_RETURN_IF_ERROR(
150+
cudaMemcpyAsync(output_data, input_data, input_tensor->SizeInBytes(),
151+
cudaMemcpyDeviceToDevice, Stream(context)));
152+
}
153+
154+
// Bail out early
155+
if (indices_shape.Size() == 0) {
156+
return Status::OK();
157+
}
158+
159+
auto last_index_dimension = indices_shape[indices_shape.NumDimensions() - 1];
160+
ElementCountsAndInputDimsSpanOrGpu element_counts_and_input_dims;
161+
CudaAsyncBuffer<int64_t> element_counts_and_input_dims_gpu(this);
162+
ORT_RETURN_IF_ERROR(InitiliazeElementCountsAndInputDimsSpanOrGpu(last_index_dimension, input_shape,
163+
element_counts_and_input_dims,
164+
element_counts_and_input_dims_gpu,
165+
context));
166+
167+
switch (reduction_) {
168+
case ScatterNDReduction::None: {
169+
size_t element_size = input_tensor->DataType()->Size();
170+
ORT_RETURN_IF_ERROR(ScatterNDImpl(
171+
Stream(context),
172+
output_data,
173+
element_size,
174+
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
175+
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
176+
last_index_dimension,
177+
element_counts_and_input_dims,
178+
updates_tensor->DataRaw(),
179+
input_shape.SizeFromDimension(last_index_dimension)));
180+
} break;
181+
case ScatterNDReduction::Add:
182+
case ScatterNDReduction::Min:
183+
case ScatterNDReduction::Max:
184+
case ScatterNDReduction::Mul: {
185+
auto element_type = input_tensor->DataType()->AsPrimitiveDataType()->GetDataType();
186+
ORT_RETURN_IF_ERROR(ScatterNDImplReduction(
187+
Stream(context),
188+
output_data,
189+
element_type,
190+
indices_shape.Size() / static_cast<size_t>(last_index_dimension),
191+
indices_tensor->Data<int64_t>(), // only int64_t is supported for indices as per the onnx spec
192+
last_index_dimension,
193+
element_counts_and_input_dims,
194+
updates_tensor->DataRaw(),
195+
input_shape.SizeFromDimension(last_index_dimension),
196+
reduction_));
197+
} break;
198+
default:
199+
ORT_THROW("ScatterND not supported for other reduction than Add, None.");
200+
break;
201+
}
202+
203+
return Status::OK();
204+
}
205+
88206
} // namespace cuda
89207
} // namespace onnxruntime

0 commit comments

Comments
 (0)