Skip to content

Commit bddb1cc

Browse files
author
ssjia
committed
Update on "[ET-VK] Miscellaneous fixes"
Collecting fixes for various models/ops in this diff/PR. They have all been squashed into this single change to make it easier to cherry pick. # Fixes ## Wav2Letter Type: Output correctness failure This is caused by a bug in swiftshader, and not reproducible on any other platform. Specifically, the issue is in the softmax shader; the exact cause of the issue is unknown, but it is related to using shared memory within shaders. The workaround for this issue is to use separate shared memory arrays for the shared max and shared sum. ## ConvNeXT Type: Exception during runtime This is caused by an incompatible memory layout being used for mean2d. More technically, the packed dimension of the tensor cannot be one of the dims being reduced. The current operator registry system did not have a way to select valid tensor representations based on the actual arguments of an op. To fix, we have to introduce a mechanism for ops to specify valid representations once a node's arguments are known. Once the model is exported with supported memory layout, the model test passes. ## Inception_V3/ViT Type: Exception during runtime The root cause of this was an interaction betwen the fuse batch norm pass and how `vulkan_preprocess.py` was applying passes. Essentially, the fuse batch norm pass creates a new param node for the fused weight, but after the pass is applied `_copy_module` is used to copy the transformed graph back into the ExportedProgram. However, it seems that _copy_module lowercases the node names without updating the exported program's graph signature. Therefore, subsequent passes couldn't recognize the weight tensor of convolution tensors as a constant/parameter node. The solution was to migrate vulkan_preprocess.py to use the _transform() API instead of using _copy_module. ## DenseNet 161 (w/ dynamic shapes) Type: Output Mismatch Cause: the native_batch_norm op doesn't support dynamic shapes. However, the backend test runner doesn't set the correct compile option to filter ops without dynamic shape support. Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/) [ghstack-poisoned]
2 parents 449035d + 6c7e363 commit bddb1cc

File tree

21 files changed

+154
-35
lines changed

21 files changed

+154
-35
lines changed

.github/workflows/pull.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ jobs:
970970
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_model.sh --build
971971
972972
# Test models serially
973-
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4 convnext_small vit"
973+
models="mv2 mv3 edsr resnet18 resnet50 dl3 w2l ic3 ic4 vit"
974974
for model in $models; do
975975
python -m examples.vulkan.export --model_name=$model --test
976976
done

backends/arm/operators/op_bmm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def define_node(
7979
input1_zp = input_qparams[1].get_zp_per_tensor()
8080
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
8181
bmm_output_name = bmm_result.name
82+
elif inputs[0].dtype == ts.DType.INT16:
83+
input_qparams = get_input_qparams(node)
84+
input0_zp = input_qparams[0].get_zp_per_tensor()
85+
input1_zp = input_qparams[1].get_zp_per_tensor()
86+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48)
87+
bmm_output_name = bmm_result.name
8288
else:
8389
bmm_output_name = output.name
8490
input0_zp, input1_zp = 0, 0
@@ -118,3 +124,20 @@ def define_node(
118124
output_zp=[output_qparams.get_zp_per_tensor()],
119125
rounding_mode=RoundingMode.SINGLE_ROUND,
120126
)
127+
elif output.dtype == ts.DType.INT16:
128+
output_qparams = get_output_qparams(node)[0]
129+
final_output_scale = (
130+
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
131+
) / output_qparams.get_scale_per_tensor()
132+
133+
build_rescale(
134+
tosa_fb=tosa_graph,
135+
scale=[final_output_scale],
136+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
137+
input_node=bmm_result, # type: ignore[possibly-undefined]
138+
output_name=output.name,
139+
output_type=ts.DType.INT16,
140+
input_zp=[0],
141+
output_zp=[output_qparams.get_zp_per_tensor()],
142+
rounding_mode=RoundingMode.SINGLE_ROUND,
143+
)

backends/arm/test/ops/test_addmm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False):
213213

214214

215215
@common.parametrize("test_data", test_data_suite)
216-
@pytest.mark.xfail(
217-
reason="missing int16 addmm ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13979"
218-
)
219216
def test_addmm_16a8w_tosa_INT(test_data: input_t1):
220217
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
221218
per_channel_quantization = False
@@ -268,9 +265,6 @@ def test_addmm_16a8w_u55_INT16(test_data: input_t1):
268265

269266
@common.parametrize("test_data", test_data_suite)
270267
@common.XfailIfNoCorstone320
271-
@pytest.mark.xfail(
272-
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
273-
)
274268
def test_addmm_16a8w_u85_INT16(test_data: input_t1):
275269
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
276270
per_channel_quantization = False

backends/cadence/aot/replace_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def replace_logical_nop_where_with_where(
8989

9090
# Get the third arg node and its input
9191
logical_not_node = node.args[0]
92-
logical_not_input_tensor = logical_not_node.args[0].to_tensor()
92+
logical_not_input_node = logical_not_node.args[0]
9393

9494
# If the logical_not input is not a boolean tensor, bail.
95-
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
95+
if logical_not_input_node.meta["val"].dtype != torch.bool:
9696
continue
9797

9898
# Replace the where op with another one, flipping the inputs and using the boolean

backends/cuda/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ runtime.python_library(
66
name = "cuda_backend",
77
srcs = [
88
"cuda_backend.py",
9+
"replace_slice_copy_with_slice.py",
910
],
1011
visibility = [
1112
"//executorch/...",

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def preprocess(
144144
}
145145

146146
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
147-
[SDPBackend.MATH]
147+
[
148+
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
149+
]
148150
), torch.no_grad():
149151
# torch._logging.set_logs(post_grad_graphs=True)
150152
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]

backends/cuda/replace_slice_copy_with_slice.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66

77
# pyre-strict
88

9-
from typing import Iterable
9+
from typing import Dict, Iterable, Tuple
1010

1111
import torch
1212
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import fx
1516

1617

17-
_SLICE_COPY_TARGETS = (
18+
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
1819
torch.ops.aten.slice_copy.Tensor,
1920
ops.edge.aten.slice_copy.Tensor,
2021
)
2122

22-
_SLICE_TARGETS = {
23+
_SLICE_TARGETS: Dict[
24+
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
25+
] = {
2326
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
2427
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
2528
}
@@ -99,8 +102,8 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
99102
return False
100103

101104
def _argument_mutates(
102-
self, schema: torch._C.FunctionSchema, key
103-
) -> bool: # pyre-ignore[11]
105+
self, schema: torch._C.FunctionSchema, key: int | str
106+
) -> bool:
104107
arguments = schema.arguments
105108
if isinstance(key, int):
106109
if key >= len(arguments):

backends/cuda/tests/test_cuda_export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11+
from executorch.backends.cuda.cuda_backend import CudaBackend
1112
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1213
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
1314
from torch.export import export
@@ -30,7 +31,9 @@ def _export_to_cuda_with_lower(
3031
exported_program = export(module, inputs, strict=True)
3132

3233
# Create partitioner and compile specs
33-
partitioner = CudaPartitioner([])
34+
partitioner = CudaPartitioner(
35+
[CudaBackend.generate_method_name_compile_spec("forward")]
36+
)
3437

3538
# Use to_edge_transform_and_lower for complete pipeline
3639
edge_program_manager = to_edge_transform_and_lower(

backends/qualcomm/quantizer/qconfig.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200200
act_observer=MovingAverageMinMaxObserver,
201201
) -> QuantizationConfig:
202202
extra_args: Dict[str, Any] = {"eps": 2**-20}
203-
act_fake_quant_ctr = FakeQuantize.with_args(
203+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
204204
dtype=torch.int32,
205205
quant_min=torch.iinfo(torch.uint16).min,
206206
quant_max=torch.iinfo(torch.uint16).max,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398398
def get_8a8w_qnn_qat_config(
399399
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
400400
) -> QuantizationConfig:
401-
act_fake_quant_ctr = FakeQuantize.with_args(
401+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
402402
dtype=torch.uint8,
403403
qscheme=(
404404
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458458
def get_16a4w_qnn_qat_config(
459459
act_observer=MovingAverageMinMaxObserver,
460460
) -> QuantizationConfig:
461-
act_fake_quant_ctr = FakeQuantize.with_args(
461+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
462462
dtype=torch.int32,
463463
quant_min=torch.iinfo(torch.uint16).min,
464464
quant_max=torch.iinfo(torch.uint16).max,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541541
# If zero_point is 128, htp can do optimizations.
542542
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543543
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544-
act_fake_quant_ctr = FakeQuantize.with_args(
544+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
545545
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
546546
qscheme=torch.per_tensor_symmetric,
547547
observer=act_observer,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553553
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554554
)
555555
else:
556-
act_fake_quant_ctr = FakeQuantize.with_args(
556+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
557557
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558558
quant_min=torch.iinfo(act_dtype).min,
559559
quant_max=torch.iinfo(act_dtype).max,

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ def get_arg_tensor_source_repset(
230230
"""
231231
arg_node = op_node.args[arg_i]
232232

233+
# For non-tensor arguments, return ANY_STORAGE
234+
if not utils.is_tensor_arg_node(arg_node):
235+
return utils.ANY_STORAGE
236+
233237
# Special case for cat - use the first tensor in the list as representative
234238
if isinstance(arg_node, list):
235239
arg_node = arg_node[0]

0 commit comments

Comments
 (0)