Skip to content

Commit 10e51d2

Browse files
authored
[webgpu] fix the reflect mode issue of Pad (microsoft#24202)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent f83e661 commit 10e51d2

File tree

1 file changed

+28
-18
lines changed
  • onnxruntime/core/providers/webgpu/tensor

1 file changed

+28
-18
lines changed

onnxruntime/core/providers/webgpu/tensor/pad.cc

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <vector>
66

77
#include "core/util/math.h"
8+
#include "core/providers/webgpu/string_macros.h"
89
#include "core/providers/webgpu/tensor/pad.h"
910
#include "core/providers/webgpu/shader_helper.h"
1011
#include "core/providers/webgpu/webgpu_supported_types.h"
@@ -38,38 +39,47 @@ Status PadProgram::GenerateShaderCode(ShaderHelper& shader) const {
3839
std::string lower_pads_str = GetElementAt("uniforms.lower_pads", "dim", rank);
3940
std::string data_shape_str = "i32(" + GetElementAt("uniforms.data_shape", "dim", rank) + ")";
4041
std::string data_stride_str = rank == 1 ? "" : " * " + GetElementAt("uniforms.data_stride", "dim", rank - 1);
41-
std::string begin_axis_statement = "in_coord = ";
42-
std::string end_axis_statement = "in_coord = ";
43-
std::string in_axis_statement = "in_coord = " + output_indices_str + " - " + lower_pads_str + ";\n";
42+
SS(axis_body_ss, 1024);
4443
switch (mode_) {
4544
case Mode::Constant:
46-
begin_axis_statement = "use_pad_value = true;\n";
47-
end_axis_statement = "use_pad_value = true;\n";
45+
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << " || " << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
46+
<< " use_pad_value = true;\n";
4847
break;
4948
case Mode::Edge:
50-
begin_axis_statement += "0;\n";
51-
end_axis_statement += data_shape_str + " - 1;\n";
49+
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n"
50+
<< " in_coord = 0;\n"
51+
<< " } else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
52+
<< " in_coord = " << data_shape_str + " - 1;\n";
5253
break;
5354
case Mode::Reflect:
54-
begin_axis_statement += lower_pads_str + " - " + output_indices_str + ";\n";
55-
end_axis_statement += data_shape_str + " - 2 - (" + output_indices_str +
56-
" - (" + lower_pads_str + " + " + data_shape_str + "));\n";
55+
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << " || " << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
56+
<< " in_coord = " << output_indices_str << " - " << lower_pads_str << ";\n"
57+
<< " if (in_coord < 0) {\n"
58+
<< " in_coord = -in_coord;\n"
59+
<< " }\n"
60+
<< " {\n"
61+
<< " let _2n_1 = 2 * (" << data_shape_str << " - 1);\n"
62+
<< " in_coord = in_coord % _2n_1;\n"
63+
<< " if(in_coord >= " << data_shape_str << ") {\n"
64+
<< " in_coord = _2n_1 - in_coord;\n"
65+
<< " }\n"
66+
<< " }\n";
5767
break;
5868
case Mode::Wrap:
59-
begin_axis_statement += data_shape_str + " + " + output_indices_str + " - " + lower_pads_str + ";\n";
60-
end_axis_statement += output_indices_str + " - " + lower_pads_str + " - " + data_shape_str + ";\n";
69+
axis_body_ss << " if (" << output_indices_str << " < " << lower_pads_str << ") {\n"
70+
<< " in_coord = " << data_shape_str << " + " << output_indices_str << " - " << lower_pads_str + ";\n"
71+
<< " } else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
72+
<< " in_coord = " << output_indices_str << " - " << lower_pads_str << " - " << data_shape_str << ";\n";
6173
break;
6274
default:
6375
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported mode type: ", static_cast<int>(mode_));
6476
}
77+
axis_body_ss << " } else {\n"
78+
<< " " << "in_coord = " << output_indices_str << " - " << lower_pads_str << ";\n"
79+
<< " }\n";
6580

6681
shader.MainFunctionBody() << " for (var dim = 0; dim < " << rank << " && !use_pad_value; dim++) {\n"
67-
<< " if (" << output_indices_str << " < " << lower_pads_str << ") {\n"
68-
<< " " << begin_axis_statement << " }\n"
69-
<< " else if (" << output_indices_str << " >= " << lower_pads_str << " + " << data_shape_str << ") {\n"
70-
<< " " << end_axis_statement << " }\n"
71-
<< " else {\n"
72-
<< " " << in_axis_statement << " }\n"
82+
<< SS_GET(axis_body_ss)
7383
<< " input_index += select(u32(in_coord)" << data_stride_str << ", u32(in_coord), dim == " << rank - 1 << ");\n"
7484
<< " }\n"
7585
<< " " << constant_value_str

0 commit comments

Comments
 (0)