Skip to content

Commit b626409

Browse files
authored
webgpu ep support for argmax/argmin (microsoft#24089)
1 parent 7444fee commit b626409

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12);
9191
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17);
9292
REGISTER_REDUCE_KERNEL(ReduceLogSumExp, 18);
9393

94+
REGISTER_REDUCE_VERSIONED_KERNEL(ArgMax, 1, 10);
95+
REGISTER_REDUCE_VERSIONED_KERNEL(ArgMax, 11, 12);
96+
REGISTER_REDUCE_KERNEL(ArgMax, 13);
97+
98+
REGISTER_REDUCE_VERSIONED_KERNEL(ArgMin, 1, 10);
99+
REGISTER_REDUCE_VERSIONED_KERNEL(ArgMin, 11, 12);
100+
REGISTER_REDUCE_KERNEL(ArgMin, 13);
101+
94102
Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
95103
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
96104
if (is_input_empty_) {
@@ -114,6 +122,9 @@ Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
114122
std::stringstream ss;
115123
std::string index = "i" + std::to_string(i);
116124
ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n";
125+
if (loop_body.find("last_index") != std::string::npos) {
126+
ss << "let last_index = " + index + ";\n";
127+
}
117128
ss << input.IndicesSet("input_indices", i, index) << ";\n";
118129
ss << loop_body << "\n";
119130
ss << "}\n";
@@ -337,5 +348,25 @@ ReduceOpSpecificCode ReduceLogSumExp::GetOpSpecificCode(const Tensor* input_tens
337348
return code;
338349
}
339350

351+
ReduceOpSpecificCode ArgMin::GetOpSpecificCode(const Tensor* input_tensor) const {
352+
ORT_UNUSED_PARAMETER(input_tensor);
353+
std::string op = (select_last_index_) ? "<=" : "<";
354+
std::string loop_header = "var best_element = first_element; var best_index = u32(0);";
355+
std::string loop_body = "if (current_element " + op + " best_element) { best_element = current_element; best_index = last_index; };";
356+
std::string loop_footer = "let output_value = output_value_t(best_index);";
357+
ReduceOpSpecificCode code({loop_header, loop_body, loop_footer});
358+
return code;
359+
}
360+
361+
ReduceOpSpecificCode ArgMax::GetOpSpecificCode(const Tensor* input_tensor) const {
362+
ORT_UNUSED_PARAMETER(input_tensor);
363+
std::string op = (select_last_index_) ? ">=" : ">";
364+
std::string loop_header = "var best_element = first_element; var best_index = u32(0);";
365+
std::string loop_body = "if (current_element " + op + " best_element) { best_element = current_element; best_index = last_index; };";
366+
std::string loop_footer = "let output_value = output_value_t(best_index);";
367+
ReduceOpSpecificCode code({loop_header, loop_body, loop_footer});
368+
return code;
369+
}
370+
340371
} // namespace webgpu
341-
} // namespace onnxruntime
372+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/reduction/reduction_ops.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,17 @@ class ReduceLogSumExp final : public ReduceKernel<true> {
119119
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
120120
};
121121

122+
class ArgMin final : public ReduceKernel<false> {
123+
public:
124+
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info, "ArgMin", true) {}
125+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
126+
};
127+
128+
class ArgMax final : public ReduceKernel<false> {
129+
public:
130+
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info, "ArgMax", true) {}
131+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor) const override;
132+
};
133+
122134
} // namespace webgpu
123135
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -297,12 +297,12 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13,
297297
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul);
298298
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul);
299299

300-
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax);
301-
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax);
302-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax);
303-
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin);
304-
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin);
305-
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin);
300+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMax);
301+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMax);
302+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMax);
303+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMin);
304+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMin);
305+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMin);
306306

307307
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax);
308308
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax);
@@ -624,13 +624,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
624624
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
625625
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, MatMul)>,
626626

627-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax)>,
628-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax)>,
629-
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
627+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMax)>,
628+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMax)>,
629+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMax)>,
630630

631-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, float, ArgMin)>,
632-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, float, ArgMin)>,
633-
// BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
631+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ArgMin)>,
632+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ArgMin)>,
633+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, ArgMin)>,
634634

635635
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Softmax)>,
636636
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Softmax)>,

0 commit comments

Comments
 (0)