1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #include " core/providers/webgpu/math/cum_sum.h"
5+ #include " core/providers/webgpu/shader_helper.h"
6+ #include " core/providers/webgpu/webgpu_supported_types.h"
7+
8+ namespace onnxruntime {
9+ namespace webgpu {
10+
11+ ONNX_OPERATOR_VERSIONED_KERNEL_EX (
12+ CumSum,
13+ kOnnxDomain ,
14+ 11 , 13 ,
15+ kWebGpuExecutionProvider ,
16+ (*KernelDefBuilder::Create ())
17+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
18+ .TypeConstraint(" T2" , {DataTypeImpl::GetTensorType<int32_t >(),
19+ DataTypeImpl::GetTensorType<int64_t >()})
20+ .InputMemoryType(OrtMemTypeCPU, 1 ),
21+ CumSum);
22+
23+ ONNX_OPERATOR_KERNEL_EX (
24+ CumSum,
25+ kOnnxDomain ,
26+ 14 ,
27+ kWebGpuExecutionProvider ,
28+ (*KernelDefBuilder::Create ())
29+ .TypeConstraint(" T" , WebGpuSupportedFloatTypes())
30+ .TypeConstraint(" T2" , {DataTypeImpl::GetTensorType<int32_t >(),
31+ DataTypeImpl::GetTensorType<int64_t >()})
32+ .InputMemoryType(OrtMemTypeCPU, 1 ),
33+ CumSum);
34+
35+ Status CumSumProgram::GenerateShaderCode (ShaderHelper& shader) const {
36+ const ShaderVariableHelper& input = shader.AddInput (" input" , ShaderUsage::UseUniform);
37+ const ShaderVariableHelper& output = shader.AddOutput (" output" , ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
38+
39+ shader.MainFunctionBody () << shader.GuardAgainstOutOfBoundsWorkgroupSizes (" uniforms.output_size" )
40+ << " var input_indices = " << input.OffsetToIndices (" global_idx" ) << " ;\n "
41+ << " var sum : output_value_t = 0;\n "
42+ << " var first : i32 = 0;\n "
43+ << " if (uniforms.reverse == 1) {\n "
44+ << " first = i32(" + input.IndicesGet (" input_indices" , " uniforms.axis" ) + " );\n "
45+ << " if (uniforms.exclusive == 1) { first += 1; }\n "
46+ << " }\n\n "
47+ << " var last : i32 = 0;\n "
48+ << " if (uniforms.reverse == 1) {\n "
49+ << " last = i32(" << GetElementAt (" uniforms.input_shape" , " uniforms.axis" , input.Rank ()) << " );\n "
50+ << " } else {\n "
51+ << " last = i32(" + input.IndicesGet (" input_indices" , " uniforms.axis" ) + " );\n "
52+ << " if (uniforms.exclusive == 0) { last += 1; }\n "
53+ << " }\n\n "
54+ << " for (var i : i32 = first; i < last; i++) {\n "
55+ << " " << input.IndicesSet (" input_indices" , " uniforms.axis" , " u32(i)" ) << " ;\n "
56+ << " sum = sum + " << input.GetByIndices (" input_indices" ) << " ;\n "
57+ << " }\n "
58+ << output.SetByOffset (" global_idx" , " sum" );
59+
60+ return Status::OK ();
61+ }
62+
63+ Status CumSum::ComputeInternal (ComputeContext& context) const {
64+ const auto * input_tensor = context.Input (0 );
65+ const TensorShape& input_shape = input_tensor->Shape ();
66+ int64_t input_rank = input_shape.NumDimensions ();
67+
68+ const auto * axis_tensor = context.Input (1 );
69+ const auto * axis_data = axis_tensor->Data <int >();
70+ int64_t axis = static_cast <int64_t >(axis_data[0 ]);
71+
72+ ORT_ENFORCE (-input_rank <= axis && axis < input_rank, " Axes attribute must be within range -input_rank <= axis < input_rank." );
73+ // Handle negative axis
74+ if (axis < 0 ) {
75+ axis += input_rank;
76+ }
77+
78+ auto * output_tensor = context.Output (0 , input_shape);
79+ int64_t output_size = output_tensor->Shape ().Size ();
80+
81+ if (output_size == 0 ) {
82+ return Status::OK ();
83+ }
84+
85+ CumSumProgram program{};
86+ program
87+ .AddInput ({input_tensor})
88+ .AddOutput ({output_tensor, ProgramTensorMetadataDependency::TypeAndRank})
89+ .SetDispatchGroupSize ((output_size + WORKGROUP_SIZE - 1 ) / WORKGROUP_SIZE)
90+ .AddUniformVariables ({{static_cast <uint32_t >(output_size)},
91+ {static_cast <uint32_t >(axis)},
92+ {static_cast <uint32_t >(exclusive_)},
93+ {static_cast <uint32_t >(reverse_)}});
94+ return context.RunProgram (program);
95+ }
96+
97+ } // namespace webgpu
98+ } // namespace onnxruntime
0 commit comments