Skip to content

Commit 655e735

Browse files
committed
make style
1 parent f9d4a77 commit 655e735

File tree

1 file changed

+13
-19
lines changed

1 file changed

+13
-19
lines changed

tests/others/test_ema.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -60,27 +60,24 @@ def simulate_backprop(self, unet):
6060
return unet
6161

6262
def test_from_pretrained(self):
63-
#Save the model parameters to a temporary directory
63+
# Save the model parameters to a temporary directory
6464
unet, ema_unet = self.get_models()
6565
with tempfile.TemporaryDirectory() as tmpdir:
6666
ema_unet.save_pretrained(tmpdir)
6767

68-
#Load the EMA model from the saved directory
69-
loaded_ema_unet = EMAModel.from_pretrained(
70-
tmpdir, model_cls=UNet2DConditionModel,foreach=False
71-
)
68+
# Load the EMA model from the saved directory
69+
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=False)
7270

73-
#Check that the shadow parameters of the loaded model match the original EMA model
71+
# Check that the shadow parameters of the loaded model match the original EMA model
7472
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
7573
assert torch.allclose(original_param, loaded_param, atol=1e-4)
7674

77-
#Verify that the optimization step is also preserved
75+
# Verify that the optimization step is also preserved
7876
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
7977

80-
#Check the decay value
78+
# Check the decay value
8179
assert loaded_ema_unet.decay == ema_unet.decay
8280

83-
8481
def test_optimization_steps_updated(self):
8582
unet, ema_unet = self.get_models()
8683
# Take the first (hypothetical) EMA step.
@@ -215,29 +212,26 @@ def simulate_backprop(self, unet):
215212
updated_state_dict.update({k: updated_param})
216213
unet.load_state_dict(updated_state_dict)
217214
return unet
215+
218216
def test_from_pretrained(self):
219-
#Save the model parameters to a temporary directory
217+
# Save the model parameters to a temporary directory
220218
unet, ema_unet = self.get_models()
221219
with tempfile.TemporaryDirectory() as tmpdir:
222220
ema_unet.save_pretrained(tmpdir)
223221

224-
#Load the EMA model from the saved directory
225-
loaded_ema_unet = EMAModel.from_pretrained(
226-
tmpdir, model_cls=UNet2DConditionModel,foreach=True
227-
)
222+
# Load the EMA model from the saved directory
223+
loaded_ema_unet = EMAModel.from_pretrained(tmpdir, model_cls=UNet2DConditionModel, foreach=True)
228224

229-
#Check that the shadow parameters of the loaded model match the original EMA model
225+
# Check that the shadow parameters of the loaded model match the original EMA model
230226
for original_param, loaded_param in zip(ema_unet.shadow_params, loaded_ema_unet.shadow_params):
231227
assert torch.allclose(original_param, loaded_param, atol=1e-4)
232228

233-
#Verify that the optimization step is also preserved
229+
# Verify that the optimization step is also preserved
234230
assert loaded_ema_unet.optimization_step == ema_unet.optimization_step
235231

236-
#Check the decay value
232+
# Check the decay value
237233
assert loaded_ema_unet.decay == ema_unet.decay
238234

239-
240-
241235
def test_optimization_steps_updated(self):
242236
unet, ema_unet = self.get_models()
243237
# Take the first (hypothetical) EMA step.

0 commit comments

Comments
 (0)