diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 0e0d0ce5b568..d2bf3fe07185 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -379,7 +379,7 @@ def __init__( @classmethod def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel": - _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) + _, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) diff --git a/tests/others/test_ema.py b/tests/others/test_ema.py index 5bed42b8488f..3443e6366f01 100644 --- a/tests/others/test_ema.py +++ b/tests/others/test_ema.py @@ -59,6 +59,25 @@ def simulate_backprop(self, unet): unet.load_state_dict(updated_state_dict) return unet + def test_from_pretrained(self): + # Save the model parameters to a temporary directory + unet, ema_unet = self.get_models() + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + + # Load the EMA model from the saved directory + loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False) + + # Check that the shadow parameters of the loaded model match the original EMA model + for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): + assert torch.allclose(original_param, loaded_param, atol=1e-4) + + # Verify that the optimization step is also preserved + assert loaded_ema_unet.optimization_step == ema_unet.optimization_step + + # Check the decay value + assert loaded_ema_unet.decay == ema_unet.decay + def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step. @@ -194,6 +213,25 @@ def simulate_backprop(self, unet): unet.load_state_dict(updated_state_dict) return unet + def test_from_pretrained(self): + # Save the model parameters to a temporary directory + unet, ema_unet = self.get_models() + with tempfile.TemporaryDirectory() as tmpdir: + ema_unet.save_pretrained(tmpdir) + + # Load the EMA model from the saved directory + loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True) + + # Check that the shadow parameters of the loaded model match the original EMA model + for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params): + assert torch.allclose(original_param, loaded_param, atol=1e-4) + + # Verify that the optimization step is also preserved + assert loaded_ema_unet.optimization_step == ema_unet.optimization_step + + # Check the decay value + assert loaded_ema_unet.decay == ema_unet.decay + def test_optimization_steps_updated(self): unet, ema_unet = self.get_models() # Take the first (hypothetical) EMA step.