Skip to content

Commit 3f0896a

Browse files
authored
[ET-VK] Miscellaneous fixes (#14801)
Collecting fixes for various models/ops in this diff/PR. They have all been squashed into this single change to make it easier to cherry pick. # Fixes ## Wav2Letter Type: Output correctness failure This is caused by a bug in swiftshader, and not reproducible on any other platform. Specifically, the issue is in the softmax shader; the exact cause of the issue is unknown, but it is related to using shared memory within shaders. The workaround for this issue is to use separate shared memory arrays for the shared max and shared sum. ## ConvNeXT Type: Exception during runtime This is caused by an incompatible memory layout being used for mean2d. More technically, the packed dimension of the tensor cannot be one of the dims being reduced. The current operator registry system did not have a way to select valid tensor representations based on the actual arguments of an op. To fix, we have to introduce a mechanism for ops to specify valid representations once a node's arguments are known. Once the model is exported with supported memory layout, the model test passes. ## Inception_V3/ViT Type: Exception during runtime The root cause of this was an interaction betwen the fuse batch norm pass and how `vulkan_preprocess.py` was applying passes. Essentially, the fuse batch norm pass creates a new param node for the fused weight, but after the pass is applied `_copy_module` is used to copy the transformed graph back into the ExportedProgram. However, it seems that _copy_module lowercases the node names without updating the exported program's graph signature. Therefore, subsequent passes couldn't recognize the weight tensor of convolution tensors as a constant/parameter node. The solution was to migrate vulkan_preprocess.py to use the _transform() API instead of using _copy_module. ## DenseNet 161 (w/ dynamic shapes) Type: Output Mismatch Cause: the native_batch_norm op doesn't support dynamic shapes. However, the backend test runner doesn't set the correct compile option to filter ops without dynamic shape support. Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/) [ghstack-poisoned]
1 parent ca9fc06 commit 3f0896a

23 files changed

+298
-204
lines changed

.github/workflows/pull.yml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,11 +970,16 @@ jobs:
970970
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
971971
972972
# Test models serially
973-
models="mv2 mv3 edsr resnet18 resnet50 dl3"
973+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4"
974974
for model in $models; do
975975
python -m examples.vulkan.export --model_name=$model --test
976976
done
977977
978+
# For selected vision models, test with dynamic shapes
979+
models="mv2 resnet18 resnet50 ic3 densenet161"
980+
for model in $models; do
981+
python -m examples.vulkan.export --model_name=$model --test -d
982+
done
978983
979984
test-vulkan-operators-linux:
980985
name: test-vulkan-operators-linux

backends/vulkan/_passes/fold_qdq.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@ class FoldQDQPass(ExportPass):
1717
valid quant op patterns have already been fused before this pass.
1818
"""
1919

20-
def __init__(self, edge_program: torch.export.ExportedProgram):
21-
super(FoldQDQPass, self).__init__()
22-
self.edge_program = edge_program
20+
def __init__(self):
21+
super().__init__()
2322

2423
def call(self, graph_module: torch.fx.GraphModule):
2524
for node in graph_module.graph.nodes:

backends/vulkan/_passes/fuse_patterns.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from typing import Optional
8+
79
import executorch.backends.vulkan.patterns as vk_patterns
810

911
import torch
@@ -13,13 +15,15 @@
1315

1416

1517
class FusePatternsPass(ExportPass):
16-
def __init__(self, exported_program: ExportedProgram) -> None:
18+
def __init__(self) -> None:
1719
super().__init__()
18-
self.program = exported_program
20+
self._exported_program: Optional[ExportedProgram] = None
1921

2022
def call(self, graph_module: torch.fx.GraphModule):
23+
assert self._exported_program is not None
24+
2125
total_replaced = vk_patterns.replace_all_fusable_subgraphs(
22-
self.program, graph_module
26+
self._exported_program, graph_module
2327
)
2428

2529
if total_replaced > 0:

backends/vulkan/_passes/fuse_quantized_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,18 +211,20 @@ def fuse_into_linear_qcnw_node(
211211

212212

213213
class FuseQuantizedOpsTransform(ExportPass):
214-
def __init__(self, exported_program: ExportedProgram) -> None:
214+
def __init__(self) -> None:
215215
super().__init__()
216-
self.program = exported_program
216+
self._exported_program: Optional[ExportedProgram] = None
217217

218218
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
219+
assert self._exported_program is not None
220+
219221
for node in graph_module.graph.nodes:
220222
# Check for linear_qcnw pattern (weight-only quantization)
221-
qcnw_details = matches_linear_qcnw_pattern(self.program, node)
223+
qcnw_details = matches_linear_qcnw_pattern(self._exported_program, node)
222224
if qcnw_details is not None:
223225
qcnw_method, qcnw_nbits = qcnw_details
224226
fuse_into_linear_qcnw_node(
225-
self.program, graph_module, node, qcnw_method, qcnw_nbits
227+
self._exported_program, graph_module, node, qcnw_method, qcnw_nbits
226228
)
227229
continue
228230

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def get_arg_tensor_source_repset(
230230
"""
231231
arg_node = op_node.args[arg_i]
232232

233+
# For non-tensor arguments, return ANY_STORAGE
234+
if not utils.is_tensor_arg_node(arg_node):
235+
return utils.ANY_STORAGE
236+
233237
# Special case for cat - use the first tensor in the list as representative
234238
if isinstance(arg_node, list):
235239
arg_node = arg_node[0]

backends/vulkan/op_registry.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
import torch
1818

19-
from executorch.backends.vulkan.serialization.vulkan_graph_schema import VkMemoryLayout
20-
2119
from executorch.exir.dialects._ops import ops as exir_ops
2220

2321
from executorch.exir.dialects.edge._ops import EdgeOpOverload
@@ -48,6 +46,9 @@ class OpFeatures:
4846
# Optional check function used during partitioning to determine if a node's
4947
# inputs are supported by the operator implementation.
5048
"are_node_inputs_supported_fn",
49+
# Optional function to determine valid representation sets for input and outputs
50+
# once a node's actual inputs are known.
51+
"pick_io_storage_fn",
5152
]
5253

5354
def __init__(
@@ -61,6 +62,7 @@ def __init__(
6162
supports_resize: bool = False,
6263
supports_prepacking: bool = False,
6364
are_node_inputs_supported_fn: Optional[Callable] = allow_node,
65+
pick_io_storage_fn: Optional[Callable] = None,
6466
):
6567
self.inputs_storage: utils.TensorRepSetList = utils.TensorRepSetList(
6668
inputs_storage if inputs_storage is not None else []
@@ -77,15 +79,21 @@ def __init__(
7779
self.supports_prepacking = supports_prepacking
7880

7981
self.are_node_inputs_supported_fn = are_node_inputs_supported_fn
82+
self.pick_io_storage_fn = pick_io_storage_fn
8083

8184
def make_op_repsets(
8285
self,
8386
op_node: torch.fx.Node,
8487
texture_limits: utils.ImageExtents = utils.DEFAULT_TEXTURE_LIMITS,
8588
) -> utils.OpRepSets:
86-
return utils.OpRepSets(
87-
self.inputs_storage, self.outputs_storage, op_node, texture_limits
88-
)
89+
inputs_storage = self.inputs_storage
90+
outputs_storage = self.outputs_storage
91+
if self.pick_io_storage_fn is not None:
92+
i_storage, o_storage = self.pick_io_storage_fn(op_node)
93+
inputs_storage = utils.TensorRepSetList(i_storage)
94+
outputs_storage = utils.TensorRepSetList(o_storage)
95+
96+
return utils.OpRepSets(inputs_storage, outputs_storage, op_node, texture_limits)
8997

9098

9199
#######################
@@ -410,28 +418,16 @@ def register_softmax_op():
410418
)
411419
def register_reduce_op():
412420
def check_reduce_node(node: torch.fx.Node) -> bool:
421+
# Only one argument implies that the reduction is over the entire tensor, which
422+
# is not supported yet.
423+
if len(node.args) == 1:
424+
return False
425+
413426
dim_list = node.args[1]
427+
# Only 1D and 2D reductions are supported at the moment.
414428
if isinstance(dim_list, list) and len(dim_list) > 2:
415429
return False
416430

417-
if isinstance(dim_list, list) and len(dim_list) == 2:
418-
# Try to get the memory layout for this node
419-
try:
420-
memory_layout = utils.get_node_memory_layout(node)
421-
422-
# If we have memory layout information, check if any dimension in dim_list corresponds to a packed dimension
423-
if (
424-
memory_layout is not None
425-
and memory_layout != VkMemoryLayout.DEFAULT_LAYOUT
426-
):
427-
# For now only default layout is supported for 2D reduction.
428-
# Because we can't determine if the input is NCHW or NHWC here,
429-
# assume the reduction dimension is packed so we cannot support it.
430-
return False
431-
except (AssertionError, KeyError, AttributeError):
432-
# If we can't get memory layout information, we'll assume the dims aren't packed
433-
pass
434-
435431
def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
436432
for arg in node.args:
437433
if isinstance(arg, bool):
@@ -446,10 +442,41 @@ def try_find_keepdim_arg(node: torch.fx.Node) -> bool:
446442

447443
return True
448444

445+
def pick_io_storage_for_reduce(node: torch.fx.Node):
446+
inputs_storage = utils.ANY_TEXTURE
447+
outputs_storage = utils.ANY_TEXTURE
448+
449+
input_tensor = node.args[0]
450+
ndim = input_tensor.meta["val"].ndim
451+
dim_list = node.args[1]
452+
if isinstance(dim_list, list) and len(dim_list) == 2:
453+
reduce_dim1_whcn = utils.nchw_dim_to_whcn_dim(dim_list[0], ndim)
454+
reduce_dim2_whcn = utils.nchw_dim_to_whcn_dim(dim_list[1], ndim)
455+
456+
possible_packed_dims = {0, 1, 2}
457+
possible_packed_dims.discard(reduce_dim1_whcn)
458+
possible_packed_dims.discard(reduce_dim2_whcn)
459+
460+
packed_dim = possible_packed_dims.pop()
461+
assert packed_dim in [0, 1, 2]
462+
463+
if packed_dim == 0:
464+
inputs_storage = utils.WIDTH_PACKED_TEXTURE
465+
outputs_storage = utils.WIDTH_PACKED_TEXTURE
466+
elif packed_dim == 1:
467+
inputs_storage = utils.HEIGHT_PACKED_TEXTURE
468+
outputs_storage = utils.HEIGHT_PACKED_TEXTURE
469+
else:
470+
inputs_storage = utils.CHANNELS_PACKED_TEXTURE
471+
outputs_storage = utils.CHANNELS_PACKED_TEXTURE
472+
473+
return inputs_storage, outputs_storage
474+
449475
return OpFeatures(
450476
inputs_storage=utils.ANY_TEXTURE,
451477
supports_resize=True,
452478
are_node_inputs_supported_fn=check_reduce_node,
479+
pick_io_storage_fn=pick_io_storage_for_reduce,
453480
)
454481

455482

@@ -474,6 +501,23 @@ def register_2d_pool_op():
474501
]
475502
)
476503
def register_convolution_op():
504+
def check_conv_node(node: torch.fx.Node) -> bool:
505+
x = node.args[0]
506+
x_shape = x.meta["val"].size()
507+
# 4-D input implies 2D convolution
508+
if len(x_shape) == 4:
509+
batches = x.meta["val"].size()[0]
510+
if batches != 1:
511+
return False
512+
# 3-D input implies 1D convolution
513+
if len(x_shape) == 3:
514+
transpose = node.args[6]
515+
# Transposed 1D convolution is not supported yet
516+
if transpose:
517+
return False
518+
519+
return True
520+
477521
return OpFeatures(
478522
inputs_storage=[
479523
utils.CHANNELS_PACKED_TEXTURE, # input
@@ -490,6 +534,7 @@ def register_convolution_op():
490534
],
491535
supports_resize=True,
492536
supports_prepacking=True,
537+
are_node_inputs_supported_fn=check_conv_node,
493538
)
494539

495540

@@ -716,6 +761,7 @@ def register_ported_ops_with_prepacking():
716761
return OpFeatures(
717762
inputs_storage=utils.CHANNELS_PACKED_TEXTURE,
718763
supports_prepacking=True,
764+
supports_resize=True,
719765
)
720766

721767

@@ -746,6 +792,7 @@ def register_ported_ops_with_prepacking_all_dims():
746792
return OpFeatures(
747793
inputs_storage=utils.ANY_TEXTURE,
748794
supports_prepacking=True,
795+
supports_resize=True,
749796
)
750797

751798

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Partitioner,
3737
PartitionResult,
3838
)
39-
from executorch.exir.backend.utils import tag_constant_data
39+
from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer
4040
from executorch.exir.dialects._ops import ops as exir_ops
4141

4242
from torch.export.exported_program import ExportedProgram
@@ -254,9 +254,10 @@ def _is_node_supported(self, node: torch.fx.Node) -> bool: # noqa: C901
254254
self.log_skip(node, "permute node of non compatible linear node")
255255
return False
256256

257-
is_in_local_scalar_dense_chain, dst_node_is_compatible = (
258-
self.is_in_local_scalar_dense_chain(node)
259-
)
257+
(
258+
is_in_local_scalar_dense_chain,
259+
dst_node_is_compatible,
260+
) = self.is_in_local_scalar_dense_chain(node)
260261
if is_in_local_scalar_dense_chain and dst_node_is_compatible:
261262
return True
262263
elif is_in_local_scalar_dense_chain:
@@ -419,6 +420,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
419420
logger.info(f"Found {pl} Vulkan subgraphs to be partitioned.")
420421

421422
tag_constant_data(exported_program)
423+
tag_mutated_buffer(exported_program)
422424

423425
return PartitionResult(
424426
tagged_exported_program=exported_program, partition_tags=partition_tags

backends/vulkan/patterns/quantized_linear.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,11 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
9292
return
9393

9494
# Identify input node
95-
self.fp_input_node, self.quantize_input_node, dq_node = (
96-
utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
97-
)
95+
(
96+
self.fp_input_node,
97+
self.quantize_input_node,
98+
dq_node,
99+
) = utils.maybe_skip_q_dq_arg_chain(self.anchor_node.args[0])
98100
assert self.fp_input_node is not None
99101
self.all_nodes.append(self.fp_input_node)
100102

@@ -386,7 +388,7 @@ def make_linear_dq8ca_q4gsw_op(
386388
weight_sums_node = create_constant_placeholder(
387389
exp_program=ep,
388390
graph=graph_module.graph,
389-
kind=InputKind.CONSTANT_TENSOR,
391+
kind=InputKind.PARAMETER,
390392
name=sums_name,
391393
data=sum_per_quant_group,
392394
)
@@ -429,7 +431,7 @@ def make_linear_q8ta_q8csw_custom_op(
429431
weight_sums_node = create_constant_placeholder(
430432
exp_program=ep,
431433
graph=graph_module.graph,
432-
kind=InputKind.CONSTANT_TENSOR,
434+
kind=InputKind.PARAMETER,
433435
name=sums_name,
434436
data=sum_per_output_channel,
435437
)

backends/vulkan/runtime/graph/ops/glsl/conv2d.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void main() {
6060
int num_steps = ((-ipos.y) + dilation.y - 1) / dilation.y;
6161
start.y = ipos.y + num_steps * dilation.y;
6262
}
63-
const ivec2 end = min(ipos + overlay_region.xy, ivec2(in_sizes.xy));
63+
const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy);
6464
// Compute the start of the kernel based on how far we are skipping ahead when
6565
// reading the input. Note that these are "canonical" indices.
6666
ivec2 kstart = (start - ipos) / dilation;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ void main() {
5454
// Compute the start and end of the input indices to load. Padding is assumed
5555
// to be constant 0 padding, so reads from the padding region are skipped.
5656
const ivec2 start = ipos;
57-
const ivec2 end = ipos + overlay_region.xy;
57+
const ivec2 end = min(ipos + overlay_region.xy, in_sizes.xy);
5858

5959
VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
6060
int kx = 0;

0 commit comments

Comments
 (0)