Skip to content

Conversation

@daijh
Copy link
Contributor

@daijh daijh commented Nov 19, 2025

Description

This PR optimizes the Conv operation by implementing two new compute shaders: oihw_to_ohwi and im2col-matmul.

oihw_to_ohwi:
Improves performance over the default Transpose shader by utilizing workgroup memory to ensure continuous memory read/write patterns.

im2col-matmul:

  • Employs a workgroup size of 64.
  • Dynamically selects tile sizes (32x64 or 16x64) based on the source/weight shape.
  • Each invocation handles a dedicated weight element.
  • Uses subgroupShuffle to efficiently access the source tile, leveraging k_vec4 vectorization for better memory throughput.

Testing on Lunar Lake demonstrated up to an 87% performance improvement in Conv_2D operations.

Motivation and Context

See above.

@daijh
Copy link
Contributor Author

daijh commented Nov 19, 2025

Lunar Lake
onnxruntime commit d55ade0

Operation

Milliseconds conv2d-mm im2col-matmul
src: 1x128x512x512
weight: 128x128x3x3
56.071 42.824
src: 1x2560x8x8
weight: 1280x2560x3x3
21.066 11.263
src: 1x1280x8x8
weight: 1280x1280x3x3
10.384 6.357

sd-turbo

Milliseconds conv2d-mm im2col-matmul
sd-turbo-unet-fp16-demo.onnx 1010.245 612.092
sd-turbo-vae-decoder-fp16-demo.onnx 2317.391 1848.545

@daijh
Copy link
Contributor Author

daijh commented Nov 19, 2025

@guschmue @fs-eire @qjia7 PTAL.

@guschmue guschmue added the ep:WebGPU ort-web webgpu provider label Nov 21, 2025
const uint32_t kernel_height = onnxruntime::narrow<uint32_t>(kernel_shape[2]);
const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);

TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
TensorShape ohwi_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};

const uint32_t kernel_width = onnxruntime::narrow<uint32_t>(kernel_shape[3]);

TensorShape nhwc_kernel_shape{channel_output, kernel_height, kernel_width, channel_input};
Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Tensor nhwc_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);
Tensor ohwi_kernel = context.CreateGPUTensor(kernel->DataType(), nhwc_kernel_shape);


const uint32_t M_tiles = ceil_div(im2col_m, tile_m);
const uint32_t N_tiles = ceil_div(im2col_n, tile_n);
im2col_mm_program.SetDispatchGroupSize(M_tiles, N_tiles, batch);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about enhancing the current TransposeProgram with shared path instead of adding a new one?
You are doing transpose from perm [0, 1, 2, 3] to perm [0, 2, 3, 1]. It equals that we are transposing from [o, i, hw] to [o, hw, i]. You can simply extend the DoTranspose with shared path to support any shape that only transpose the last two dimensions and keep the previous dimensions unchanged. Currently, the shared path only supports 2d transpose from new shape from perm [0, 1] to new shape with perm [1, 0]. We can extend it to transpose from [0, 1, 2] to [0, 2, 1] if the transpose meets the requirement that only transpose the last two dimensions by reshape it into 3d tensor [d0 * d1*...*dn-3, dn-2, dn-1]

for (var inner_k_idx = 0u; inner_k_idx < TILE_K_VEC_SIZE; inner_k_idx++) {
let weight_data = weight_tile[inner_k_idx][local_idx];
#if use_subgroup
let src_data = src_tile[inner_k_idx][sg_id];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the sg_size is larger than or less than TILE_M_SIZE?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ep:WebGPU ort-web webgpu provider

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants