Skip to content

Commit 38a2dc6

Browse files
authored
Merge branch 'main' into bump-ao-pin
2 parents 58301a3 + 47acc87 commit 38a2dc6

File tree

68 files changed

+668
-322
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

68 files changed

+668
-322
lines changed

.github/workflows/trunk.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@ on:
88
tags:
99
- ciflow/trunk/*
1010
pull_request:
11+
paths:
12+
- .ci/docker/ci_commit_pins/pytorch.txt
13+
- .ci/scripts/**
1114
workflow_dispatch:
1215

1316
concurrency:

backends/apple/coreml/test/test_coreml_quantizer.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616

1717
from executorch.backends.apple.coreml.quantizer import CoreMLQuantizer
18-
from torch.export import export_for_training
18+
from torch.export import export
1919
from torchao.quantization.pt2e.quantize_pt2e import (
2020
convert_pt2e,
2121
prepare_pt2e,
@@ -32,9 +32,7 @@ def quantize_and_compare(
3232
) -> None:
3333
assert quantization_type in {"PTQ", "QAT"}
3434

35-
pre_autograd_aten_dialect = export_for_training(
36-
model, example_inputs, strict=True
37-
).module()
35+
pre_autograd_aten_dialect = export(model, example_inputs, strict=True).module()
3836

3937
quantization_config = LinearQuantizerConfig.from_dict(
4038
{

backends/apple/mps/test/test_mps_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def lower_module_and_test_output(
206206

207207
expected_output = model(*sample_inputs)
208208

209-
model = torch.export.export_for_training(
209+
model = torch.export.export(
210210
model, sample_inputs, dynamic_shapes=dynamic_shapes, strict=True
211211
).module()
212212

backends/arm/operators/op_abs.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,7 @@ def define_node(
7373
abs_output = output
7474

7575
# Do the INT32 Abs
76-
self._serialize_operator(
77-
node,
78-
tosa_graph,
76+
tosa_graph.addOperator(
7977
ts.TosaOp.Op().ABS,
8078
[
8179
rescaled_inputs[0].name,

backends/arm/operators/op_sum.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,7 @@ def define_node(
6767
dtype=ts.DType.INT32,
6868
)
6969

70-
self._serialize_operator(
71-
node,
72-
tosa_graph,
70+
tosa_graph.addOperator(
7371
ts.TosaOp.Op().REDUCE_SUM,
7472
[rescaled_inputs[0].name],
7573
[intermediate.name],
@@ -113,9 +111,7 @@ def define_node(
113111
attr = ts.TosaSerializerAttribute()
114112
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
115113

116-
self._serialize_operator(
117-
node,
118-
tosa_graph,
114+
tosa_graph.addOperator(
119115
ts.TosaOp.Op().REDUCE_SUM,
120116
[tensor.name],
121117
[output.name],

backends/cadence/aot/ref_implementations.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,21 @@ def quantized_conv_nhwc_per_tensor(
458458
- out_shift (int): Unused
459459
"""
460460

461-
if not input_tensor.is_contiguous(memory_format=torch.channels_last):
462-
raise ValueError("Input tensor must be in NHWC format")
461+
# Convert to NCHW format to reuse the existing implementation
462+
conv_is_1d = False
463+
if len(input_tensor.shape) == 3:
464+
conv_is_1d = True
465+
input_tensor = input_tensor.movedim(-1, 1).contiguous()
466+
if len(weight.shape) != 3:
467+
raise ValueError("Weight tensor must be 3D if input is 3D")
468+
weight = weight.movedim(-1, 1).contiguous()
469+
else:
470+
input_tensor = input_tensor.movedim(-1, -3)
471+
if len(weight.shape) != 4:
472+
raise ValueError("Weight tensor must be 4D if input is nd > 3")
473+
weight = torch.permute(weight, (0, -1, 1, 2)).contiguous()
463474

464-
return quantized_conv_per_tensor(
475+
nchw_out = quantized_conv_per_tensor(
465476
input_tensor,
466477
weight,
467478
bias,
@@ -478,6 +489,11 @@ def quantized_conv_nhwc_per_tensor(
478489
out_shift,
479490
)
480491

492+
if conv_is_1d:
493+
return nchw_out.movedim(1, -1).contiguous()
494+
else:
495+
return nchw_out.movedim(-3, -1).contiguous()
496+
481497

482498
def quantized_conv_variant(
483499
layout: str,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def test_quantized_layer_norm_per_tensor(
449449
), # expected_output: [1+2, 2+3, 3+4] / 0.5 = [6, 10, 14]
450450
memory_format,
451451
)
452-
for memory_format in [torch.contiguous_format]
452+
for memory_format in [torch.contiguous_format, torch.channels_last]
453453
],
454454
# Test case 5: Multiple output channels
455455
*[
@@ -686,10 +686,13 @@ def test_quantized_conv_per_tensor(
686686
) -> None:
687687
assert memory_format in [torch.contiguous_format, torch.channels_last]
688688

689-
if len(input_tensor.shape) == 3 and memory_format == torch.channels_last:
690-
self.fail("Channels last format is not supported for 3D input tensors")
691-
692-
input_tensor = input_tensor.to(memory_format=memory_format)
689+
if memory_format == torch.channels_last:
690+
if input_tensor.ndim == 3:
691+
input_tensor = input_tensor.movedim(1, -1)
692+
weight = weight.movedim(1, -1)
693+
else:
694+
input_tensor = input_tensor.movedim(-3, -1)
695+
weight = weight.movedim(-3, -1)
693696

694697
convs = [
695698
(
@@ -701,7 +704,7 @@ def test_quantized_conv_per_tensor(
701704

702705
optimized_convs = []
703706
if input_tensor.dtype == torch.int8 and weight.dtype == torch.int8:
704-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
707+
if memory_format == torch.contiguous_format:
705708
optimized_convs = [
706709
torch.ops.cadence.quantized_conv_nchw_asym8sxsym8s_asym8s.per_tensor,
707710
torch.ops.cadence.quantized_conv_nchw_dilated_asym8sxsym8s_asym8s.per_tensor,
@@ -715,7 +718,7 @@ def test_quantized_conv_per_tensor(
715718
torch.ops.cadence.quantized_conv_nhwc_depthwise_asym8sxsym8s_asym8s.per_tensor,
716719
]
717720
elif input_tensor.dtype == torch.uint8 and weight.dtype == torch.uint8:
718-
if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
721+
if memory_format == torch.contiguous_format:
719722
optimized_convs = [
720723
torch.ops.cadence.quantized_conv_nchw_asym8uxsym8u_asym8u.per_tensor,
721724
torch.ops.cadence.quantized_conv_nchw_dilated_asym8uxsym8u_asym8u.per_tensor,
@@ -746,7 +749,13 @@ def test_quantized_conv_per_tensor(
746749
output_zero_point,
747750
out_multiplier,
748751
out_shift,
749-
).to(memory_format=torch.contiguous_format)
752+
)
753+
754+
if memory_format == torch.channels_last:
755+
if input_tensor.ndim == 3:
756+
output = output.movedim(-1, 1)
757+
else:
758+
output = output.movedim(-1, -3)
750759

751760
# Verify output properties
752761
self.assertEqual(output.dtype, dtype, f"Output dtype should be {dtype}")

backends/cortex_m/test/test_quantize_op_fusion_pass.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
get_node_args,
2424
)
2525
from executorch.exir.dialects._ops import ops as exir_ops
26-
from torch.export import export, export_for_training
26+
from torch.export import export
2727
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
2828

2929

@@ -42,9 +42,7 @@ def _prepare_quantized_model(self, model_class):
4242
model = model_class()
4343

4444
# Export and quantize
45-
exported_model = export_for_training(
46-
model.eval(), self.example_inputs, strict=True
47-
).module()
45+
exported_model = export(model.eval(), self.example_inputs, strict=True).module()
4846
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
4947
quantized_model = convert_pt2e(prepared_model)
5048

@@ -242,9 +240,7 @@ def forward(self, x, y):
242240
inputs = (torch.randn(shape), torch.randn(shape))
243241

244242
model = SingleAddModel()
245-
exported_model = export_for_training(
246-
model.eval(), inputs, strict=True
247-
).module()
243+
exported_model = export(model.eval(), inputs, strict=True).module()
248244
prepared_model = prepare_pt2e(exported_model, AddQuantizer())
249245
quantized_model = convert_pt2e(prepared_model)
250246

backends/example/test_example_delegate.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ def get_example_inputs():
4646
)
4747

4848
m = model.eval()
49-
m = torch.export.export_for_training(
50-
m, copy.deepcopy(example_inputs), strict=True
51-
).module()
49+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
5250
# print("original model:", m)
5351
quantizer = ExampleQuantizer()
5452
# quantizer = XNNPACKQuantizer()
@@ -84,9 +82,7 @@ def test_delegate_mobilenet_v2(self):
8482
)
8583

8684
m = model.eval()
87-
m = torch.export.export_for_training(
88-
m, copy.deepcopy(example_inputs), strict=True
89-
).module()
85+
m = torch.export.export(m, copy.deepcopy(example_inputs), strict=True).module()
9086
quantizer = ExampleQuantizer()
9187

9288
m = prepare_pt2e(m, quantizer)

backends/mediatek/quantizer/annotator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torch._ops import OpOverload
1111
from torch._subclasses import FakeTensor
1212

13-
from torch.export import export_for_training
13+
from torch.export import export
1414
from torch.fx import Graph, Node
1515
from torch.fx.passes.utils.matcher_with_name_node_map_utils import (
1616
SubgraphMatcherWithNameNodeMap,
@@ -158,9 +158,7 @@ def forward(self, x):
158158
return norm, {}
159159

160160
for pattern_cls in (ExecuTorchPattern, MTKPattern):
161-
pattern_gm = export_for_training(
162-
pattern_cls(), (torch.randn(3, 3),), strict=True
163-
).module()
161+
pattern_gm = export(pattern_cls(), (torch.randn(3, 3),), strict=True).module()
164162
matcher = SubgraphMatcherWithNameNodeMap(
165163
pattern_gm, ignore_literals=True, remove_overlapping_matches=False
166164
)

0 commit comments

Comments
 (0)