Skip to content

Commit 54030da

Browse files
pytorchbotSicheng Jia
andauthored
[ET-VK] Miscellaneous fixes (#14803)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #14732 by @SS-JIA ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/SS-JIA/335/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/335/head Merge bot PR base: https://github.com/pytorch/executorch/tree/main Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/SS-JIA/335/orig Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/) @diff-train-skip-merge Co-authored-by: Sicheng Jia <[email protected]>
1 parent 5c520d3 commit 54030da

23 files changed

+298
-204
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,16 @@ jobs:
958958
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
959959
960960
# Test models serially
961-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
961+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4"
962962
for model in $models; do
963963
python -m examples.vulkan.export --model_name=$model --test
964964
done
965965
966+
# For selected vision models, test with dynamic shapes
967+
models="mv2 resnet18 resnet50 ic3 densenet161"
968+
for model in $models; do
969+
python -m examples.vulkan.export --model_name=$model --test -d
970+
done
966971
967972
test-vulkan-operators-linux:
968973
name: test-vulkan-operators-linux

backends/vulkan/_passes/fold_qdq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass):
1717
valid quant op patterns have already been fused before this pass.
1818
"""
1919

20-
def __init__(self, edge_program: torch.export.ExportedProgram):
21-
super(FoldQDQPass, self).__init__()
22-
self.edge_program = edge_program
20+
def __init__(self):
21+
super().__init__()
2322

2423
def call(self, graph_module: torch.fx.GraphModule):
2524
for node in graph_module.graph.nodes:

backends/vulkan/_passes/fuse_patterns.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810

911
import torch
@@ -13,13 +15,15 @@
1315

1416

1517
class FusePatternsPass(ExportPass):
16-
def __init__(self, exported_program: ExportedProgram) -> None:
18+
def __init__(self) -> None:
1719
super().__init__()
18-
self.program = exported_program
20+
self._exported_program: Optional[ExportedProgram] = None
1921

2022
def call(self, graph_module: torch.fx.GraphModule):
23+
assert self._exported_program is not None
24+
2125
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
22-
self.program, graph_module
26+
self._exported_program, graph_module
2327
)
2428

2529
if total_replaced > 0:

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node(
211211

212212

213213
class FuseQuantizedOpsTransform(ExportPass):
214-
def __init__(self, exported_program: ExportedProgram) -> None:
214+
def __init__(self) -> None:
215215
super().__init__()
216-
self.program = exported_program
216+
self._exported_program: Optional[ExportedProgram] = None
217217

218218
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219+
assert self._exported_program is not None
220+
219221
for node in graph_module.graph.nodes:
220222
# Check for linear_qcnw pattern (weight-only quantization)
221-
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
223+
qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node)
222224
if qcnw_details is not None:
223225
qcnw_method, qcnw_nbits = qcnw_details
224226
fuse_into_linear_qcnw_node(
225-
self.program, graph_module, node, qcnw_method, qcnw_nbits
227+
self._exported_program, graph_module, node, qcnw_method, qcnw_nbits
226228
)
227229
continue
228230

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def get_arg_tensor_source_repset(
230230
"""
231231
arg_node = op_node.args[arg_i]
232232

233+
# For non-tensor arguments, return ANY_STORAGE
234+
if not utils.is_tensor_arg_node(arg_node):
235+
return utils.ANY_STORAGE
236+
233237
# Special case for cat - use the first tensor in the list as representative
234238
if isinstance(arg_node, list):
235239
arg_node = arg_node[0]

backends/vulkan/op_registry.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import torch
1818

19-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
20-
2119
from executorch.exir.dialects._ops import ops as exir_ops
2220

2321
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
4846
# Optional check function used during partitioning to determine if a node's
4947
# inputs are supported by the operator implementation.
5048
"are_node_inputs_supported_fn",
49+
# Optional function to determine valid representation sets for input and outputs
50+
# once a node's actual inputs are known.
51+
"pick_io_storage_fn",
5152
]
5253

5354
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
supports_resize: bool = False,
6263
supports_prepacking: bool = False,
6364
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
65+
pick_io_storage_fn: Optional[Callable] = None,
6466
):
6567
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
6668
inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
7779
self.supports_prepacking = supports_prepacking
7880

7981
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
82+
self.pick_io_storage_fn = pick_io_storage_fn
8083

8184
def make_op_repsets(
8285
self,
8386
op_node: torch.fx.Node,
8487
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
8588
) -> utils.OpRepSets:
86-
return utils.OpRepSets(
87-
self.inputs_storage, self.outputs_storage, op_node, texture_limits
88-
)
89+
inputs_storage = self.inputs_storage
90+
outputs_storage = self.outputs_storage
91+
if self.pick_io_storage_fn is not None:
92+
i_storage, o_storage = self.pick_io_storage_fn(op_node)
93+
inputs_storage = utils.TensorRepSetList(i_storage)
94+
outputs_storage = utils.TensorRepSetList(o_storage)
95+
96+
return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits)
8997

9098

9199
#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410418
)
411419
def register_reduce_op():
412420
def check_reduce_node(node: torch.fx.Node) -> bool:
421+
# Only one argument implies that the reduction is over the entire tensor, which
422+
# is not supported yet.
423+
if len(node.args) == 1:
424+
return False
425+
413426
dim_list = node.args[1]
427+
# Only 1D and 2D reductions are supported at the moment.
414428
if isinstance(dim_list, list) and len(dim_list) > 2:
415429
return False
416430

417-
if isinstance(dim_list, list) and len(dim_list) == 2:
418-
# Try to get the memory layout for this node
419-
try:
420-
memory_layout = utils.get_node_memory_layout(node)
421-
422-
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423-
if (
424-
memory_layout is not None
425-
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
426-
):
427-
# For now only default layout is supported for 2D reduction.
428-
# Because we can't determine if the input is NCHW or NHWC here,
429-
# assume the reduction dimension is packed so we cannot support it.
430-
return False
431-
except (AssertionError, KeyError, AttributeError):
432-
# If we can't get memory layout information, we'll assume the dims aren't packed
433-
pass
434-
435431
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
436432
for arg in node.args:
437433
if isinstance(arg, bool):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446442

447443
return True
448444

445+
def pick_io_storage_for_reduce(node: torch.fx.Node):
446+
inputs_storage = utils.ANY_TEXTURE
447+
outputs_storage = utils.ANY_TEXTURE
448+
449+
input_tensor = node.args[0]
450+
ndim = input_tensor.meta["val"].ndim
451+
dim_list = node.args[1]
452+
if isinstance(dim_list, list) and len(dim_list) == 2:
453+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
454+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
455+
456+
possible_packed_dims = {0, 1, 2}
457+
possible_packed_dims.discard(reduce_dim1_whcn)
458+
possible_packed_dims.discard(reduce_dim2_whcn)
459+
460+
packed_dim = possible_packed_dims.pop()
461+
assert packed_dim in [0, 1, 2]
462+
463+
if packed_dim == 0:
464+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
465+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
466+
elif packed_dim == 1:
467+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
468+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
469+
else:
470+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
471+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
472+
473+
return inputs_storage, outputs_storage
474+
449475
return OpFeatures(
450476
inputs_storage=utils.ANY_TEXTURE,
451477
supports_resize=True,
452478
are_node_inputs_supported_fn=check_reduce_node,
479+
pick_io_storage_fn=pick_io_storage_for_reduce,
453480
)
454481

455482

@@ -474,6 +501,23 @@ def register_2d_pool_op():
474501
]
475502
)
476503
def register_convolution_op():
504+
def check_conv_node(node: torch.fx.Node) -> bool:
505+
x = node.args[0]
506+
x_shape = x.meta["val"].size()
507+
# 4-D input implies 2D convolution
508+
if len(x_shape) == 4:
509+
batches = x.meta["val"].size()[0]
510+
if batches != 1:
511+
return False
512+
# 3-D input implies 1D convolution
513+
if len(x_shape) == 3:
514+
transpose = node.args[6]
515+
# Transposed 1D convolution is not supported yet
516+
if transpose:
517+
return False
518+
519+
return True
520+
477521
return OpFeatures(
478522
inputs_storage=[
479523
utils.CHANNELS_PACKED_TEXTURE, # input
@@ -490,6 +534,7 @@ def register_convolution_op():
490534
],
491535
supports_resize=True,
492536
supports_prepacking=True,
537+
are_node_inputs_supported_fn=check_conv_node,
493538
)
494539

495540

@@ -666,6 +711,7 @@ def register_ported_ops_with_prepacking():
666711
return OpFeatures(
667712
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
668713
supports_prepacking=True,
714+
supports_resize=True,
669715
)
670716

671717

@@ -696,6 +742,7 @@ def register_ported_ops_with_prepacking_all_dims():
696742
return OpFeatures(
697743
inputs_storage=utils.ANY_TEXTURE,
698744
supports_prepacking=True,
745+
supports_resize=True,
699746
)
700747

701748

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Partitioner,
3737
PartitionResult,
3838
)
39-
from executorch.exir.backend.utils import tag_constant_data
39+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
4040
from executorch.exir.dialects._ops import ops as exir_ops
4141

4242
from torch.export.exported_program import ExportedProgram
@@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
254254
self.log_skip(node, "permute node of non compatible linear node")
255255
return False
256256

257-
is_in_local_scalar_dense_chain, dst_node_is_compatible = (
258-
self.is_in_local_scalar_dense_chain(node)
259-
)
257+
(
258+
is_in_local_scalar_dense_chain,
259+
dst_node_is_compatible,
260+
) = self.is_in_local_scalar_dense_chain(node)
260261
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
261262
return True
262263
elif is_in_local_scalar_dense_chain:
@@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
419420
logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")
420421

421422
tag_constant_data(exported_program)
423+
tag_mutated_buffer(exported_program)
422424

423425
return PartitionResult(
424426
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/vulkan/patterns/quantized_linear.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
9292
return
9393

9494
# Identify input node
95-
self.fp_input_node, self.quantize_input_node, dq_node = (
96-
utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
97-
)
95+
(
96+
self.fp_input_node,
97+
self.quantize_input_node,
98+
dq_node,
99+
) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
98100
assert self.fp_input_node is not None
99101
self.all_nodes.append(self.fp_input_node)
100102

@@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op(
386388
weight_sums_node = create_constant_placeholder(
387389
exp_program=ep,
388390
graph=graph_module.graph,
389-
kind=InputKind.CONSTANT_TENSOR,
391+
kind=InputKind.PARAMETER,
390392
name=sums_name,
391393
data=sum_per_quant_group,
392394
)
@@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op(
429431
weight_sums_node = create_constant_placeholder(
430432
exp_program=ep,
431433
graph=graph_module.graph,
432-
kind=InputKind.CONSTANT_TENSOR,
434+
kind=InputKind.PARAMETER,
433435
name=sums_name,
434436
data=sum_per_output_channel,
435437
)

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void main() {
6060
int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y;
6161
start.y = ipos.y + num_steps * dilation.y;
6262
}
63-
const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy));
63+
const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy);
6464
// Compute the start of the kernel based on how far we are skipping ahead when
6565
// reading the input. Note that these are "canonical" indices.
6666
ivec2 kstart = (start - ipos) / dilation;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void main() {
5454
// Compute the start and end of the input indices to load. Padding is assumed
5555
// to be constant 0 padding, so reads from the padding region are skipped.
5656
const ivec2 start = ipos;
57-
const ivec2 end = ipos + overlay_region.xy;
57+
const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy);
5858

5959
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
6060
int kx = 0;

0 commit comments

Comments
 (0)