@@ -1858,20 +1858,20 @@ def test_hotswapping_compiled_model_linear(self, rank0, rank1):
18581858
18591859    @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
18601860    def  test_hotswapping_compiled_model_conv2d (self , rank0 , rank1 ):
1861-         # It's important to add this context to raise an error on recompilation 
18621861        if  "unet"  not  in   self .model_class .__name__ .lower ():
18631862            return 
18641863
1864+         # It's important to add this context to raise an error on recompilation 
18651865        target_modules  =  ["conv" , "conv1" , "conv2" ]
18661866        with  torch ._dynamo .config .patch (error_on_recompile = True ):
18671867            self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
18681868
18691869    @parameterized .expand ([(11 , 11 ), (7 , 13 ), (13 , 7 )])  # important to test small to large and vice versa  
18701870    def  test_hotswapping_compiled_model_both_linear_and_conv2d (self , rank0 , rank1 ):
1871-         # It's important to add this context to raise an error on recompilation 
18721871        if  "unet"  not  in   self .model_class .__name__ .lower ():
18731872            return 
18741873
1874+         # It's important to add this context to raise an error on recompilation 
18751875        target_modules  =  ["to_q" , "conv" ]
18761876        with  torch ._dynamo .config .patch (error_on_recompile = True ):
18771877            self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
@@ -1882,14 +1882,14 @@ def test_hotswapping_compiled_model_both_linear_and_other(self, rank0, rank1):
18821882        # with `torch.compile()` for models that have both linear and conv layers. In this test, we check 
18831883        # if we can target a linear layer from the transformer blocks and another linear layer from non-attention 
18841884        # block. 
1885-         # It's important to add this context to raise an error on recompilation 
18861885        target_modules  =  ["to_q" ]
18871886        init_dict , _  =  self .prepare_init_args_and_inputs_for_common ()
18881887        model  =  self .model_class (** init_dict )
18891888
18901889        target_modules .append (self .get_linear_module_name_other_than_attn (model ))
18911890        del  model 
18921891
1892+         # It's important to add this context to raise an error on recompilation 
18931893        with  torch ._dynamo .config .patch (error_on_recompile = True ):
18941894            self .check_model_hotswap (do_compile = True , rank0 = rank0 , rank1 = rank1 , target_modules0 = target_modules )
18951895
0 commit comments