Skip to content

Commit 71a7806

Browse files
NXP backend: Disable training mode and replace deprecated call (#13756)
### Summary Ensure the input models are always in evaluation mode. Additionally, a deprecated export call was replaced by an up-to-date one. ### Test plan Correct function is tested by most of the existing tests.
1 parent a42423c commit 71a7806

File tree

4 files changed

+21
-41
lines changed

4 files changed

+21
-41
lines changed

backends/nxp/tests/executorch_pipeline.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,10 @@ def to_quantized_edge_program(
9696

9797
example_input = calibration_inputs[0]
9898

99-
exir_program_aten = torch.export.export_for_training(
100-
model, example_input, strict=True
101-
)
99+
# Make sure the model is in the evaluation mode.
100+
model.eval()
101+
102+
exir_program_aten = torch.export.export(model, example_input, strict=True)
102103

103104
exir_program_aten__module_quant = _quantize_model(
104105
exir_program_aten.module(), calibration_inputs
@@ -147,5 +148,9 @@ def to_edge_program(
147148
calibration_inputs = get_random_calibration_inputs(to_model_input_spec(input_spec))
148149

149150
example_input = calibration_inputs[0]
151+
152+
# Make sure the model is in the evaluation mode.
153+
model.eval()
154+
150155
exir_program = torch.export.export(model, example_input)
151156
return exir.to_edge(exir_program)

backends/nxp/tests/test_batch_norm_fusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test_batch_norm_conv_fusing(bias: bool, input_shape: list[int]):
9595
example_input = (torch.ones(*input_shape),)
9696

9797
module = ConvBatchNormModule(bias, len(input_shape), 4)
98-
program = torch.export.export_for_training(module, example_input, strict=True)
98+
program = torch.export.export(module, example_input, strict=True)
9999
og_module = program.module()
100100

101101
pm = NeutronAtenPassManager()
@@ -129,7 +129,7 @@ def test_batch_norm_linear_fusing(bias: bool):
129129
example_input = (torch.ones(*input_shape),)
130130

131131
module = LinearBatchNormModule(bias, 4, input_shape[-1], input_shape[1])
132-
program = torch.export.export_for_training(module, example_input, strict=True)
132+
program = torch.export.export(module, example_input, strict=True)
133133
og_module = program.module()
134134

135135
pm = NeutronAtenPassManager()

backends/nxp/tests/test_quantizer.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@ def test_quantizer_conv2d():
2323

2424
example_input = (torch.ones(1, 4, 32, 32),)
2525
quantizer = NeutronQuantizer()
26-
graph_module = torch.export.export_for_training(
27-
model, example_input, strict=True
28-
).module()
26+
graph_module = torch.export.export(model, example_input, strict=True).module()
2927

3028
# noinspection PyTypeChecker
3129
m = prepare_pt2e(graph_module, quantizer)
@@ -64,9 +62,7 @@ def test_quantizer_linear():
6462

6563
example_input = (torch.ones(10, 32),)
6664
quantizer = NeutronQuantizer()
67-
graph_module = torch.export.export_for_training(
68-
model, example_input, strict=True
69-
).module()
65+
graph_module = torch.export.export(model, example_input, strict=True).module()
7066

7167
# noinspection PyTypeChecker
7268
m = prepare_pt2e(graph_module, quantizer)
@@ -105,9 +101,7 @@ def test_quantizer_maxpool2d():
105101

106102
example_input = (torch.ones(1, 8, 32, 32),)
107103
quantizer = NeutronQuantizer()
108-
graph_module = torch.export.export_for_training(
109-
model, example_input, strict=True
110-
).module()
104+
graph_module = torch.export.export(model, example_input, strict=True).module()
111105

112106
# noinspection PyTypeChecker
113107
m = prepare_pt2e(graph_module, quantizer)
@@ -143,9 +137,7 @@ def test_quantizer_softmax():
143137

144138
example_input = (torch.ones(1, 10),)
145139
quantizer = NeutronQuantizer()
146-
graph_module = torch.export.export_for_training(
147-
model, example_input, strict=True
148-
).module()
140+
graph_module = torch.export.export(model, example_input, strict=True).module()
149141

150142
# noinspection PyTypeChecker
151143
m = prepare_pt2e(graph_module, quantizer)
@@ -182,9 +174,7 @@ def test_quantizer_single_maxpool2d():
182174

183175
example_input = (torch.ones(1, 4, 32, 32),)
184176
quantizer = NeutronQuantizer()
185-
graph_module = torch.export.export_for_training(
186-
model, example_input, strict=True
187-
).module()
177+
graph_module = torch.export.export(model, example_input, strict=True).module()
188178

189179
# noinspection PyTypeChecker
190180
m = prepare_pt2e(graph_module, quantizer)
@@ -206,9 +196,7 @@ def test_quantizer_conv2d_relu():
206196

207197
example_input = (torch.ones(1, 4, 32, 32),)
208198
quantizer = NeutronQuantizer()
209-
graph_module = torch.export.export_for_training(
210-
model, example_input, strict=True
211-
).module()
199+
graph_module = torch.export.export(model, example_input, strict=True).module()
212200

213201
# noinspection PyTypeChecker
214202
m = prepare_pt2e(graph_module, quantizer)
@@ -231,9 +219,7 @@ def test_quantizer_conv2d_avg_pool2d():
231219

232220
example_input = (torch.ones(1, 4, 16, 16),)
233221
quantizer = NeutronQuantizer()
234-
graph_module = torch.export.export_for_training(
235-
model, example_input, strict=True
236-
).module()
222+
graph_module = torch.export.export(model, example_input, strict=True).module()
237223

238224
# noinspection PyTypeChecker
239225
m = prepare_pt2e(graph_module, quantizer)
@@ -256,9 +242,7 @@ def test_quantizer_conv2d_permute():
256242

257243
example_input = (torch.ones(1, 4, 16, 16),)
258244
quantizer = NeutronQuantizer()
259-
graph_module = torch.export.export_for_training(
260-
model, example_input, strict=True
261-
).module()
245+
graph_module = torch.export.export(model, example_input, strict=True).module()
262246

263247
# noinspection PyTypeChecker
264248
m = prepare_pt2e(graph_module, quantizer)
@@ -285,9 +269,7 @@ def test_multiple_shared_spec_ops_in_row():
285269

286270
example_input = (torch.ones(1, 3, 64, 64),)
287271
quantizer = NeutronQuantizer()
288-
graph_module = torch.export.export_for_training(
289-
model, example_input, strict=True
290-
).module()
272+
graph_module = torch.export.export(model, example_input, strict=True).module()
291273

292274
# noinspection PyTypeChecker
293275
m = prepare_pt2e(graph_module, quantizer)
@@ -321,9 +303,7 @@ def test_quantizers_order_invariance():
321303
example_input = (torch.ones(1, 4, 64, 64),)
322304
quantizer = NeutronQuantizer()
323305

324-
graph_module = torch.export.export_for_training(
325-
model, example_input, strict=True
326-
).module()
306+
graph_module = torch.export.export(model, example_input, strict=True).module()
327307

328308
m = prepare_pt2e(deepcopy(graph_module), quantizer)
329309
m(*example_input)

examples/nxp/aot_neutron_compile.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import executorch.kernels.quantized # noqa F401
1616

1717
import torch
18-
1918
from executorch.backends.nxp.backend.ir.edge_passes.remove_io_quant_ops_pass import (
2019
RemoveIOQuantOpsPass,
2120
)
@@ -24,14 +23,12 @@
2423
from executorch.backends.nxp.quantizer.neutron_quantizer import NeutronQuantizer
2524
from executorch.examples.models import MODEL_NAME_TO_MODEL
2625
from executorch.examples.models.model_factory import EagerModelFactory
27-
2826
from executorch.exir import (
2927
EdgeCompileConfig,
3028
ExecutorchBackendConfig,
3129
to_edge_transform_and_lower,
3230
)
3331
from executorch.extension.export_util import save_pte_program
34-
3532
from torch.export import export
3633
from torchao.quantization.pt2e.quantize_pt2e import convert_pt2e, prepare_pt2e
3734

@@ -227,9 +224,7 @@ def _get_batch_size(data):
227224
model = model.eval()
228225

229226
# 2. Export the model to ATEN
230-
exported_program = torch.export.export_for_training(
231-
model, example_inputs, strict=True
232-
)
227+
exported_program = torch.export.export(model, example_inputs, strict=True)
233228

234229
module = exported_program.module()
235230

0 commit comments

Comments
 (0)