Skip to content

Commit f9d4a77

Browse files
committed
fix from_pretrained and added test
1 parent fddbab7 commit f9d4a77

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

src/diffusers/training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def __init__(
379379

380380
@classmethod
381381
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
382-
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
382+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
383383
model = model_cls.from_pretrained(path)
384384

385385
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)

tests/others/test_ema.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,28 @@ def simulate_backprop(self, unet):
5959
unet.load_state_dict(updated_state_dict)
6060
return unet
6161

62+
def test_from_pretrained(self):
63+
#Save the model parameters to a temporary directory
64+
unet, ema_unet = self.get_models()
65+
with tempfile.TemporaryDirectory() as tmpdir:
66+
ema_unet.save_pretrained(tmpdir)
67+
68+
#Load the EMA model from the saved directory
69+
loaded_ema_unet = EMAModel.from_pretrained(
70+
tmpdir, model_cls=UNet2DConditionModel,foreach=False
71+
)
72+
73+
#Check that the shadow parameters of the loaded model match the original EMA model
74+
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
75+
assert torch.allclose(original_param, loaded_param, atol=1e-4)
76+
77+
#Verify that the optimization step is also preserved
78+
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
79+
80+
#Check the decay value
81+
assert loaded_ema_unet.decay == ema_unet.decay
82+
83+
6284
def test_optimization_steps_updated(self):
6385
unet, ema_unet = self.get_models()
6486
# Take the first (hypothetical) EMA step.
@@ -193,6 +215,28 @@ def simulate_backprop(self, unet):
193215
updated_state_dict.update({k: updated_param})
194216
unet.load_state_dict(updated_state_dict)
195217
return unet
218+
def test_from_pretrained(self):
219+
#Save the model parameters to a temporary directory
220+
unet, ema_unet = self.get_models()
221+
with tempfile.TemporaryDirectory() as tmpdir:
222+
ema_unet.save_pretrained(tmpdir)
223+
224+
#Load the EMA model from the saved directory
225+
loaded_ema_unet = EMAModel.from_pretrained(
226+
tmpdir, model_cls=UNet2DConditionModel,foreach=True
227+
)
228+
229+
#Check that the shadow parameters of the loaded model match the original EMA model
230+
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
231+
assert torch.allclose(original_param, loaded_param, atol=1e-4)
232+
233+
#Verify that the optimization step is also preserved
234+
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
235+
236+
#Check the decay value
237+
assert loaded_ema_unet.decay == ema_unet.decay
238+
239+
196240

197241
def test_optimization_steps_updated(self):
198242
unet, ema_unet = self.get_models()

0 commit comments

Comments
 (0)