|
4 | 4 | #include "core/providers/common.h" |
5 | 5 | #include "core/providers/webgpu/math/binary_elementwise_ops.h" |
6 | 6 | #include "core/providers/webgpu/shader_helper.h" |
| 7 | +#include "core/providers/webgpu/string_macros.h" |
7 | 8 | #include "core/providers/webgpu/webgpu_supported_types.h" |
8 | 9 |
|
9 | 10 | namespace onnxruntime { |
10 | 11 | namespace webgpu { |
11 | 12 | Status BinaryElementwiseProgram::GenerateShaderCode(ShaderHelper& shader) const { |
12 | | - const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); |
13 | | - const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); |
| 13 | + const auto& a = shader.AddInput("input_a", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); |
| 14 | + const auto& b = shader.AddInput("input_b", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias); |
14 | 15 | const auto& c = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias); |
15 | 16 |
|
| 17 | + shader.AdditionalImplementation() << additional_impl_; |
| 18 | + |
16 | 19 | shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.vec_size"); |
17 | 20 |
|
18 | 21 | // check whether can use element-wise mode. |
@@ -142,8 +145,15 @@ Status BinaryElementwise::ComputeInternal(ComputeContext& context) const { |
142 | 145 | } |
143 | 146 |
|
144 | 147 | uint32_t vec_size = onnxruntime::narrow<uint32_t>((size + 3) / 4); |
| 148 | + |
| 149 | + std::string additional_impl; |
| 150 | + if (get_additional_impl_) { |
| 151 | + additional_impl = get_additional_impl_(lhs_tensor->GetElementType(), rhs_tensor->GetElementType()); |
| 152 | + } |
| 153 | + |
145 | 154 | BinaryElementwiseProgram program{kernel_name_, |
146 | 155 | expression_, |
| 156 | + additional_impl, |
147 | 157 | is_broadcast, |
148 | 158 | is_lhs_scalar, |
149 | 159 | is_rhs_scalar, |
@@ -273,7 +283,28 @@ WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 7, 12, Sub, WebGpuSupportedNumberTypes()) |
273 | 283 | WEBGPU_BINARY_VERSIONED_KERNEL(Sub, 13, 13, Sub, WebGpuSupportedNumberTypes()) |
274 | 284 | WEBGPU_BINARY_KERNEL(Sub, 14, Sub, WebGpuSupportedNumberTypes()) |
275 | 285 |
|
276 | | -WEBGPU_BINARY_IMPL(Pow, "output_value_t(pow(vec4<f32>(a), vec4<f32>(b)))") |
| 286 | +std::string GetPowImpl(int lhs_element_type, int /* rhs_element_type */) { |
| 287 | + SS(s, 1024); |
| 288 | + std::string round_str; |
| 289 | + if (lhs_element_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) { |
| 290 | + round_str = "round"; |
| 291 | + } |
| 292 | + |
| 293 | + s << "fn pow_custom(a : input_a_element_t, b : f32) -> input_a_element_t {\n" |
| 294 | + " if (b == 0.0) {\n" |
| 295 | + " return input_a_element_t(1.0);\n" |
| 296 | + " } else if (a < input_a_element_t(0.0) && b != floor(b)) {\n" |
| 297 | + " return input_a_element_t(pow(f32(a), b)); // NaN\n" |
| 298 | + " }\n" |
| 299 | + << " return select(sign(a), input_a_element_t(1.0), round(abs(b) % 2.0) != 1.0) * input_a_element_t(" << round_str << "(pow(f32(abs(a)), b)));\n" |
| 300 | + << "}\n" |
| 301 | + "fn pow_v(a : vec4<input_a_element_t>, b : vec4<input_b_element_t>) -> vec4<input_a_element_t> {\n" |
| 302 | + " return vec4<input_a_element_t>(pow_custom(a.x, f32(b.x)), pow_custom(a.y, f32(b.y)), pow_custom(a.z, f32(b.z)), pow_custom(a.w, f32(b.w)));\n" |
| 303 | + "}\n"; |
| 304 | + return SS_GET(s); |
| 305 | +} |
| 306 | + |
| 307 | +WEBGPU_BINARY_IMPL(Pow, "pow_v(a, b)", GetPowImpl) |
277 | 308 | WEBGPU_BINARY_VERSIONED_KERNEL(Pow, 7, 11, Pow, WebGpuSupportedNumberTypes()) |
278 | 309 | WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 12, 12, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) |
279 | 310 | WEBGPU_BINARY_VERSIONED_KERNEL_2(Pow, 13, 14, Pow, WebGpuSupportedNumberTypes(), WebGpuSupportedNumberTypes()) |
|
0 commit comments