From d48686cd65f267b666442cb7654f30a41f15860c Mon Sep 17 00:00:00 2001 From: hlky Date: Fri, 20 Dec 2024 17:54:54 +0000 Subject: [PATCH] Fix EMAModel test_from_pretrained --- tests/others/test_ema.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 3443e6366f01..7cf8f30ecc44 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -67,6 +67,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): @@ -221,6 +222,7 @@ def test_from_pretrained(self): # Load the EMA model from the saved directory loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + loaded_ema_unet.to(torch_device) # Check that the shadow parameters of the loaded model match the original EMA model for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):