File tree Expand file tree Collapse file tree 1 file changed +2
-0
lines changed Expand file tree Collapse file tree 1 file changed +2
-0
lines changed Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def test_from_pretrained(self):
6767
6868            # Load the EMA model from the saved directory 
6969            loaded_ema_unet  =  EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = False )
70+             loaded_ema_unet .to (torch_device )
7071
7172        # Check that the shadow parameters of the loaded model match the original EMA model 
7273        for  original_param , loaded_param  in  zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
@@ -221,6 +222,7 @@ def test_from_pretrained(self):
221222
222223            # Load the EMA model from the saved directory 
223224            loaded_ema_unet  =  EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = True )
225+             loaded_ema_unet .to (torch_device )
224226
225227        # Check that the shadow parameters of the loaded model match the original EMA model 
226228        for  original_param , loaded_param  in  zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments