Skip to content

Commit e905dcd

Browse files
committed
NXP backend: Update deprecated export call.
1 parent ad7fb23 commit e905dcd

File tree

4 files changed

+16
-43
lines changed

4 files changed

+16
-43
lines changed

backends/nxp/tests/executorch_pipeline.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import torch
7-
from torch import nn
8-
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
97

108
from executorch import exir
119
from executorch.backends.nxp.backend.custom_delegation_options import (
@@ -27,6 +25,8 @@
2725
ExecutorchProgramManager,
2826
)
2927
from executorch.extension.export_util.utils import export_to_edge
28+
from torch import nn
29+
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3030

3131

3232
def _quantize_model(model, calibration_inputs: list[tuple[torch.Tensor]]):
@@ -73,9 +73,7 @@ def to_quantized_edge_program(
7373
# Make sure the model is in the evaluation mode.
7474
model.eval()
7575

76-
exir_program_aten = torch.export.export_for_training(
77-
model, example_input, strict=True
78-
)
76+
exir_program_aten = torch.export.export(model, example_input, strict=True)
7977

8078
exir_program_aten__module_quant = _quantize_model(
8179
exir_program_aten.module(), calibration_inputs

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
9595
example_input = (torch.ones(*input_shape),)
9696

9797
module = ConvBatchNormModule(bias, len(input_shape), 4)
98-
program = torch.export.export_for_training(module, example_input, strict=True)
98+
program = torch.export.export(module, example_input, strict=True)
9999
og_module = program.module()
100100

101101
pm = NeutronAtenPassManager()
@@ -129,7 +129,7 @@ def test_batch_norm_linear_fusing(bias: bool):
129129
example_input = (torch.ones(*input_shape),)
130130

131131
module = LinearBatchNormModule(bias, 4, input_shape[-1], input_shape[1])
132-
program = torch.export.export_for_training(module, example_input, strict=True)
132+
program = torch.export.export(module, example_input, strict=True)
133133
og_module = program.module()
134134

135135
pm = NeutronAtenPassManager()

backends/nxp/tests/test_quantizer.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def test_quantizer_conv2d():
2323

2424
example_input = (torch.ones(1, 4, 32, 32),)
2525
quantizer = NeutronQuantizer()
26-
graph_module = torch.export.export_for_training(
27-
model, example_input, strict=True
28-
).module()
26+
graph_module = torch.export.export(model, example_input, strict=True).module()
2927

3028
# noinspection PyTypeChecker
3129
m = prepare_pt2e(graph_module, quantizer)
@@ -64,9 +62,7 @@ def test_quantizer_linear():
6462

6563
example_input = (torch.ones(10, 32),)
6664
quantizer = NeutronQuantizer()
67-
graph_module = torch.export.export_for_training(
68-
model, example_input, strict=True
69-
).module()
65+
graph_module = torch.export.export(model, example_input, strict=True).module()
7066

7167
# noinspection PyTypeChecker
7268
m = prepare_pt2e(graph_module, quantizer)
@@ -105,9 +101,7 @@ def test_quantizer_maxpool2d():
105101

106102
example_input = (torch.ones(1, 8, 32, 32),)
107103
quantizer = NeutronQuantizer()
108-
graph_module = torch.export.export_for_training(
109-
model, example_input, strict=True
110-
).module()
104+
graph_module = torch.export.export(model, example_input, strict=True).module()
111105

112106
# noinspection PyTypeChecker
113107
m = prepare_pt2e(graph_module, quantizer)
@@ -143,9 +137,7 @@ def test_quantizer_softmax():
143137

144138
example_input = (torch.ones(1, 10),)
145139
quantizer = NeutronQuantizer()
146-
graph_module = torch.export.export_for_training(
147-
model, example_input, strict=True
148-
).module()
140+
graph_module = torch.export.export(model, example_input, strict=True).module()
149141

150142
# noinspection PyTypeChecker
151143
m = prepare_pt2e(graph_module, quantizer)
@@ -182,9 +174,7 @@ def test_quantizer_single_maxpool2d():
182174

183175
example_input = (torch.ones(1, 4, 32, 32),)
184176
quantizer = NeutronQuantizer()
185-
graph_module = torch.export.export_for_training(
186-
model, example_input, strict=True
187-
).module()
177+
graph_module = torch.export.export(model, example_input, strict=True).module()
188178

189179
# noinspection PyTypeChecker
190180
m = prepare_pt2e(graph_module, quantizer)
@@ -206,9 +196,7 @@ def test_quantizer_conv2d_relu():
206196

207197
example_input = (torch.ones(1, 4, 32, 32),)
208198
quantizer = NeutronQuantizer()
209-
graph_module = torch.export.export_for_training(
210-
model, example_input, strict=True
211-
).module()
199+
graph_module = torch.export.export(model, example_input, strict=True).module()
212200

213201
# noinspection PyTypeChecker
214202
m = prepare_pt2e(graph_module, quantizer)
@@ -231,9 +219,7 @@ def test_quantizer_conv2d_avg_pool2d():
231219

232220
example_input = (torch.ones(1, 4, 16, 16),)
233221
quantizer = NeutronQuantizer()
234-
graph_module = torch.export.export_for_training(
235-
model, example_input, strict=True
236-
).module()
222+
graph_module = torch.export.export(model, example_input, strict=True).module()
237223

238224
# noinspection PyTypeChecker
239225
m = prepare_pt2e(graph_module, quantizer)
@@ -256,9 +242,7 @@ def test_quantizer_conv2d_permute():
256242

257243
example_input = (torch.ones(1, 4, 16, 16),)
258244
quantizer = NeutronQuantizer()
259-
graph_module = torch.export.export_for_training(
260-
model, example_input, strict=True
261-
).module()
245+
graph_module = torch.export.export(model, example_input, strict=True).module()
262246

263247
# noinspection PyTypeChecker
264248
m = prepare_pt2e(graph_module, quantizer)
@@ -285,9 +269,7 @@ def test_multiple_shared_spec_ops_in_row():
285269

286270
example_input = (torch.ones(1, 3, 64, 64),)
287271
quantizer = NeutronQuantizer()
288-
graph_module = torch.export.export_for_training(
289-
model, example_input, strict=True
290-
).module()
272+
graph_module = torch.export.export(model, example_input, strict=True).module()
291273

292274
# noinspection PyTypeChecker
293275
m = prepare_pt2e(graph_module, quantizer)
@@ -321,9 +303,7 @@ def test_quantizers_order_invariance():
321303
example_input = (torch.ones(1, 4, 64, 64),)
322304
quantizer = NeutronQuantizer()
323305

324-
graph_module = torch.export.export_for_training(
325-
model, example_input, strict=True
326-
).module()
306+
graph_module = torch.export.export(model, example_input, strict=True).module()
327307

328308
m = prepare_pt2e(deepcopy(graph_module), quantizer)
329309
m(*example_input)

examples/nxp/aot_neutron_compile.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import executorch.kernels.quantized # noqa F401
1616

1717
import torch
18-
1918
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
2019
RemoveIOQuantOpsPass,
2120
)
@@ -24,14 +23,12 @@
2423
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
2524
from executorch.examples.models import MODEL_NAME_TO_MODEL
2625
from executorch.examples.models.model_factory import EagerModelFactory
27-
2826
from executorch.exir import (
2927
EdgeCompileConfig,
3028
ExecutorchBackendConfig,
3129
to_edge_transform_and_lower,
3230
)
3331
from executorch.extension.export_util import save_pte_program
34-
3532
from torch.export import export
3633
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3734

@@ -227,9 +224,7 @@ def _get_batch_size(data):
227224
model = model.eval()
228225

229226
# 2. Export the model to ATEN
230-
exported_program = torch.export.export_for_training(
231-
model, example_inputs, strict=True
232-
)
227+
exported_program = torch.export.export(model, example_inputs, strict=True)
233228

234229
module = exported_program.module()
235230

0 commit comments

Comments
 (0)