@@ -61,10 +61,10 @@ def test_extract_model(model_cls, input_node_name, output_node_name):
6161
6262 model = wrap_model (model_cls ().eval (), example_input = example_input , trace_parameters = True )
6363 extracted_module = extract_model (model , [input_node_name ], [output_node_name ])
64- with torch . no_grad ():
65- ret1 = model (example_input )
66- ret2 = extracted_module ( example_input )
67- assert torch .any (torch .isclose (ret1 , ret2 ))
64+ ret1 = model ( example_input )
65+ ret2 = extracted_module (example_input )
66+ assert not ret2 . grad_fn
67+ assert torch .any (torch .isclose (ret1 , ret2 ))
6868
6969
7070@pytest .mark .parametrize (
@@ -122,10 +122,11 @@ def test_extract_model_for_node_with_fq(model_cls, input_node_name, output_node_
122122 q_model = transformer .transform (layout )
123123
124124 extracted_module = extract_model (model , [input_node_name ], [output_node_name ])
125- with torch .no_grad ():
126- ret1 = q_model (example_input )
127- ret2 = extracted_module (example_input )
128- assert torch .all (torch .isclose (ret1 , ret2 ))
125+
126+ ret1 = q_model (example_input )
127+ ret2 = extracted_module (example_input )
128+ assert torch .all (torch .isclose (ret1 , ret2 ))
129+ assert not ret2 .grad_fn
129130
130131 extracted_fn = extracted_module
131132 if isinstance (extracted_fn , nn .Sequential ):
0 commit comments