Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def __init__(

@classmethod
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)

ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
Expand Down
38 changes: 38 additions & 0 deletions tests/others/test_ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@ def simulate_backprop(self, unet):
unet.load_state_dict(updated_state_dict)
return unet

def test_from_pretrained(self):
# Save the model parameters to a temporary directory
unet, ema_unet = self.get_models()
with tempfile.TemporaryDirectory() as tmpdir:
ema_unet.save_pretrained(tmpdir)

# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)

# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
assert torch.allclose(original_param, loaded_param, atol=1e-4)

# Verify that the optimization step is also preserved
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step

# Check the decay value
assert loaded_ema_unet.decay == ema_unet.decay

def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step.
Expand Down Expand Up @@ -194,6 +213,25 @@ def simulate_backprop(self, unet):
unet.load_state_dict(updated_state_dict)
return unet

def test_from_pretrained(self):
# Save the model parameters to a temporary directory
unet, ema_unet = self.get_models()
with tempfile.TemporaryDirectory() as tmpdir:
ema_unet.save_pretrained(tmpdir)

# Load the EMA model from the saved directory
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)

# Check that the shadow parameters of the loaded model match the original EMA model
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
assert torch.allclose(original_param, loaded_param, atol=1e-4)

# Verify that the optimization step is also preserved
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step

# Check the decay value
assert loaded_ema_unet.decay == ema_unet.decay

def test_optimization_steps_updated(self):
unet, ema_unet = self.get_models()
# Take the first (hypothetical) EMA step.
Expand Down
Loading