Skip to content
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ conv2d_pw:
OPERATOR: X
NDIM: 3
DTYPE: float
TILE_SIZE_X: 2
TILE_SIZE_Y: 2
TILE_SIZE_X: 1
TILE_SIZE_Y: 4
generate_variant_forall:
DTYPE:
- VALUE: half
Expand Down
4 changes: 2 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ utils::uvec3 create_conv2d_global_wg_size(
if (method == Conv2dMethod::Pointwise) {
const utils::uvec3 image_extents = graph.logical_limits_of(out);
return {
utils::div_up(image_extents[0u], 2u),
utils::div_up(image_extents[1u], 2u),
utils::div_up(image_extents[0u], 1u),
utils::div_up(image_extents[1u], 4u),
image_extents[2u]};
} else if (method == Conv2dMethod::Depthwise && stride_equals_dilation) {
const utils::uvec3 image_extents = graph.create_global_wg_size(out);
Expand Down
Loading