diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index b7f8f3de955..a6cc59e26f0 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool: # If we can't get memory layout information, we'll assume the dims aren't packed pass - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: + def try_find_keepdim_arg(node: torch.fx.Node) -> bool: + for arg in node.args: + if isinstance(arg, bool): + return arg + + # Assume false by default return False - if len(node.args) > 2: - keepdim = node.args[2] - if isinstance(keepdim, bool) and not keepdim: - return False + keepdim = try_find_keepdim_arg(node) + if isinstance(keepdim, bool) and not keepdim: + return False return True diff --git a/backends/vulkan/partitioner/vulkan_partitioner.py b/backends/vulkan/partitioner/vulkan_partitioner.py index 1b5ff0a44e4..04a1a500b64 100644 --- a/backends/vulkan/partitioner/vulkan_partitioner.py +++ b/backends/vulkan/partitioner/vulkan_partitioner.py @@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo def log_skip(self, node: torch.fx.Node, reason: str) -> None: if node.op == "call_function": logger.info( - f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}" + f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}" ) def is_node_supported( diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl index c0ed9204227..0f5dbc41273 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + /* * Computes a 2D convolution. Each shader invocation calculates the output at * a single output location. @@ -74,7 +76,18 @@ void main() { // Perform the convolution by iterating over the overlay region. VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); const int ic4 = in_group_size / 4; - for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) { + + int z_start = 0; + int z_end = ic4; + if (ngroups > 1) { + const int group_size = (out_limits.z) / ngroups; + const int group_idx = pos.z / group_size; + + z_start = ic4 * group_idx; + z_end = z_start + ic4; + } + + for (int z4 = z_start; z4 < z_end; ++z4, kstart.x += kernel_size.x * 4) { for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) { for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) { const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0); diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl index 8a845b6a8a6..02fbef29b75 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + /* * Computes a depthwise convolution. Each shader invocation calculates the * output at a single output location. diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl index cf9714ca468..4c6031152ee 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl @@ -38,6 +38,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + #extension GL_EXT_control_flow_attributes : require /* diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl index a46f1e3b99c..9f84afeb1a1 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl @@ -40,6 +40,8 @@ layout(push_constant) uniform restrict Block { layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +${layout_declare_spec_const(C, "int", "ngroups", "1")} + #extension GL_EXT_control_flow_attributes : require /* diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index f5b5faa1c8b..ded1defe973 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -280,9 +280,6 @@ Conv2dMethod get_conv2d_method( if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) { return Conv2dMethod::Depthwise; } - if (groups > 1) { - VK_THROW("aten.convolution.default: groups > 1 is not supported yet!"); - } if (transposed) { return Conv2dMethod::Transposed; } @@ -601,7 +598,7 @@ void add_conv2d_node( // Push Constants push_constants, // Specialization Constants - {}, + {utils::safe_downcast(groups_val)}, // Resize Args {weight_data, stride, padding, dilation, transposed, output_padding}, // Resizing Logic diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index ff35188be3e..5aaf00fe8bc 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -297,6 +297,28 @@ def get_conv_inputs(): ) test_cases = [ + Test( + self=(1, 64, 256, 256), + weight=(64, 32, 3, 3), + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=2, + ), + Test( + self=(1, 16, 3, 3), + weight=(16, 8, 3, 3), + bias=None, + stride=[1, 1], + padding=[1, 1], + dilation=[1, 1], + transposed=False, + output_padding=[0, 0], + groups=2, + ), Test( self=(1, 6, 40, 50), weight=(8, 6, 3, 3), diff --git a/backends/vulkan/utils.py b/backends/vulkan/utils.py index 1765f0b5e1c..bc03860ed3f 100644 --- a/backends/vulkan/utils.py +++ b/backends/vulkan/utils.py @@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str: assert isinstance(node.meta["val"], (list, tuple)) return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]" else: + if "val" not in node.meta: + return str(node) return str(node.meta["val"])