Skip to content

Commit 8a0a25b

Browse files
Fix bug where source transform passes untransformed model to next stage (#14186)
1 parent 7fbb88c commit 8a0a25b

File tree

4 files changed

+26
-7
lines changed

4 files changed

+26
-7
lines changed

backends/apple/coreml/test/test_coreml_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(self, x):
166166
session, example_inputs, atol=1e-3
167167
)
168168
self._compare_eager_unquantized_model_outputs(
169-
session, model, example_inputs
169+
session, model, example_inputs, sqnr_threshold=15
170170
)
171171

172172
def test_int4_weight_only_per_group_validation(self):

backends/xnnpack/test/recipes/test_xnnpack_recipes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def forward(self, x) -> torch.Tensor:
154154
),
155155
ExportRecipe.get_recipe(
156156
XNNPackRecipeType.TORCHAO_INT8_DYNAMIC_ACT_INT4_WEIGHT_PER_TENSOR,
157-
group_size=8,
157+
group_size=32,
158158
),
159159
]
160160

export/stages.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def run(self, artifact: PipelineArtifact) -> None:
332332
self._transformed_models = copy.deepcopy(artifact.data)
333333

334334
# Apply torchao quantize_ to each model
335-
for _, model in artifact.data.items():
335+
for _, model in self._transformed_models.items():
336336
# pyre-ignore
337337
if len(self._quantization_recipe.ao_quantization_configs) > 1:
338338
raise ValueError(

export/tests/test_export_stages.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -300,11 +300,26 @@ def test_run_with_ao_quantization_configs(
300300
artifact = PipelineArtifact(data=models_dict, context={})
301301
stage.run(artifact)
302302

303-
# Verify quantize_ was called with the model and config
304-
mock_quantize.assert_called_once_with(self.model, mock_config, mock_filter_fn)
303+
# Verify quantize_ was called once (with the copied model, not the original)
304+
self.assertEqual(mock_quantize.call_count, 1)
305+
# Verify the config and filter_fn arguments are correct
306+
call_args = mock_quantize.call_args[0]
307+
self.assertNotEqual(self.model, call_args[0])
308+
self.assertEqual(call_args[1], mock_config)
309+
self.assertEqual(call_args[2], mock_filter_fn)
305310

306-
# Verify unwrap_tensor_subclass was called with the model
307-
mock_unwrap.assert_called_once_with(self.model)
311+
# Verify unwrap_tensor_subclass was called once (with the copied model)
312+
self.assertEqual(mock_unwrap.call_count, 1)
313+
314+
# Verify that the original models_dict is unchanged
315+
self.assertEqual(models_dict, {"forward": self.model})
316+
317+
# Verify that the result artifact data contains valid models
318+
result_artifact = stage.get_artifacts()
319+
self.assertIn("forward", result_artifact.data)
320+
self.assertIsNotNone(result_artifact.data["forward"])
321+
# verify the result model is NOT the same object as the original
322+
self.assertIsNot(result_artifact.data["forward"], self.model)
308323

309324

310325
class TestQuantizeStage(unittest.TestCase):
@@ -398,6 +413,10 @@ def test_run_with_quantizers(
398413
self.assertIn("forward", result_artifact.data)
399414
self.assertEqual(result_artifact.data["forward"], mock_quantized_model)
400415

416+
# Verify that the original model in the input artifact is unchanged
417+
self.assertEqual(artifact.data["forward"], self.model)
418+
self.assertIsNot(result_artifact.data["forward"], self.model)
419+
401420
def test_run_empty_example_inputs(self) -> None:
402421
"""Test error when example inputs list is empty."""
403422
mock_quantizer = Mock()

0 commit comments

Comments
 (0)