@@ -1775,6 +1775,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
17751775        fine. 
17761776        """ 
17771777        # create 2 adapters with different ranks and alphas 
1778+         torch .manual_seed (0 )
17781779        init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
17791780        model  =  self .model_class (** init_dict ).to (torch_device )
17801781
@@ -1809,7 +1810,8 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18091810            del  model 
18101811
18111812            # load the first adapter 
1812-             init_dict , inputs_dict  =  self .prepare_init_args_and_inputs_for_common ()
1813+             torch .manual_seed (0 )
1814+             init_dict , _  =  self .prepare_init_args_and_inputs_for_common ()
18131815            model  =  self .model_class (** init_dict ).to (torch_device )
18141816
18151817            if  do_compile  or  (rank0  !=  rank1 ):
@@ -1824,7 +1826,6 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18241826                model  =  torch .compile (model , mode = "reduce-overhead" )
18251827
18261828            with  torch .inference_mode ():
1827-                 torch .manual_seed (0 )
18281829                output0_after  =  model (** inputs_dict )["sample" ]
18291830            assert  torch .allclose (output0_before , output0_after , atol = tol , rtol = tol )
18301831
@@ -1833,7 +1834,6 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_
18331834
18341835            # we need to call forward to potentially trigger recompilation 
18351836            with  torch .inference_mode ():
1836-                 torch .manual_seed (0 )
18371837                output1_after  =  model (** inputs_dict )["sample" ]
18381838            assert  torch .allclose (output1_before , output1_after , atol = tol , rtol = tol )
18391839
0 commit comments