Skip to content

Commit 6c7e363

Browse files
author
ssjia
committed
Update base for Update on "[ET-VK] Miscellaneous fixes"
Collecting fixes for various models/ops in this diff/PR. They have all been squashed into this single change to make it easier to cherry pick. # Fixes ## Wav2Letter Type: Output correctness failure This is caused by a bug in swiftshader, and not reproducible on any other platform. Specifically, the issue is in the softmax shader; the exact cause of the issue is unknown, but it is related to using shared memory within shaders. The workaround for this issue is to use separate shared memory arrays for the shared max and shared sum. ## ConvNeXT Type: Exception during runtime This is caused by an incompatible memory layout being used for mean2d. More technically, the packed dimension of the tensor cannot be one of the dims being reduced. The current operator registry system did not have a way to select valid tensor representations based on the actual arguments of an op. To fix, we have to introduce a mechanism for ops to specify valid representations once a node's arguments are known. Once the model is exported with supported memory layout, the model test passes. ## Inception_V3/ViT Type: Exception during runtime The root cause of this was an interaction betwen the fuse batch norm pass and how `vulkan_preprocess.py` was applying passes. Essentially, the fuse batch norm pass creates a new param node for the fused weight, but after the pass is applied `_copy_module` is used to copy the transformed graph back into the ExportedProgram. However, it seems that _copy_module lowercases the node names without updating the exported program's graph signature. Therefore, subsequent passes couldn't recognize the weight tensor of convolution tensors as a constant/parameter node. The solution was to migrate vulkan_preprocess.py to use the _transform() API instead of using _copy_module. ## DenseNet 161 (w/ dynamic shapes) Type: Output Mismatch Cause: the native_batch_norm op doesn't support dynamic shapes. However, the backend test runner doesn't set the correct compile option to filter ops without dynamic shape support. Differential Revision: [D83703496](https://our.internmc.facebook.com/intern/diff/D83703496/) [ghstack-poisoned]
2 parents f0480bb + 70ea661 commit 6c7e363

File tree

16 files changed

+122
-32
lines changed

16 files changed

+122
-32
lines changed

backends/arm/operators/op_bmm.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ def define_node(
7979
input1_zp = input_qparams[1].get_zp_per_tensor()
8080
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT32)
8181
bmm_output_name = bmm_result.name
82+
elif inputs[0].dtype == ts.DType.INT16:
83+
input_qparams = get_input_qparams(node)
84+
input0_zp = input_qparams[0].get_zp_per_tensor()
85+
input1_zp = input_qparams[1].get_zp_per_tensor()
86+
bmm_result = tosa_graph.addIntermediate(output.shape, ts.DType.INT48)
87+
bmm_output_name = bmm_result.name
8288
else:
8389
bmm_output_name = output.name
8490
input0_zp, input1_zp = 0, 0
@@ -118,3 +124,20 @@ def define_node(
118124
output_zp=[output_qparams.get_zp_per_tensor()],
119125
rounding_mode=RoundingMode.SINGLE_ROUND,
120126
)
127+
elif output.dtype == ts.DType.INT16:
128+
output_qparams = get_output_qparams(node)[0]
129+
final_output_scale = (
130+
input_qparams[0].get_scale_per_tensor() * input_qparams[1].get_scale_per_tensor() # type: ignore[possibly-undefined] # pyre-ignore[61]
131+
) / output_qparams.get_scale_per_tensor()
132+
133+
build_rescale(
134+
tosa_fb=tosa_graph,
135+
scale=[final_output_scale],
136+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
137+
input_node=bmm_result, # type: ignore[possibly-undefined]
138+
output_name=output.name,
139+
output_type=ts.DType.INT16,
140+
input_zp=[0],
141+
output_zp=[output_qparams.get_zp_per_tensor()],
142+
rounding_mode=RoundingMode.SINGLE_ROUND,
143+
)

backends/arm/test/ops/test_addmm.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,6 @@ def get_symmetric_a16w8_addmm_quantizer(per_channel_quantization=False):
213213

214214

215215
@common.parametrize("test_data", test_data_suite)
216-
@pytest.mark.xfail(
217-
reason="missing int16 addmm ops support; fails at TOSA reference model with Unsupported operation type or rank. See: https://github.com/pytorch/executorch/issues/13979"
218-
)
219216
def test_addmm_16a8w_tosa_INT(test_data: input_t1):
220217
"""Test addmm (FC layer) operation with 16A8W quantization (16-bit activations, 8-bit weights)"""
221218
per_channel_quantization = False
@@ -268,9 +265,6 @@ def test_addmm_16a8w_u55_INT16(test_data: input_t1):
268265

269266
@common.parametrize("test_data", test_data_suite)
270267
@common.XfailIfNoCorstone320
271-
@pytest.mark.xfail(
272-
reason="Vela compilation fails with 'Invalid arguments' for int16 addmm operations"
273-
)
274268
def test_addmm_16a8w_u85_INT16(test_data: input_t1):
275269
"""Test addmm (FC layer) operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)"""
276270
per_channel_quantization = False

backends/cadence/aot/replace_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def replace_logical_nop_where_with_where(
8989

9090
# Get the third arg node and its input
9191
logical_not_node = node.args[0]
92-
logical_not_input_tensor = logical_not_node.args[0].to_tensor()
92+
logical_not_input_node = logical_not_node.args[0]
9393

9494
# If the logical_not input is not a boolean tensor, bail.
95-
if logical_not_input_tensor.meta["spec"].dtype != torch.bool:
95+
if logical_not_input_node.meta["val"].dtype != torch.bool:
9696
continue
9797

9898
# Replace the where op with another one, flipping the inputs and using the boolean

backends/cuda/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ runtime.python_library(
66
name = "cuda_backend",
77
srcs = [
88
"cuda_backend.py",
9+
"replace_slice_copy_with_slice.py",
910
],
1011
visibility = [
1112
"//executorch/...",

backends/cuda/cuda_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,9 @@ def preprocess(
144144
}
145145

146146
with collect_unsupported_fallback_kernels(), torch.nn.attention.sdpa_kernel(
147-
[SDPBackend.MATH]
147+
[
148+
SDPBackend.MATH # pyre-ignore[16]: Module `torch.nn.attention` has no attribute `SDPBackend`.
149+
]
148150
), torch.no_grad():
149151
# torch._logging.set_logs(post_grad_graphs=True)
150152
so_path = torch._inductor.aot_compile(edge_program_module, tuple(user_input_placeholders), options=options) # type: ignore[arg-type]

backends/cuda/replace_slice_copy_with_slice.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66

77
# pyre-strict
88

9-
from typing import Iterable
9+
from typing import Dict, Iterable, Tuple
1010

1111
import torch
1212
from executorch.exir.dialects._ops import ops
13+
from executorch.exir.dialects.edge._ops import EdgeOpOverload
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415
from torch import fx
1516

1617

17-
_SLICE_COPY_TARGETS = (
18+
_SLICE_COPY_TARGETS: Tuple[torch._ops.OpOverload | EdgeOpOverload] = (
1819
torch.ops.aten.slice_copy.Tensor,
1920
ops.edge.aten.slice_copy.Tensor,
2021
)
2122

22-
_SLICE_TARGETS = {
23+
_SLICE_TARGETS: Dict[
24+
torch._ops.OpOverload | EdgeOpOverload, torch._ops.OpOverload | EdgeOpOverload
25+
] = {
2326
torch.ops.aten.slice_copy.Tensor: torch.ops.aten.slice.Tensor,
2427
ops.edge.aten.slice_copy.Tensor: ops.edge.aten.slice.Tensor,
2528
}
@@ -99,8 +102,8 @@ def _is_view_user(self, node: fx.Node, user: fx.Node) -> bool:
99102
return False
100103

101104
def _argument_mutates(
102-
self, schema: torch._C.FunctionSchema, key
103-
) -> bool: # pyre-ignore[11]
105+
self, schema: torch._C.FunctionSchema, key: int | str
106+
) -> bool:
104107
arguments = schema.arguments
105108
if isinstance(key, int):
106109
if key >= len(arguments):

backends/cuda/tests/test_cuda_export.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11+
from executorch.backends.cuda.cuda_backend import CudaBackend
1112
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
1213
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
1314
from torch.export import export
@@ -30,7 +31,9 @@ def _export_to_cuda_with_lower(
3031
exported_program = export(module, inputs, strict=True)
3132

3233
# Create partitioner and compile specs
33-
partitioner = CudaPartitioner([])
34+
partitioner = CudaPartitioner(
35+
[CudaBackend.generate_method_name_compile_spec("forward")]
36+
)
3437

3538
# Use to_edge_transform_and_lower for complete pipeline
3639
edge_program_manager = to_edge_transform_and_lower(

backends/qualcomm/quantizer/qconfig.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def get_16a8w_qnn_qat_config(
200200
act_observer=MovingAverageMinMaxObserver,
201201
) -> QuantizationConfig:
202202
extra_args: Dict[str, Any] = {"eps": 2**-20}
203-
act_fake_quant_ctr = FakeQuantize.with_args(
203+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
204204
dtype=torch.int32,
205205
quant_min=torch.iinfo(torch.uint16).min,
206206
quant_max=torch.iinfo(torch.uint16).max,
@@ -398,7 +398,7 @@ def get_ptq_per_block_quant_config(
398398
def get_8a8w_qnn_qat_config(
399399
act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver
400400
) -> QuantizationConfig:
401-
act_fake_quant_ctr = FakeQuantize.with_args(
401+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
402402
dtype=torch.uint8,
403403
qscheme=(
404404
torch.per_tensor_symmetric if act_symmetric else torch.per_tensor_affine
@@ -458,7 +458,7 @@ def get_8a8w_qnn_qat_config(
458458
def get_16a4w_qnn_qat_config(
459459
act_observer=MovingAverageMinMaxObserver,
460460
) -> QuantizationConfig:
461-
act_fake_quant_ctr = FakeQuantize.with_args(
461+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
462462
dtype=torch.int32,
463463
quant_min=torch.iinfo(torch.uint16).min,
464464
quant_max=torch.iinfo(torch.uint16).max,
@@ -541,7 +541,7 @@ def get_qat_per_channel_quant_config(
541541
# If zero_point is 128, htp can do optimizations.
542542
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
543543
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
544-
act_fake_quant_ctr = FakeQuantize.with_args(
544+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
545545
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
546546
qscheme=torch.per_tensor_symmetric,
547547
observer=act_observer,
@@ -553,7 +553,7 @@ def get_qat_per_channel_quant_config(
553553
observer_or_fake_quant_ctr=act_fake_quant_ctr,
554554
)
555555
else:
556-
act_fake_quant_ctr = FakeQuantize.with_args(
556+
act_fake_quant_ctr = FusedMovingAvgObsFakeQuantize.with_args(
557557
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
558558
quant_min=torch.iinfo(act_dtype).min,
559559
quant_max=torch.iinfo(act_dtype).max,

examples/models/llama/model_args.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class ModelArgs:
6363
use_sdpa_with_kv_cache_op: bool = (
6464
False # Use custom sdpa op that updates kv cache in-place
6565
)
66+
# Device to use for the model: "cpu" or "cuda" (needed for QAT)
67+
# Only used for creating Rope parameters
68+
device: str = "cpu"
6669
# Generate logits for all inputs. When it's True, it would take big memory usage
6770
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
6871
# logits for all input tokens.)

examples/models/llama/rope.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,19 @@ def forward(
138138
# and https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py#L242.
139139
# Current only support non-long rope.
140140
def hf_precompute_freqs_cis(
141-
dim: int, end: int, theta: float, partial_rotary_factor: float = 1.0
141+
dim: int,
142+
end: int,
143+
theta: float,
144+
partial_rotary_factor: float = 1.0,
145+
device: Union[str, torch.device] = "cpu",
142146
):
143147
# Partial rotary embeddings.
144148
dim = int(dim * partial_rotary_factor)
145149

146150
# Short factor scaling.
147151
freqs = 1.0 / (
148152
theta
149-
** (torch.arange(0, dim, 2, device="cpu", dtype=torch.int64).float() / dim)
153+
** (torch.arange(0, dim, 2, device=device, dtype=torch.int64).float() / dim)
150154
)
151155
# TODO: support long factor scaling.
152156

@@ -236,6 +240,7 @@ def __init__(self, params: ModelArgs):
236240
self.precompute_freqs_cis = partial(
237241
hf_precompute_freqs_cis,
238242
partial_rotary_factor=self.params.partial_rotary_factor,
243+
device=self.params.device,
239244
)
240245
self.apply_rotary_emb = hf_apply_rotary_emb
241246
else:
@@ -244,6 +249,7 @@ def __init__(self, params: ModelArgs):
244249
use_scaled=self.params.use_scaled_rope,
245250
scale_factor=self.params.rope_scale_factor,
246251
high_freq_factor=self.params.high_freq_factor,
252+
device=self.params.device,
247253
)
248254
self.apply_rotary_emb = RotaryEmbedding()
249255

0 commit comments

Comments
 (0)