File tree Expand file tree Collapse file tree 1 file changed +2
-0
lines changed Expand file tree Collapse file tree 1 file changed +2
-0
lines changed Original file line number Diff line number Diff line change @@ -67,6 +67,7 @@ def test_from_pretrained(self):
67
67
68
68
# Load the EMA model from the saved directory
69
69
loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = False )
70
+ loaded_ema_unet .to (torch_device )
70
71
71
72
# Check that the shadow parameters of the loaded model match the original EMA model
72
73
for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
@@ -221,6 +222,7 @@ def test_from_pretrained(self):
221
222
222
223
# Load the EMA model from the saved directory
223
224
loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = True )
225
+ loaded_ema_unet .to (torch_device )
224
226
225
227
# Check that the shadow parameters of the loaded model match the original EMA model
226
228
for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
You can’t perform that action at this time.
0 commit comments