Skip to content

Commit 5f25039

Browse files
committed
support vec2 and add unit test
1 parent 10d490c commit 5f25039

File tree

3 files changed

+89
-21
lines changed

3 files changed

+89
-21
lines changed

onnxruntime/core/providers/webgpu/nn/im2col_matmul.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ Status Im2ColMatMulProgram::GenerateShaderCode(ShaderHelper& shader) const {
7171

7272
ORT_ENFORCE(tile_m_ == 16 || tile_m_ == 32, "tile_m must be 16 or 32.");
7373
ORT_ENFORCE(tile_n_ == 64, "tile_n must be 64.");
74-
ORT_ENFORCE(vec_size_ == 1 || vec_size_ == 4, "vec_size must be 4 or 1.");
74+
ORT_ENFORCE(vec_size_ == 1 || vec_size_ == 2 || vec_size_ == 4, "vec_size must be 1, 2 or 4.");
7575

7676
return WGSL_TEMPLATE_APPLY(shader, "nn/im2col_matmul.wgsl.template",
7777
WGSL_TEMPLATE_PARAMETER(has_bias, has_bias_),
@@ -147,7 +147,7 @@ Status ApplyIm2ColMatMulProgram(ComputeContext& context,
147147
// Ensure the subgroup size must be greater than or equal to `tile_m` to safely enable `use_subgroup`.
148148
// If the status of this condition is uncertain, the feature must be disabled.
149149
const bool use_subgroup = false;
150-
const uint32_t vec_size = channel_input % 4 == 0 ? 4 : 1;
150+
const uint32_t vec_size = channel_input % 4 == 0 ? 4 : (channel_input % 2 == 0 ? 2 : 1);
151151
Im2ColMatMulProgram im2col_mm_program{has_bias, tile_m, tile_n, vec_size, use_subgroup};
152152
im2col_mm_program.SetWorkgroupSize(workgroup_size);
153153

onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ fn write_output(batch : u32, m : u32, n : u32, value : output_element_t) {
8282
const TILE_M_SIZE : u32 = tile_m;
8383
const TILE_N_SIZE : u32 = tile_n;
8484
const TILE_K_VEC_SIZE : u32 = 16 / vec_size;
85+
const ADVANCE_DIM = 64 / TILE_K_VEC_SIZE;
8586

8687
var<workgroup> src_tile : array<array<src_value_t, TILE_M_SIZE>, TILE_K_VEC_SIZE>;
8788
var<workgroup> weight_tile : array<array<weight_value_t, TILE_N_SIZE>, TILE_K_VEC_SIZE>;
@@ -93,32 +94,20 @@ $MAIN {
9394

9495
var results : array<output_element_t, TILE_M_SIZE>;
9596
for (var k_idx = 0u; k_idx < uniforms.K_tiles; k_idx++) {
96-
#if vec_size != 4
97-
for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 4u) {
98-
let load_src_m = src_m + local_idx / 16;
99-
let load_src_k = local_idx % 16;
100-
#else
101-
for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += 16u) {
97+
for (var src_m = 0u; src_m < TILE_M_SIZE; src_m += ADVANCE_DIM) {
10298
// Loads a 16x4 vec of src into the workgroup memory.
103-
let load_src_m = src_m + local_idx / 4;
104-
let load_src_k = local_idx % 4;
105-
#endif
99+
let load_src_m = src_m + local_idx / TILE_K_VEC_SIZE;
100+
let load_src_k = local_idx % TILE_K_VEC_SIZE;
106101

107102
src_tile[load_src_k][load_src_m] = load_src(batch,
108103
m_global_base + load_src_m,
109104
k_idx * TILE_K_VEC_SIZE + load_src_k);
110105
}
111106

112-
#if vec_size != 4
113-
for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 4u) {
114-
let load_weight_n = weight_n + local_idx / 16;
115-
let load_weight_k = local_idx % 16;
116-
#else
117-
for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += 16u) {
107+
for (var weight_n = 0u; weight_n < TILE_N_SIZE; weight_n += ADVANCE_DIM) {
118108
// Loads a 16x4 vec of weight into the workgroup memory.
119-
let load_weight_n = weight_n + local_idx / 4;
120-
let load_weight_k = local_idx % 4;
121-
#endif
109+
let load_weight_n = weight_n + local_idx / TILE_K_VEC_SIZE;
110+
let load_weight_k = local_idx % TILE_K_VEC_SIZE;
122111

123112
weight_tile[load_weight_k][load_weight_n] = load_weight(n_global_base + load_weight_n,
124113
k_idx * TILE_K_VEC_SIZE + load_weight_k);
@@ -134,7 +123,7 @@ $MAIN {
134123
}
135124
#else
136125
for (var m_idx = 0u; m_idx < TILE_M_SIZE; m_idx++) {
137-
#if vec_size != 4
126+
#if vec_size == 1
138127
results[m_idx] += output_element_t(weight_data * src_tile[inner_k_idx][m_idx]);
139128
#else
140129
results[m_idx] += output_element_t(dot(weight_data, src_tile[inner_k_idx][m_idx]));

onnxruntime/test/providers/cpu/nn/conv_op_test.cc

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,85 @@ TEST(ConvTest, Conv2D_3) {
395395
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
396396
}
397397

398+
TEST(ConvTest, Conv2D_4) {
399+
ConvOpAndTestAttributes attrs = {
400+
"", // auto_pad
401+
vector<int64_t>{1, 1}, // dilations
402+
1, // group
403+
vector<int64_t>{2, 2}, // kernel_shape
404+
vector<int64_t>{1, 2, 3, 1}, // pads
405+
vector<int64_t>{1, 1}, // strides
406+
{} // excluded EPs
407+
};
408+
409+
vector<int64_t> X_shape = {1, 4, 3, 3};
410+
vector<float> X(36, 1.f);
411+
412+
vector<int64_t> W_shape = {2, 4, 2, 2};
413+
vector<float> W(32, 1.f);
414+
415+
vector<int64_t> Y_shape = {1, 2, 6, 5};
416+
417+
auto Y = {
418+
0.f, 4.f, 8.f, 8.f, 4.f,
419+
0.f, 8.f, 16.f, 16.f, 8.f,
420+
0.f, 8.f, 16.f, 16.f, 8.f,
421+
0.f, 4.f, 8.f, 8.f, 4.f,
422+
0.f, 0.f, 0.f, 0.f, 0.f,
423+
0.f, 0.f, 0.f, 0.f, 0.f,
424+
425+
0.f, 4.f, 8.f, 8.f, 4.f,
426+
0.f, 8.f, 16.f, 16.f, 8.f,
427+
0.f, 8.f, 16.f, 16.f, 8.f,
428+
0.f, 4.f, 8.f, 8.f, 4.f,
429+
0.f, 0.f, 0.f, 0.f, 0.f,
430+
0.f, 0.f, 0.f, 0.f, 0.f};
431+
432+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape);
433+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
434+
}
435+
436+
TEST(ConvTest, Conv2D_5) {
437+
ConvOpAndTestAttributes attrs = {
438+
"", // auto_pad
439+
vector<int64_t>{1, 1}, // dilations
440+
1, // group
441+
vector<int64_t>{2, 2}, // kernel_shape
442+
vector<int64_t>{1, 2, 3, 1}, // pads
443+
vector<int64_t>{1, 1}, // strides
444+
{} // excluded EPs
445+
};
446+
447+
vector<int64_t> X_shape = {1, 6, 3, 3};
448+
vector<float> X(54);
449+
for (int i = 0; i < 54; ++i) {
450+
X[i] = static_cast<float>(i + 1);
451+
}
452+
453+
vector<int64_t> W_shape = {2, 6, 2, 2};
454+
vector<float> W(48, 1.f);
455+
456+
vector<int64_t> Y_shape = {1, 2, 6, 5};
457+
458+
auto Y = {
459+
0.f, 141.f, 288.f, 300.f, 153.f,
460+
0.f, 300.f, 612.f, 636.f, 324.f,
461+
0.f, 336.f, 684.f, 708.f, 360.f,
462+
0.f, 177.f, 360.f, 372.f, 189.f,
463+
0.f, 0.f, 0.f, 0.f, 0.f,
464+
0.f, 0.f, 0.f, 0.f, 0.f,
465+
466+
0.f, 141.f, 288.f, 300.f, 153.f,
467+
0.f, 300.f, 612.f, 636.f, 324.f,
468+
0.f, 336.f, 684.f, 708.f, 360.f,
469+
0.f, 177.f, 360.f, 372.f, 189.f,
470+
0.f, 0.f, 0.f, 0.f, 0.f,
471+
0.f, 0.f, 0.f, 0.f, 0.f};
472+
473+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape);
474+
TestConvOp(attrs, {X, W}, {X_shape, W_shape}, Y, Y_shape, true);
475+
}
476+
398477
TEST(ConvTest, Conv2D_Bias_1) {
399478
ConvOpAndTestAttributes attrs = {
400479
"", // auto_pad

0 commit comments

Comments
 (0)