Skip to content

Commit 80441e4

Browse files
authored
[WebGPU EP] fix implementation of Pow (microsoft#24088)
### Description Use custom implementation for Pow to fix test failures.
1 parent d8ed4da commit 80441e4

File tree

2 files changed

+46
-7
lines changed

2 files changed

+46
-7
lines changed

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
#include "core/providers/common.h"
55
#include "core/providers/webgpu/math/binary_elementwise_ops.h"
66
#include "core/providers/webgpu/shader_helper.h"
7+
#include "core/providers/webgpu/string_macros.h"
78
#include "core/providers/webgpu/webgpu_supported_types.h"
89

910
namespace onnxruntime {
1011
namespace webgpu {
1112
Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const {
12-
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
13-
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
13+
const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
14+
const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias);
1415
const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias);
1516

17+
shader.AdditionalImplementation() << additional_impl_;
18+
1619
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size");
1720

1821
// check whether can use element-wise mode.
@@ -142,8 +145,15 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const {
142145
}
143146

144147
uint32_t vec_size = onnxruntime::narrow<uint32_t>((size + 3) / 4);
148+
149+
std::string additional_impl;
150+
if (get_additional_impl_) {
151+
additional_impl = get_additional_impl_(lhs_tensor->GetElementType(), rhs_tensor->GetElementType());
152+
}
153+
145154
BinaryElementwiseProgram program{kernel_name_,
146155
expression_,
156+
additional_impl,
147157
is_broadcast,
148158
is_lhs_scalar,
149159
is_rhs_scalar,
@@ -273,7 +283,28 @@ WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes())
273283
WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes())
274284
WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes())
275285

276-
WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4<f32>(a), vec4<f32>(b)))")
286+
std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) {
287+
SS(s, 1024);
288+
std::string round_str;
289+
if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) {
290+
round_str = "round";
291+
}
292+
293+
s << "fn pow_custom(a : input_a_element_t, b : f32) -> input_a_element_t {\n"
294+
" if (b == 0.0) {\n"
295+
" return input_a_element_t(1.0);\n"
296+
" } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n"
297+
" return input_a_element_t(pow(f32(a), b)); // NaN\n"
298+
" }\n"
299+
<< " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n"
300+
<< "}\n"
301+
"fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n"
302+
" return vec4<input_a_element_t>(pow_custom(a.x, f32(b.x)), pow_custom(a.y, f32(b.y)), pow_custom(a.z, f32(b.z)), pow_custom(a.w, f32(b.w)));\n"
303+
"}\n";
304+
return SS_GET(s);
305+
}
306+
307+
WEBGPU_BINARY_IMPL(Pow, "pow_v(a, b)", GetPowImpl)
277308
WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes())
278309
WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes())
279310
WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes())

onnxruntime/core/providers/webgpu/math/binary_elementwise_ops.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
1414
public:
1515
BinaryElementwiseProgram(const std::string& kernel_name,
1616
const std::string& expression,
17+
const std::string& additional_impl,
1718
const bool is_broadcast,
1819
const bool is_lhs_scalar,
1920
const bool is_rhs_scalar,
2021
const bool vectorize) : Program{kernel_name},
2122
expression_{expression},
23+
additional_impl_{additional_impl},
2224
is_broadcast_{is_broadcast},
2325
is_lhs_scalar_{is_lhs_scalar},
2426
is_rhs_scalar_{is_rhs_scalar},
@@ -29,7 +31,8 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
2931
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"vec_size", ProgramUniformVariableDataType::Uint32});
3032

3133
private:
32-
std::string expression_;
34+
std::string_view expression_;
35+
std::string_view additional_impl_;
3336
bool is_broadcast_;
3437
bool is_lhs_scalar_;
3538
bool is_rhs_scalar_;
@@ -38,18 +41,23 @@ class BinaryElementwiseProgram final : public Program<BinaryElementwiseProgram>
3841

3942
class BinaryElementwise : public WebGpuKernel {
4043
public:
44+
using GetAdditionalImplementationFunction = std::string (*)(int lhs_element_type, int rhs_element_type);
45+
4146
BinaryElementwise(const OpKernelInfo& info,
4247
const std::string& kernel_name,
43-
const std::string& expression) : WebGpuKernel{info},
44-
kernel_name_{kernel_name},
45-
expression_{expression} {}
48+
const std::string& expression,
49+
const GetAdditionalImplementationFunction get_additional_impl = nullptr) : WebGpuKernel{info},
50+
kernel_name_{kernel_name},
51+
expression_{expression},
52+
get_additional_impl_{get_additional_impl} {}
4653

4754
protected:
4855
Status ComputeInternal(ComputeContext& context) const final;
4956

5057
private:
5158
std::string kernel_name_;
5259
std::string expression_;
60+
const GetAdditionalImplementationFunction get_additional_impl_;
5361
};
5462

5563
} // namespace webgpu

0 commit comments

Comments
 (0)