Skip to content

Commit b3aa5a3

Browse files
authored
[WebGPU EP] fix for reduce min/max error on MacOS CI (microsoft#24077)
### Error ```Traceback /onnxruntime/onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc:146 [allow_multi_axes = true] Axes values must be in the range [-rank, rank-1]. Got: 446098880 ```
1 parent a46d212 commit b3aa5a3

File tree

2 files changed

+57
-46
lines changed

2 files changed

+57
-46
lines changed

onnxruntime/core/providers/webgpu/reduction/reduction_ops.cc

Lines changed: 54 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
namespace onnxruntime {
1212
namespace webgpu {
1313

14-
#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \
14+
#define REGISTER_REDUCE_VERSIONED_KERNEL(ReduceOp, begin, end) \
1515
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
1616
ReduceOp, \
1717
kOnnxDomain, \
@@ -20,7 +20,16 @@ namespace webgpu {
2020
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \
2121
ReduceOp);
2222

23-
#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \
23+
#define REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceOp, begin, end) \
24+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
25+
ReduceOp, \
26+
kOnnxDomain, \
27+
begin, end, \
28+
kWebGpuExecutionProvider, \
29+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \
30+
ReduceOp);
31+
32+
#define REGISTER_REDUCE_KERNEL(ReduceOp, version) \
2433
ONNX_OPERATOR_KERNEL_EX( \
2534
ReduceOp, \
2635
kOnnxDomain, \
@@ -29,58 +38,58 @@ namespace webgpu {
2938
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \
3039
ReduceOp);
3140

32-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10);
33-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
34-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17);
35-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18);
41+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 1, 10);
42+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 11, 12);
43+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMean, 13, 17);
44+
REGISTER_REDUCE_KERNEL(ReduceMean, 18);
3645

37-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 1, 10);
38-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 11, 11);
39-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 12, 12);
40-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 13, 17);
41-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMax, 18, 19);
42-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMax, 20);
46+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 1, 10);
47+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 11, 11);
48+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 12, 12);
49+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMax, 13, 17);
50+
REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMax, 18, 19);
51+
REGISTER_REDUCE_KERNEL(ReduceMax, 20);
4352

44-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 1, 10);
45-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 11, 11);
46-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 12, 12);
47-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 13, 17);
48-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMin, 18, 19);
49-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMin, 20);
53+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 1, 10);
54+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 11, 11);
55+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 12, 12);
56+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceMin, 13, 17);
57+
REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT(ReduceMin, 18, 19);
58+
REGISTER_REDUCE_KERNEL(ReduceMin, 20);
5059

51-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 1, 10);
52-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSum, 11, 12);
53-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSum, 13);
60+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSum, 1, 10);
61+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSum, 11, 12);
62+
REGISTER_REDUCE_KERNEL(ReduceSum, 13);
5463

55-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 1, 10);
56-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 11, 12);
57-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceProd, 13, 17);
58-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceProd, 18);
64+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 1, 10);
65+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 11, 12);
66+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceProd, 13, 17);
67+
REGISTER_REDUCE_KERNEL(ReduceProd, 18);
5968

60-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 1, 10);
61-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 11, 12);
62-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL1, 13, 17);
63-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL1, 18);
69+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 1, 10);
70+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 11, 12);
71+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL1, 13, 17);
72+
REGISTER_REDUCE_KERNEL(ReduceL1, 18);
6473

65-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 1, 10);
66-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 11, 12);
67-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceL2, 13, 17);
68-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceL2, 18);
74+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 1, 10);
75+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 11, 12);
76+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceL2, 13, 17);
77+
REGISTER_REDUCE_KERNEL(ReduceL2, 18);
6978

70-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 1, 10);
71-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 11, 12);
72-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSum, 13, 17);
73-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSum, 18);
79+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 1, 10);
80+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 11, 12);
81+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSum, 13, 17);
82+
REGISTER_REDUCE_KERNEL(ReduceLogSum, 18);
7483

75-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 1, 10);
76-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 11, 12);
77-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceSumSquare, 13, 17);
78-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceSumSquare, 18);
84+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 1, 10);
85+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 11, 12);
86+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceSumSquare, 13, 17);
87+
REGISTER_REDUCE_KERNEL(ReduceSumSquare, 18);
7988

80-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10);
81-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12);
82-
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17);
83-
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceLogSumExp, 18);
89+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 1, 10);
90+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 11, 12);
91+
REGISTER_REDUCE_VERSIONED_KERNEL(ReduceLogSumExp, 13, 17);
92+
REGISTER_REDUCE_KERNEL(ReduceLogSumExp, 18);
8493

8594
Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
8695
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);

onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,9 @@
694694
"^test_gelu_tanh_2_expanded_cpu",
695695
"^test_dynamicquantizelinear_expanded_cpu",
696696
"^test_center_crop_pad_crop_negative_axes_hwc*", // failed due to new types or shape infer with negative axis for CenterCropPad.
697-
"^test_center_crop_pad_crop_negative_axes_hwc_expanded*" // failed due to new types or shape infer with negative axis for CenterCropPad.
697+
"^test_center_crop_pad_crop_negative_axes_hwc_expanded*", // failed due to new types or shape infer with negative axis for CenterCropPad.
698+
"^test_reduce_max_empty_set",
699+
"^test_reduce_min_empty_set"
698700
],
699701
"current_failing_tests_pure_DML": [
700702
"^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu",

0 commit comments

Comments
 (0)