@@ -60,27 +60,24 @@ def simulate_backprop(self, unet):
6060 return unet
6161
6262 def test_from_pretrained (self ):
63- #Save the model parameters to a temporary directory
63+ # Save the model parameters to a temporary directory
6464 unet , ema_unet = self .get_models ()
6565 with tempfile .TemporaryDirectory () as tmpdir :
6666 ema_unet .save_pretrained (tmpdir )
6767
68- #Load the EMA model from the saved directory
69- loaded_ema_unet = EMAModel .from_pretrained (
70- tmpdir , model_cls = UNet2DConditionModel ,foreach = False
71- )
68+ # Load the EMA model from the saved directory
69+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = False )
7270
73- #Check that the shadow parameters of the loaded model match the original EMA model
71+ # Check that the shadow parameters of the loaded model match the original EMA model
7472 for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
7573 assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
7674
77- #Verify that the optimization step is also preserved
75+ # Verify that the optimization step is also preserved
7876 assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
7977
80- #Check the decay value
78+ # Check the decay value
8179 assert loaded_ema_unet .decay == ema_unet .decay
8280
83-
8481 def test_optimization_steps_updated (self ):
8582 unet , ema_unet = self .get_models ()
8683 # Take the first (hypothetical) EMA step.
@@ -215,29 +212,26 @@ def simulate_backprop(self, unet):
215212 updated_state_dict .update ({k : updated_param })
216213 unet .load_state_dict (updated_state_dict )
217214 return unet
215+
218216 def test_from_pretrained (self ):
219- #Save the model parameters to a temporary directory
217+ # Save the model parameters to a temporary directory
220218 unet , ema_unet = self .get_models ()
221219 with tempfile .TemporaryDirectory () as tmpdir :
222220 ema_unet .save_pretrained (tmpdir )
223221
224- #Load the EMA model from the saved directory
225- loaded_ema_unet = EMAModel .from_pretrained (
226- tmpdir , model_cls = UNet2DConditionModel ,foreach = True
227- )
222+ # Load the EMA model from the saved directory
223+ loaded_ema_unet = EMAModel .from_pretrained (tmpdir , model_cls = UNet2DConditionModel , foreach = True )
228224
229- #Check that the shadow parameters of the loaded model match the original EMA model
225+ # Check that the shadow parameters of the loaded model match the original EMA model
230226 for original_param , loaded_param in zip (ema_unet .shadow_params , loaded_ema_unet .shadow_params ):
231227 assert torch .allclose (original_param , loaded_param , atol = 1e-4 )
232228
233- #Verify that the optimization step is also preserved
229+ # Verify that the optimization step is also preserved
234230 assert loaded_ema_unet .optimization_step == ema_unet .optimization_step
235231
236- #Check the decay value
232+ # Check the decay value
237233 assert loaded_ema_unet .decay == ema_unet .decay
238234
239-
240-
241235 def test_optimization_steps_updated (self ):
242236 unet , ema_unet = self .get_models ()
243237 # Take the first (hypothetical) EMA step.
0 commit comments