From f532f651bbd8aa7b0fbf93354f6e4c2b542b9b2c Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 22 Aug 2025 07:30:58 -0700 Subject: [PATCH 1/3] [ET-VK][ez] Fix partitioner logic of finding keepdim arg of reduce ops Title says it all. For reduce ops, their signature are not all alike so some extra legwork needs to be done to identify specific arguments that need to be checked. Also included a small update to partitioner logging to improve debuggability. Differential Revision: [D80741737](https://our.internmc.facebook.com/intern/diff/D80741737/) [ghstack-poisoned] --- backends/vulkan/op_registry.py | 15 +++++++++------ backends/vulkan/partitioner/vulkan_partitioner.py | 2 +- backends/vulkan/utils.py | 2 ++ 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b7f8f3de955..a6cc59e26f0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool: # If we can't get memory layout information, we'll assume the dims aren't packed pass - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: + def try_find_keepdim_arg(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default return False - if len(node.args) > 2: - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + keepdim = try_find_keepdim_arg(node) + if isinstance(keepdim, bool) and not keepdim: + return False return True diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 1b5ff0a44e4..04a1a500b64 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( - f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" + f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}" ) def is_node_supported( diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1765f0b5e1c..bc03860ed3f 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str: assert isinstance(node.meta["val"], (list, tuple)) return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" else: + if "val" not in node.meta: + return str(node) return str(node.meta["val"]) From 173946dad0e2d1e14a4af0cba55e93c24c2d54ed Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 22 Aug 2025 07:31:01 -0700 Subject: [PATCH 2/3] [ET-VK][ez] Support grouped convolutions Title says it all! Differential Revision: [D80741734](https://our.internmc.facebook.com/intern/diff/D80741734/) [ghstack-poisoned] --- .../vulkan/runtime/graph/ops/glsl/conv2d.glsl | 15 ++++++++++++- .../runtime/graph/ops/glsl/conv2d_dw.glsl | 2 ++ .../runtime/graph/ops/glsl/conv2d_pw.glsl | 2 ++ .../graph/ops/glsl/conv2d_pw_s1p0.glsl | 2 ++ .../runtime/graph/ops/impl/Convolution.cpp | 5 +---- backends/vulkan/test/op_tests/cases.py | 22 +++++++++++++++++++ 6 files changed, 43 insertions(+), 5 deletions(-) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index c0ed9204227..0f5dbc41273 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + /* * Computes a 2D convolution. Each shader invocation calculates the output at * a single output location. @@ -74,7 +76,18 @@ void main() { // Perform the convolution by iterating over the overlay region. VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); const int ic4 = in_group_size / 4; - for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) { + + int z_start = 0; + int z_end = ic4; + if (ngroups > 1) { + const int group_size = (out_limits.z) / ngroups; + const int group_idx = pos.z / group_size; + + z_start = ic4 * group_idx; + z_end = z_start + ic4; + } + + for (int z4 = z_start; z4 < z_end; ++z4, kstart.x += kernel_size.x * 4) { for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) { for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) { const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 8a845b6a8a6..02fbef29b75 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index cf9714ca468..4c6031152ee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -38,6 +38,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + #extension GL_EXT_control_flow_attributes : require /* diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index a46f1e3b99c..9f84afeb1a1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -40,6 +40,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + #extension GL_EXT_control_flow_attributes : require /* diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index f5b5faa1c8b..ded1defe973 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -280,9 +280,6 @@ Conv2dMethod get_conv2d_method( if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) { return Conv2dMethod::Depthwise; } - if (groups > 1) { - VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); - } if (transposed) { return Conv2dMethod::Transposed; } @@ -601,7 +598,7 @@ void add_conv2d_node( // Push Constants push_constants, // Specialization Constants - {}, + {utils::safe_downcast(groups_val)}, // Resize Args {weight_data, stride, padding, dilation, transposed, output_padding}, // Resizing Logic diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index ff35188be3e..5aaf00fe8bc 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -297,6 +297,28 @@ def get_conv_inputs(): ) test_cases = [ + Test( + self=(1, 64, 256, 256), + weight=(64, 32, 3, 3), + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=2, + ), + Test( + self=(1, 16, 3, 3), + weight=(16, 8, 3, 3), + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=2, + ), Test( self=(1, 6, 40, 50), weight=(8, 6, 3, 3), From 69c945be1bc169cf8734d623f358a1a95dbf9eaf Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 22 Aug 2025 14:12:09 -0700 Subject: [PATCH 3/3] Update base for Update on "[ET-VK][ez] Support grouped convolutions" Title says it all! Differential Revision: [D80741734](https://our.internmc.facebook.com/intern/diff/D80741734/) [ghstack-poisoned]