Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
# If we can't get memory layout information, we'll assume the dims aren't packed
pass

keepdim = node.args[2]
if isinstance(keepdim, bool) and not keepdim:
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
for arg in node.args:
if isinstance(arg, bool):
return arg

# Assume false by default
return False

if len(node.args) > 2:
keepdim = node.args[2]
if isinstance(keepdim, bool) and not keepdim:
return False
keepdim = try_find_keepdim_arg(node)
if isinstance(keepdim, bool) and not keepdim:
return False

return True

Expand Down
2 changes: 1 addition & 1 deletion backends/vulkan/partitioner/vulkan_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
if node.op == "call_function":
logger.info(
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}"
)

def is_node_supported(
Expand Down
15 changes: 14 additions & 1 deletion backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "ngroups", "1")}

/*
* Computes a 2D convolution. Each shader invocation calculates the output at
* a single output location.
Expand Down Expand Up @@ -74,7 +76,18 @@ void main() {
// Perform the convolution by iterating over the overlay region.
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
const int ic4 = in_group_size / 4;
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) {

int z_start = 0;
int z_end = ic4;
if (ngroups > 1) {
const int group_size = (out_limits.z) / ngroups;
const int group_idx = pos.z / group_size;

z_start = ic4 * group_idx;
z_end = z_start + ic4;
}

for (int z4 = z_start; z4 < z_end; ++z4, kstart.x += kernel_size.x * 4) {
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) {
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) {
const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "ngroups", "1")}

/*
* Computes a depthwise convolution. Each shader invocation calculates the
* output at a single output location.
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "ngroups", "1")}

#extension GL_EXT_control_flow_attributes : require

/*
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ layout(push_constant) uniform restrict Block {

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

${layout_declare_spec_const(C, "int", "ngroups", "1")}

#extension GL_EXT_control_flow_attributes : require

/*
Expand Down
5 changes: 1 addition & 4 deletions backends/vulkan/runtime/graph/ops/impl/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,6 @@ Conv2dMethod get_conv2d_method(
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
return Conv2dMethod::Depthwise;
}
if (groups > 1) {
VK_THROW("aten.convolution.default: groups > 1 is not supported yet!");
}
if (transposed) {
return Conv2dMethod::Transposed;
}
Expand Down Expand Up @@ -601,7 +598,7 @@ void add_conv2d_node(
// Push Constants
push_constants,
// Specialization Constants
{},
{utils::safe_downcast<int32_t>(groups_val)},
// Resize Args
{weight_data, stride, padding, dilation, transposed, output_padding},
// Resizing Logic
Expand Down
22 changes: 22 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,28 @@ def get_conv_inputs():
)

test_cases = [
Test(
self=(1, 64, 256, 256),
weight=(64, 32, 3, 3),
bias=None,
stride=[1, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=2,
),
Test(
self=(1, 16, 3, 3),
weight=(16, 8, 3, 3),
bias=None,
stride=[1, 1],
padding=[1, 1],
dilation=[1, 1],
transposed=False,
output_padding=[0, 0],
groups=2,
),
Test(
self=(1, 6, 40, 50),
weight=(8, 6, 3, 3),
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str:
assert isinstance(node.meta["val"], (list, tuple))
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
else:
if "val" not in node.meta:
return str(node)
return str(node.meta["val"])


Expand Down
Loading