@@ -326,7 +326,7 @@ def test_disallow_eval_train(self) -> None:
326326 m .train ()
327327
328328 # After export: this is not OK
329- m = export_for_training (m , example_inputs ).module ()
329+ m = export_for_training (m , example_inputs , strict = True ).module ()
330330 with self .assertRaises (NotImplementedError ):
331331 m .eval ()
332332 with self .assertRaises (NotImplementedError ):
@@ -380,7 +380,7 @@ def forward(self, x):
380380 m = M ().train ()
381381 example_inputs = (torch .randn (1 , 3 , 3 , 3 ),)
382382 bn_train_op , bn_eval_op = self ._get_bn_train_eval_ops () # pyre-ignore[23]
383- m = export_for_training (m , example_inputs ).module ()
383+ m = export_for_training (m , example_inputs , strict = True ).module ()
384384
385385 def _assert_ops_are_correct (m : torch .fx .GraphModule , train : bool ) -> None :
386386 bn_op = bn_train_op if train else bn_eval_op
@@ -449,10 +449,7 @@ def forward(self, x):
449449 quantizer .set_global (operator_config )
450450 example_inputs = (torch .randn (2 , 2 ),)
451451 m = M ().eval ()
452- m = export_for_training (
453- m ,
454- example_inputs ,
455- ).module ()
452+ m = export_for_training (m , example_inputs , strict = True ).module ()
456453 weight_meta = None
457454 for n in m .graph .nodes : # pyre-ignore[16]
458455 if (
@@ -481,7 +478,7 @@ def test_reentrant(self) -> None:
481478 get_symmetric_quantization_config (is_per_channel = True , is_qat = True )
482479 )
483480 m .conv_bn_relu = export_for_training ( # pyre-ignore[8]
484- m .conv_bn_relu , example_inputs
481+ m .conv_bn_relu , example_inputs , strict = True
485482 ).module ()
486483 m .conv_bn_relu = prepare_qat_pt2e (m .conv_bn_relu , quantizer ) # pyre-ignore[6,8]
487484 m (* example_inputs )
@@ -490,7 +487,7 @@ def test_reentrant(self) -> None:
490487 quantizer = XNNPACKQuantizer ().set_module_type (
491488 torch .nn .Linear , get_symmetric_quantization_config (is_per_channel = False )
492489 )
493- m = export_for_training (m , example_inputs ).module ()
490+ m = export_for_training (m , example_inputs , strict = True ).module ()
494491 m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
495492 m = convert_pt2e (m )
496493
@@ -553,7 +550,7 @@ def check_nn_module(node: torch.fx.Node) -> None:
553550 )
554551
555552 m .conv_bn_relu = export_for_training ( # pyre-ignore[8]
556- m .conv_bn_relu , example_inputs
553+ m .conv_bn_relu , example_inputs , strict = True
557554 ).module ()
558555 for node in m .conv_bn_relu .graph .nodes : # pyre-ignore[16]
559556 if node .op not in ["placeholder" , "output" , "get_attr" ]:
@@ -568,7 +565,7 @@ def test_speed(self) -> None:
568565
569566 def dynamic_quantize_pt2e (model , example_inputs ) -> torch .fx .GraphModule :
570567 torch ._dynamo .reset ()
571- model = export_for_training (model , example_inputs ).module ()
568+ model = export_for_training (model , example_inputs , strict = True ).module ()
572569 # Per channel quantization for weight
573570 # Dynamic quantization for activation
574571 # Please read a detail: https://fburl.com/code/30zds51q
@@ -625,7 +622,7 @@ def forward(self, x):
625622
626623 example_inputs = (torch .randn (1 , 3 , 5 , 5 ),)
627624 m = M ()
628- m = export_for_training (m , example_inputs ).module ()
625+ m = export_for_training (m , example_inputs , strict = True ).module ()
629626 quantizer = XNNPACKQuantizer ().set_global (
630627 get_symmetric_quantization_config (),
631628 )
@@ -701,7 +698,6 @@ def test_save_load(self) -> None:
701698
702699
703700class TestNumericDebugger (TestCase ):
704-
705701 def _extract_debug_handles (self , model ) -> Dict [str , int ]:
706702 debug_handle_map : Dict [str , int ] = {}
707703
@@ -731,7 +727,7 @@ def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
731727 def test_quantize_pt2e_preserve_handle (self ) -> None :
732728 m = TestHelperModules .Conv2dThenConv1d ()
733729 example_inputs = m .example_inputs ()
734- ep = export_for_training (m , example_inputs )
730+ ep = export_for_training (m , example_inputs , strict = True )
735731 generate_numeric_debug_handle (ep )
736732 m = ep .module ()
737733
@@ -761,7 +757,7 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
761757 def test_extract_results_from_loggers (self ) -> None :
762758 m = TestHelperModules .Conv2dThenConv1d ()
763759 example_inputs = m .example_inputs ()
764- ep = export_for_training (m , example_inputs )
760+ ep = export_for_training (m , example_inputs , strict = True )
765761 generate_numeric_debug_handle (ep )
766762 m = ep .module ()
767763 m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
@@ -779,18 +775,20 @@ def test_extract_results_from_loggers(self) -> None:
779775 ref_results = extract_results_from_loggers (m_ref_logger )
780776 quant_results = extract_results_from_loggers (m_quant_logger )
781777 comparison_results = compare_results (
782- ref_results , quant_results # pyre-ignore[6]
778+ ref_results ,
779+ quant_results , # pyre-ignore[6]
783780 )
784781 for node_summary in comparison_results .values ():
785782 if len (node_summary .results ) > 0 :
786783 self .assertGreaterEqual (
787- node_summary .results [0 ].sqnr , 35 # pyre-ignore[6]
784+ node_summary .results [0 ].sqnr ,
785+ 35 , # pyre-ignore[6]
788786 )
789787
790788 def test_extract_results_from_loggers_list_output (self ) -> None :
791789 m = TestHelperModules .Conv2dWithSplit ()
792790 example_inputs = m .example_inputs ()
793- ep = export_for_training (m , example_inputs )
791+ ep = export_for_training (m , example_inputs , strict = True )
794792 generate_numeric_debug_handle (ep )
795793 m = ep .module ()
796794 m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
@@ -808,7 +806,8 @@ def test_extract_results_from_loggers_list_output(self) -> None:
808806 ref_results = extract_results_from_loggers (m_ref_logger )
809807 quant_results = extract_results_from_loggers (m_quant_logger )
810808 comparison_results = compare_results (
811- ref_results , quant_results # pyre-ignore[6]
809+ ref_results ,
810+ quant_results , # pyre-ignore[6]
812811 )
813812 for node_summary in comparison_results .values ():
814813 if len (node_summary .results ) > 0 :
0 commit comments