Skip to content

Commit b49597b

Browse files
author
ssjia
committed
Update on "[ET-VK][testing] Add scripts to facilitate operator testiing"
Differential Revision: [D80800081](https://our.internmc.facebook.com/intern/diff/D80800081) [ghstack-poisoned]
2 parents 3e35ba2 + 0f68f01 commit b49597b

File tree

20 files changed

+103
-278
lines changed

20 files changed

+103
-278
lines changed

backends/vulkan/op_registry.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -397,14 +397,17 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
397397
# If we can't get memory layout information, we'll assume the dims aren't packed
398398
pass
399399

400-
keepdim = node.args[2]
401-
if isinstance(keepdim, bool) and not keepdim:
400+
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
401+
for arg in node.args:
402+
if isinstance(arg, bool):
403+
return arg
404+
405+
# Assume false by default
402406
return False
403407

404-
if len(node.args) > 2:
405-
keepdim = node.args[2]
406-
if isinstance(keepdim, bool) and not keepdim:
407-
return False
408+
keepdim = try_find_keepdim_arg(node)
409+
if isinstance(keepdim, bool) and not keepdim:
410+
return False
408411

409412
return True
410413

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> Tuple[bool, boo
204204
def log_skip(self, node: torch.fx.Node, reason: str) -> None:
205205
if node.op == "call_function":
206206
logger.info(
207-
f"[Vulkan Partitioner] Due to [{reason}], skipping {node.format_node()}"
207+
f"[Vulkan Partitioner] Due to [{reason}], skipping {utils.node_io_str(node)}"
208208
)
209209

210210
def is_node_supported(

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232

33+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
34+
3335
/*
3436
* Computes a 2D convolution. Each shader invocation calculates the output at
3537
* a single output location.
@@ -74,7 +76,18 @@ void main() {
7476
// Perform the convolution by iterating over the overlay region.
7577
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
7678
const int ic4 = in_group_size / 4;
77-
for (int z4 = 0; z4 < ic4; ++z4, kstart.x += kernel_size.x * 4) {
79+
80+
int z_start = 0;
81+
int z_end = ic4;
82+
if (ngroups > 1) {
83+
const int group_size = (out_limits.z) / ngroups;
84+
const int group_idx = pos.z / group_size;
85+
86+
z_start = ic4 * group_idx;
87+
z_end = z_start + ic4;
88+
}
89+
90+
for (int z4 = z_start; z4 < z_end; ++z4, kstart.x += kernel_size.x * 4) {
7891
for (int y = start.y, ky = kstart.y; y < end.y; y += dilation.y, ++ky) {
7992
for (int x = start.x, kx = kstart.x; x < end.x; x += dilation.x, kx += 4) {
8093
const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ ${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
3030

3131
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3232

33+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
34+
3335
/*
3436
* Computes a depthwise convolution. Each shader invocation calculates the
3537
* output at a single output location.

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ layout(push_constant) uniform restrict Block {
3838

3939
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4040

41+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
42+
4143
#extension GL_EXT_control_flow_attributes : require
4244

4345
/*

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw_s1p0.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ layout(push_constant) uniform restrict Block {
4040

4141
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
4242

43+
${layout_declare_spec_const(C, "int", "ngroups", "1")}
44+
4345
#extension GL_EXT_control_flow_attributes : require
4446

4547
/*

backends/vulkan/runtime/graph/ops/impl/Convolution.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,6 @@ Conv2dMethod get_conv2d_method(
280280
if (!transposed && weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) {
281281
return Conv2dMethod::Depthwise;
282282
}
283-
if (groups > 1) {
284-
VK_THROW("aten.convolution.default: groups > 1 is not supported yet!");
285-
}
286283
if (transposed) {
287284
return Conv2dMethod::Transposed;
288285
}
@@ -601,7 +598,7 @@ void add_conv2d_node(
601598
// Push Constants
602599
push_constants,
603600
// Specialization Constants
604-
{},
601+
{utils::safe_downcast<int32_t>(groups_val)},
605602
// Resize Args
606603
{weight_data, stride, padding, dilation, transposed, output_padding},
607604
// Resizing Logic

backends/vulkan/targets.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,8 @@ def define_common_targets(is_fbcode = False):
387387
"//executorch/backends/transforms:view_copy_to_squeeze_unsqueeze",
388388
"//executorch/backends/vulkan/_passes:vulkan_passes",
389389
"//executorch/backends/vulkan/serialization:lib",
390+
"//executorch/backends/transforms:remove_getitem_op",
391+
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
390392
"//executorch/exir/backend:backend_details",
391393
],
392394
)

backends/vulkan/test/op_tests/cases.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,28 @@ def get_conv_inputs():
297297
)
298298

299299
test_cases = [
300+
Test(
301+
self=(1, 64, 256, 256),
302+
weight=(64, 32, 3, 3),
303+
bias=None,
304+
stride=[1, 1],
305+
padding=[1, 1],
306+
dilation=[1, 1],
307+
transposed=False,
308+
output_padding=[0, 0],
309+
groups=2,
310+
),
311+
Test(
312+
self=(1, 16, 3, 3),
313+
weight=(16, 8, 3, 3),
314+
bias=None,
315+
stride=[1, 1],
316+
padding=[1, 1],
317+
dilation=[1, 1],
318+
transposed=False,
319+
output_padding=[0, 0],
320+
groups=2,
321+
),
300322
Test(
301323
self=(1, 6, 40, 50),
302324
weight=(8, 6, 3, 3),

backends/vulkan/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1059,6 +1059,8 @@ def get_node_val_str(node: torch.fx.Node) -> str:
10591059
assert isinstance(node.meta["val"], (list, tuple))
10601060
return f"[{', '.join(get_tensor_val_str(t) for t in node.meta['val'])}]"
10611061
else:
1062+
if "val" not in node.meta:
1063+
return str(node)
10621064
return str(node.meta["val"])
10631065

10641066

0 commit comments

Comments
 (0)