@@ -59,6 +59,28 @@ def simulate_backprop(self, unet):
5959 unet .load_state_dict (updated_state_dict )
6060 return unet
6161
62+ def test_from_pretrained (self ):
63+ #Save the model parameters to a temporary directory
64+ unet , ema_unet = self .get_models ()
65+ with tempfile .TemporaryDirectory () as tmpdir :
66+ ema_unet .save_pretrained (tmpdir )
67+
68+ #Load the EMA model from the saved directory
69+ loaded_ema_unet = EMAModel .from_pretrained (
70+ tmpdir , model_cls = UNet2DConditionModel ,foreach = False
71+ )
72+
73+ #Check that the shadow parameters of the loaded model match the original EMA model
74+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
75+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
76+
77+ #Verify that the optimization step is also preserved
78+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
79+
80+ #Check the decay value
81+ assert loaded_ema_unet .decay == ema_unet .decay
82+
83+
6284 def test_optimization_steps_updated (self ):
6385 unet , ema_unet = self .get_models ()
6486 # Take the first (hypothetical) EMA step.
@@ -193,6 +215,28 @@ def simulate_backprop(self, unet):
193215 updated_state_dict .update ({k : updated_param })
194216 unet .load_state_dict (updated_state_dict )
195217 return unet
218+ def test_from_pretrained (self ):
219+ #Save the model parameters to a temporary directory
220+ unet , ema_unet = self .get_models ()
221+ with tempfile .TemporaryDirectory () as tmpdir :
222+ ema_unet .save_pretrained (tmpdir )
223+
224+ #Load the EMA model from the saved directory
225+ loaded_ema_unet = EMAModel .from_pretrained (
226+ tmpdir , model_cls = UNet2DConditionModel ,foreach = True
227+ )
228+
229+ #Check that the shadow parameters of the loaded model match the original EMA model
230+ for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
231+ assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
232+
233+ #Verify that the optimization step is also preserved
234+ assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
235+
236+ #Check the decay value
237+ assert loaded_ema_unet .decay == ema_unet .decay
238+
239+
196240
197241 def test_optimization_steps_updated (self ):
198242 unet , ema_unet = self .get_models ()
0 commit comments