1111namespace onnxruntime {
1212namespace webgpu {
1313
14- #define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceOp, begin, end ) \
14+ #define REGISTER_REDUCE_VERSIONED_KERNEL (ReduceOp, begin, end ) \
1515 ONNX_OPERATOR_VERSIONED_KERNEL_EX ( \
1616 ReduceOp, \
1717 kOnnxDomain , \
@@ -20,7 +20,16 @@ namespace webgpu {
2020 (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedNumberTypes()), \
2121 ReduceOp);
2222
23- #define REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceOp, version ) \
23+ #define REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT (ReduceOp, begin, end ) \
24+ ONNX_OPERATOR_VERSIONED_KERNEL_EX ( \
25+ ReduceOp, \
26+ kOnnxDomain , \
27+ begin, end, \
28+ kWebGpuExecutionProvider , \
29+ (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1 ), \
30+ ReduceOp);
31+
32+ #define REGISTER_REDUCE_KERNEL (ReduceOp, version ) \
2433 ONNX_OPERATOR_KERNEL_EX ( \
2534 ReduceOp, \
2635 kOnnxDomain , \
@@ -29,58 +38,58 @@ namespace webgpu {
2938 (*KernelDefBuilder::Create ()).TypeConstraint(" T" , WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1 ), \
3039 ReduceOp);
3140
32- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMean, 1 , 10 );
33- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMean, 11 , 12 );
34- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMean, 13 , 17 );
35- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMean, 18 );
41+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMean, 1 , 10 );
42+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMean, 11 , 12 );
43+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMean, 13 , 17 );
44+ REGISTER_REDUCE_KERNEL (ReduceMean, 18 );
3645
37- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 1 , 10 );
38- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 11 , 11 );
39- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 12 , 12 );
40- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 13 , 17 );
41- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMax, 18 , 19 );
42- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMax, 20 );
46+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMax, 1 , 10 );
47+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMax, 11 , 11 );
48+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMax, 12 , 12 );
49+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMax, 13 , 17 );
50+ REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT (ReduceMax, 18 , 19 );
51+ REGISTER_REDUCE_KERNEL (ReduceMax, 20 );
4352
44- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMin, 1 , 10 );
45- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMin, 11 , 11 );
46- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMin, 12 , 12 );
47- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMin, 13 , 17 );
48- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceMin, 18 , 19 );
49- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceMin, 20 );
53+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMin, 1 , 10 );
54+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMin, 11 , 11 );
55+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMin, 12 , 12 );
56+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceMin, 13 , 17 );
57+ REGISTER_REDUCE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT (ReduceMin, 18 , 19 );
58+ REGISTER_REDUCE_KERNEL (ReduceMin, 20 );
5059
51- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 1 , 10 );
52- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSum, 11 , 12 );
53- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceSum, 13 );
60+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceSum, 1 , 10 );
61+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceSum, 11 , 12 );
62+ REGISTER_REDUCE_KERNEL (ReduceSum, 13 );
5463
55- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceProd, 1 , 10 );
56- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceProd, 11 , 12 );
57- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceProd, 13 , 17 );
58- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceProd, 18 );
64+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceProd, 1 , 10 );
65+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceProd, 11 , 12 );
66+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceProd, 13 , 17 );
67+ REGISTER_REDUCE_KERNEL (ReduceProd, 18 );
5968
60- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL1, 1 , 10 );
61- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL1, 11 , 12 );
62- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL1, 13 , 17 );
63- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceL1, 18 );
69+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL1, 1 , 10 );
70+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL1, 11 , 12 );
71+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL1, 13 , 17 );
72+ REGISTER_REDUCE_KERNEL (ReduceL1, 18 );
6473
65- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL2, 1 , 10 );
66- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL2, 11 , 12 );
67- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceL2, 13 , 17 );
68- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceL2, 18 );
74+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL2, 1 , 10 );
75+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL2, 11 , 12 );
76+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceL2, 13 , 17 );
77+ REGISTER_REDUCE_KERNEL (ReduceL2, 18 );
6978
70- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSum, 1 , 10 );
71- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSum, 11 , 12 );
72- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSum, 13 , 17 );
73- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceLogSum, 18 );
79+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSum, 1 , 10 );
80+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSum, 11 , 12 );
81+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSum, 13 , 17 );
82+ REGISTER_REDUCE_KERNEL (ReduceLogSum, 18 );
7483
75- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSumSquare, 1 , 10 );
76- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSumSquare, 11 , 12 );
77- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceSumSquare, 13 , 17 );
78- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceSumSquare, 18 );
84+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceSumSquare, 1 , 10 );
85+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceSumSquare, 11 , 12 );
86+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceSumSquare, 13 , 17 );
87+ REGISTER_REDUCE_KERNEL (ReduceSumSquare, 18 );
7988
80- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSumExp, 1 , 10 );
81- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSumExp, 11 , 12 );
82- REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL (ReduceLogSumExp, 13 , 17 );
83- REGISTER_UNARY_ELEMENTWISE_KERNEL (ReduceLogSumExp, 18 );
89+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSumExp, 1 , 10 );
90+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSumExp, 11 , 12 );
91+ REGISTER_REDUCE_VERSIONED_KERNEL (ReduceLogSumExp, 13 , 17 );
92+ REGISTER_REDUCE_KERNEL (ReduceLogSumExp, 18 );
8493
8594Status ReduceKernelProgram::GenerateShaderCode (ShaderHelper& shader) const {
8695 const auto & output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
0 commit comments