@@ -300,11 +300,26 @@ def test_run_with_ao_quantization_configs(
300
300
artifact = PipelineArtifact (data = models_dict , context = {})
301
301
stage .run (artifact )
302
302
303
- # Verify quantize_ was called with the model and config
304
- mock_quantize .assert_called_once_with (self .model , mock_config , mock_filter_fn )
303
+ # Verify quantize_ was called once (with the copied model, not the original)
304
+ self .assertEqual (mock_quantize .call_count , 1 )
305
+ # Verify the config and filter_fn arguments are correct
306
+ call_args = mock_quantize .call_args [0 ]
307
+ self .assertNotEqual (self .model , call_args [0 ])
308
+ self .assertEqual (call_args [1 ], mock_config )
309
+ self .assertEqual (call_args [2 ], mock_filter_fn )
305
310
306
- # Verify unwrap_tensor_subclass was called with the model
307
- mock_unwrap .assert_called_once_with (self .model )
311
+ # Verify unwrap_tensor_subclass was called once (with the copied model)
312
+ self .assertEqual (mock_unwrap .call_count , 1 )
313
+
314
+ # Verify that the original models_dict is unchanged
315
+ self .assertEqual (models_dict , {"forward" : self .model })
316
+
317
+ # Verify that the result artifact data contains valid models
318
+ result_artifact = stage .get_artifacts ()
319
+ self .assertIn ("forward" , result_artifact .data )
320
+ self .assertIsNotNone (result_artifact .data ["forward" ])
321
+ # verify the result model is NOT the same object as the original
322
+ self .assertIsNot (result_artifact .data ["forward" ], self .model )
308
323
309
324
310
325
class TestQuantizeStage (unittest .TestCase ):
@@ -398,6 +413,10 @@ def test_run_with_quantizers(
398
413
self .assertIn ("forward" , result_artifact .data )
399
414
self .assertEqual (result_artifact .data ["forward" ], mock_quantized_model )
400
415
416
+ # Verify that the original model in the input artifact is unchanged
417
+ self .assertEqual (artifact .data ["forward" ], self .model )
418
+ self .assertIsNot (result_artifact .data ["forward" ], self .model )
419
+
401
420
def test_run_empty_example_inputs (self ) -> None :
402
421
"""Test error when example inputs list is empty."""
403
422
mock_quantizer = Mock ()
0 commit comments