Skip to content

Commit 8d21bf7

Browse files
authored
[WebGPU EP] Implements CumSum Operator (microsoft#24047)
Increases WebGPU EP op coverage.
1 parent da7874c commit 8d21bf7

File tree

3 files changed

+139
-2
lines changed

3 files changed

+139
-2
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/math/cum_sum.h"
5+
#include "core/providers/webgpu/shader_helper.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12+
CumSum,
13+
kOnnxDomain,
14+
11, 13,
15+
kWebGpuExecutionProvider,
16+
(*KernelDefBuilder::Create())
17+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
18+
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<int32_t>(),
19+
DataTypeImpl::GetTensorType<int64_t>()})
20+
.InputMemoryType(OrtMemTypeCPU, 1),
21+
CumSum);
22+
23+
ONNX_OPERATOR_KERNEL_EX(
24+
CumSum,
25+
kOnnxDomain,
26+
14,
27+
kWebGpuExecutionProvider,
28+
(*KernelDefBuilder::Create())
29+
.TypeConstraint("T", WebGpuSupportedFloatTypes())
30+
.TypeConstraint("T2", {DataTypeImpl::GetTensorType<int32_t>(),
31+
DataTypeImpl::GetTensorType<int64_t>()})
32+
.InputMemoryType(OrtMemTypeCPU, 1),
33+
CumSum);
34+
35+
Status CumSumProgram::GenerateShaderCode(ShaderHelper& shader) const {
36+
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform);
37+
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
38+
39+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
40+
<< "var input_indices = " << input.OffsetToIndices("global_idx") << ";\n"
41+
<< "var sum : output_value_t = 0;\n"
42+
<< "var first : i32 = 0;\n"
43+
<< "if (uniforms.reverse == 1) {\n"
44+
<< " first = i32(" + input.IndicesGet("input_indices", "uniforms.axis") + ");\n"
45+
<< " if (uniforms.exclusive == 1) { first += 1; }\n"
46+
<< "}\n\n"
47+
<< "var last : i32 = 0;\n"
48+
<< "if (uniforms.reverse == 1) {\n"
49+
<< " last = i32(" << GetElementAt("uniforms.input_shape", "uniforms.axis", input.Rank()) << ");\n"
50+
<< "} else {\n"
51+
<< " last = i32(" + input.IndicesGet("input_indices", "uniforms.axis") + ");\n"
52+
<< " if (uniforms.exclusive == 0) { last += 1; }\n"
53+
<< "}\n\n"
54+
<< "for (var i : i32 = first; i < last; i++) {\n"
55+
<< " " << input.IndicesSet("input_indices", "uniforms.axis", "u32(i)") << ";\n"
56+
<< " sum = sum + " << input.GetByIndices("input_indices") << ";\n"
57+
<< "}\n"
58+
<< output.SetByOffset("global_idx", "sum");
59+
60+
return Status::OK();
61+
}
62+
63+
Status CumSum::ComputeInternal(ComputeContext& context) const {
64+
const auto* input_tensor = context.Input(0);
65+
const TensorShape& input_shape = input_tensor->Shape();
66+
int64_t input_rank = input_shape.NumDimensions();
67+
68+
const auto* axis_tensor = context.Input(1);
69+
const auto* axis_data = axis_tensor->Data<int>();
70+
int64_t axis = static_cast<int64_t>(axis_data[0]);
71+
72+
ORT_ENFORCE(-input_rank <= axis && axis < input_rank, "Axes attribute must be within range -input_rank <= axis < input_rank.");
73+
// Handle negative axis
74+
if (axis < 0) {
75+
axis += input_rank;
76+
}
77+
78+
auto* output_tensor = context.Output(0, input_shape);
79+
int64_t output_size = output_tensor->Shape().Size();
80+
81+
if (output_size == 0) {
82+
return Status::OK();
83+
}
84+
85+
CumSumProgram program{};
86+
program
87+
.AddInput({input_tensor})
88+
.AddOutput({output_tensor, ProgramTensorMetadataDependency::TypeAndRank})
89+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
90+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
91+
{static_cast<uint32_t>(axis)},
92+
{static_cast<uint32_t>(exclusive_)},
93+
{static_cast<uint32_t>(reverse_)}});
94+
return context.RunProgram(program);
95+
}
96+
97+
} // namespace webgpu
98+
} // namespace onnxruntime
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/providers/webgpu/webgpu_kernel.h"
7+
#include "core/providers/webgpu/program.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class CumSumProgram final : public Program<CumSumProgram> {
13+
public:
14+
CumSumProgram() : Program{"CumSum"} {}
15+
16+
Status GenerateShaderCode(ShaderHelper& sh) const override;
17+
18+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
19+
{"axis", ProgramUniformVariableDataType::Uint32},
20+
{"exclusive", ProgramUniformVariableDataType::Uint32},
21+
{"reverse", ProgramUniformVariableDataType::Uint32});
22+
};
23+
24+
class CumSum final : public WebGpuKernel {
25+
public:
26+
CumSum(const OpKernelInfo& info) : WebGpuKernel(info) {
27+
exclusive_ = info.GetAttrOrDefault<int64_t>("exclusive", 0);
28+
reverse_ = info.GetAttrOrDefault<int64_t>("reverse", 0);
29+
}
30+
31+
Status ComputeInternal(ComputeContext& context) const override;
32+
33+
private:
34+
int64_t exclusive_;
35+
int64_t reverse_;
36+
};
37+
38+
} // namespace webgpu
39+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
713713
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 9, 13, BatchNormalization)>,
714714
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 14, 14, BatchNormalization)>,
715715
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSInternalNHWCDomain, 15, BatchNormalization)>,
716-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
717-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
716+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 13, CumSum)>,
717+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, CumSum)>,
718718
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, uint8_t, DequantizeLinear)>,
719719
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int8_t, DequantizeLinear)>,
720720
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 12, int32_t, DequantizeLinear)>,

0 commit comments

Comments
 (0)