Skip to content

Commit 809924c

Browse files
authored
Merge branch 'release/1.0' into cherry-pick-14825-by-pytorch_bot_bot_
2 parents 6fbbb6e + 73e2346 commit 809924c

File tree

88 files changed

+1500
-428
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+1500
-428
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,11 +958,16 @@ jobs:
958958
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
959959
960960
# Test models serially
961-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
961+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4"
962962
for model in $models; do
963963
python -m examples.vulkan.export --model_name=$model --test
964964
done
965965
966+
# For selected vision models, test with dynamic shapes
967+
models="mv2 resnet18 resnet50 ic3 densenet161"
968+
for model in $models; do
969+
python -m examples.vulkan.export --model_name=$model --test -d
970+
done
966971
967972
test-vulkan-operators-linux:
968973
name: test-vulkan-operators-linux

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ def call(self, graph_module: GraphModule) -> PassResult:
6868
node for node in partition.nodes if node.target in matmul_targets
6969
][0]
7070

71-
if quantized_input:
71+
if quantized_input and not all(
72+
input_node.target in DQ_OPS
73+
for input_node in matmul_node.all_input_nodes
74+
):
7275
matmul_args = matmul_node.all_input_nodes
7376
for node in matmul_args:
7477
# Find the dq-node connected to this mm/bmm arg
@@ -94,7 +97,9 @@ def call(self, graph_module: GraphModule) -> PassResult:
9497

9598
partition_output = list(partition.output_nodes[0].users)[0]
9699
quantized_output = partition_output.target in Q_OPS
97-
if quantized_output:
100+
if quantized_output and not all(
101+
user.target in Q_OPS for user in matmul_node.users
102+
):
98103
with graph_module.graph.inserting_after(matmul_node):
99104
# Create q-node after matmul
100105
q_node = create_node(

backends/arm/test/ops/test_matmul.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
class MatMul(torch.nn.Module):
2424
test_data_generators = {
25+
"rand_rand_2d": lambda: (torch.rand(5, 5), torch.rand(5, 2)),
2526
"rand_rand_3d": lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
2627
"rand_rand_4d": lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
2728
}
@@ -32,6 +33,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
3233

3334
class MatMulSingleInput(torch.nn.Module):
3435
test_data_generators = {
36+
"rand_2d": lambda: (torch.rand(5, 5),),
3537
"rand_3d": lambda: (torch.rand(2, 5, 5),),
3638
"rand_4d": lambda: (torch.rand(1, 2, 5, 5),),
3739
}
@@ -42,6 +44,11 @@ def forward(self, x: torch.Tensor):
4244

4345
class MatMulCombo(torch.nn.Module):
4446
test_data_generators = {
47+
"rand_rand_rand_2d": lambda: (
48+
torch.rand(5, 5),
49+
torch.rand(5, 2),
50+
torch.rand(2, 5),
51+
),
4552
"rand_rand_rand_3d": lambda: (
4653
torch.rand(2, 5, 5),
4754
torch.rand(2, 5, 2),

backends/arm/test/tester/arm_tester.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,10 @@ def run_method_and_compare_outputs(
458458
for run_iteration in range(num_runs):
459459
reference_input = inputs if inputs else next(self.generate_random_inputs())
460460

461+
# Avoid issues with inplace operators
462+
test_input = copy.deepcopy(reference_input)
463+
original_input = copy.deepcopy(reference_input)
464+
461465
input_shapes = [
462466
generated_input.shape if hasattr(generated_input, "shape") else (1,)
463467
for generated_input in reference_input
@@ -472,16 +476,16 @@ def run_method_and_compare_outputs(
472476
# Run exported module directly
473477
test_outputs, _ = pytree.tree_flatten(
474478
self._calculate_reference_output(
475-
exported_program.module(), reference_input
479+
exported_program.module(), test_input
476480
)
477481
)
478482
else:
479483
# Run lowered model with target
480484
test_outputs, _ = pytree.tree_flatten(
481-
test_stage.run_artifact(reference_input)
485+
test_stage.run_artifact(test_input)
482486
)
483487

484-
logger.info(f"\n Input: {reference_input}")
488+
logger.info(f"\n Input: {original_input}")
485489
logger.info(f"\n Ref output: {reference_outputs}")
486490
logger.info(f"\nTest output: {test_outputs}")
487491

backends/vulkan/_passes/fold_qdq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass):
1717
valid quant op patterns have already been fused before this pass.
1818
"""
1919

20-
def __init__(self, edge_program: torch.export.ExportedProgram):
21-
super(FoldQDQPass, self).__init__()
22-
self.edge_program = edge_program
20+
def __init__(self):
21+
super().__init__()
2322

2423
def call(self, graph_module: torch.fx.GraphModule):
2524
for node in graph_module.graph.nodes:

backends/vulkan/_passes/fuse_patterns.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810

911
import torch
@@ -13,13 +15,15 @@
1315

1416

1517
class FusePatternsPass(ExportPass):
16-
def __init__(self, exported_program: ExportedProgram) -> None:
18+
def __init__(self) -> None:
1719
super().__init__()
18-
self.program = exported_program
20+
self._exported_program: Optional[ExportedProgram] = None
1921

2022
def call(self, graph_module: torch.fx.GraphModule):
23+
assert self._exported_program is not None
24+
2125
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
22-
self.program, graph_module
26+
self._exported_program, graph_module
2327
)
2428

2529
if total_replaced > 0:

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node(
211211

212212

213213
class FuseQuantizedOpsTransform(ExportPass):
214-
def __init__(self, exported_program: ExportedProgram) -> None:
214+
def __init__(self) -> None:
215215
super().__init__()
216-
self.program = exported_program
216+
self._exported_program: Optional[ExportedProgram] = None
217217

218218
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219+
assert self._exported_program is not None
220+
219221
for node in graph_module.graph.nodes:
220222
# Check for linear_qcnw pattern (weight-only quantization)
221-
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
223+
qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node)
222224
if qcnw_details is not None:
223225
qcnw_method, qcnw_nbits = qcnw_details
224226
fuse_into_linear_qcnw_node(
225-
self.program, graph_module, node, qcnw_method, qcnw_nbits
227+
self._exported_program, graph_module, node, qcnw_method, qcnw_nbits
226228
)
227229
continue
228230

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def get_arg_tensor_source_repset(
230230
"""
231231
arg_node = op_node.args[arg_i]
232232

233+
# For non-tensor arguments, return ANY_STORAGE
234+
if not utils.is_tensor_arg_node(arg_node):
235+
return utils.ANY_STORAGE
236+
233237
# Special case for cat - use the first tensor in the list as representative
234238
if isinstance(arg_node, list):
235239
arg_node = arg_node[0]

backends/vulkan/op_registry.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import torch
1818

19-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
20-
2119
from executorch.exir.dialects._ops import ops as exir_ops
2220

2321
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
4846
# Optional check function used during partitioning to determine if a node's
4947
# inputs are supported by the operator implementation.
5048
"are_node_inputs_supported_fn",
49+
# Optional function to determine valid representation sets for input and outputs
50+
# once a node's actual inputs are known.
51+
"pick_io_storage_fn",
5152
]
5253

5354
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
supports_resize: bool = False,
6263
supports_prepacking: bool = False,
6364
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
65+
pick_io_storage_fn: Optional[Callable] = None,
6466
):
6567
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
6668
inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
7779
self.supports_prepacking = supports_prepacking
7880

7981
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
82+
self.pick_io_storage_fn = pick_io_storage_fn
8083

8184
def make_op_repsets(
8285
self,
8386
op_node: torch.fx.Node,
8487
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
8588
) -> utils.OpRepSets:
86-
return utils.OpRepSets(
87-
self.inputs_storage, self.outputs_storage, op_node, texture_limits
88-
)
89+
inputs_storage = self.inputs_storage
90+
outputs_storage = self.outputs_storage
91+
if self.pick_io_storage_fn is not None:
92+
i_storage, o_storage = self.pick_io_storage_fn(op_node)
93+
inputs_storage = utils.TensorRepSetList(i_storage)
94+
outputs_storage = utils.TensorRepSetList(o_storage)
95+
96+
return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits)
8997

9098

9199
#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410418
)
411419
def register_reduce_op():
412420
def check_reduce_node(node: torch.fx.Node) -> bool:
421+
# Only one argument implies that the reduction is over the entire tensor, which
422+
# is not supported yet.
423+
if len(node.args) == 1:
424+
return False
425+
413426
dim_list = node.args[1]
427+
# Only 1D and 2D reductions are supported at the moment.
414428
if isinstance(dim_list, list) and len(dim_list) > 2:
415429
return False
416430

417-
if isinstance(dim_list, list) and len(dim_list) == 2:
418-
# Try to get the memory layout for this node
419-
try:
420-
memory_layout = utils.get_node_memory_layout(node)
421-
422-
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423-
if (
424-
memory_layout is not None
425-
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
426-
):
427-
# For now only default layout is supported for 2D reduction.
428-
# Because we can't determine if the input is NCHW or NHWC here,
429-
# assume the reduction dimension is packed so we cannot support it.
430-
return False
431-
except (AssertionError, KeyError, AttributeError):
432-
# If we can't get memory layout information, we'll assume the dims aren't packed
433-
pass
434-
435431
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
436432
for arg in node.args:
437433
if isinstance(arg, bool):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446442

447443
return True
448444

445+
def pick_io_storage_for_reduce(node: torch.fx.Node):
446+
inputs_storage = utils.ANY_TEXTURE
447+
outputs_storage = utils.ANY_TEXTURE
448+
449+
input_tensor = node.args[0]
450+
ndim = input_tensor.meta["val"].ndim
451+
dim_list = node.args[1]
452+
if isinstance(dim_list, list) and len(dim_list) == 2:
453+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
454+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
455+
456+
possible_packed_dims = {0, 1, 2}
457+
possible_packed_dims.discard(reduce_dim1_whcn)
458+
possible_packed_dims.discard(reduce_dim2_whcn)
459+
460+
packed_dim = possible_packed_dims.pop()
461+
assert packed_dim in [0, 1, 2]
462+
463+
if packed_dim == 0:
464+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
465+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
466+
elif packed_dim == 1:
467+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
468+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
469+
else:
470+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
471+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
472+
473+
return inputs_storage, outputs_storage
474+
449475
return OpFeatures(
450476
inputs_storage=utils.ANY_TEXTURE,
451477
supports_resize=True,
452478
are_node_inputs_supported_fn=check_reduce_node,
479+
pick_io_storage_fn=pick_io_storage_for_reduce,
453480
)
454481

455482

@@ -474,6 +501,23 @@ def register_2d_pool_op():
474501
]
475502
)
476503
def register_convolution_op():
504+
def check_conv_node(node: torch.fx.Node) -> bool:
505+
x = node.args[0]
506+
x_shape = x.meta["val"].size()
507+
# 4-D input implies 2D convolution
508+
if len(x_shape) == 4:
509+
batches = x.meta["val"].size()[0]
510+
if batches != 1:
511+
return False
512+
# 3-D input implies 1D convolution
513+
if len(x_shape) == 3:
514+
transpose = node.args[6]
515+
# Transposed 1D convolution is not supported yet
516+
if transpose:
517+
return False
518+
519+
return True
520+
477521
return OpFeatures(
478522
inputs_storage=[
479523
utils.CHANNELS_PACKED_TEXTURE, # input
@@ -490,6 +534,7 @@ def register_convolution_op():
490534
],
491535
supports_resize=True,
492536
supports_prepacking=True,
537+
are_node_inputs_supported_fn=check_conv_node,
493538
)
494539

495540

@@ -666,6 +711,7 @@ def register_ported_ops_with_prepacking():
666711
return OpFeatures(
667712
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
668713
supports_prepacking=True,
714+
supports_resize=True,
669715
)
670716

671717

@@ -696,6 +742,7 @@ def register_ported_ops_with_prepacking_all_dims():
696742
return OpFeatures(
697743
inputs_storage=utils.ANY_TEXTURE,
698744
supports_prepacking=True,
745+
supports_resize=True,
699746
)
700747

701748

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Partitioner,
3737
PartitionResult,
3838
)
39-
from executorch.exir.backend.utils import tag_constant_data
39+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
4040
from executorch.exir.dialects._ops import ops as exir_ops
4141

4242
from torch.export.exported_program import ExportedProgram
@@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
254254
self.log_skip(node, "permute node of non compatible linear node")
255255
return False
256256

257-
is_in_local_scalar_dense_chain, dst_node_is_compatible = (
258-
self.is_in_local_scalar_dense_chain(node)
259-
)
257+
(
258+
is_in_local_scalar_dense_chain,
259+
dst_node_is_compatible,
260+
) = self.is_in_local_scalar_dense_chain(node)
260261
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
261262
return True
262263
elif is_in_local_scalar_dense_chain:
@@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
419420
logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")
420421

421422
tag_constant_data(exported_program)
423+
tag_mutated_buffer(exported_program)
422424

423425
return PartitionResult(
424426
tagged_exported_program=exported_program, partition_tags=partition_tags

0 commit comments

Comments
 (0)