Skip to content

Commit e619ec8

Browse files
committed
Fix tests
1 parent eb92a65 commit e619ec8

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

tests/models/autoencoders/test_models_autoencoder_kl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def test_enable_disable_slicing(self):
146146
)
147147

148148
def test_gradient_checkpointing_is_applied(self):
149-
expected_set = {"Decoder", "Encoder"}
149+
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
150150
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
151151

152152
def test_from_pretrained_hub(self):

tests/models/autoencoders/test_models_autoencoder_kl_temporal_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def prepare_init_args_and_inputs_for_common(self):
6565
return init_dict, inputs_dict
6666

6767
def test_gradient_checkpointing_is_applied(self):
68-
expected_set = {"Encoder", "TemporalDecoder"}
68+
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
6969
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
7070

7171
@unittest.skip("Test unsupported.")

tests/models/test_modeling_common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ def test_enable_disable_gradient_checkpointing(self):
803803
self.assertFalse(model.is_gradient_checkpointing)
804804

805805
@require_torch_accelerator_with_training
806-
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
806+
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5, skip: set[str] = {}):
807807
if not self.model_class._supports_gradient_checkpointing:
808808
return # Skip test if model does not support gradient checkpointing
809809

@@ -850,6 +850,8 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
850850
for name, param in named_params.items():
851851
if "post_quant_conv" in name:
852852
continue
853+
if name in skip:
854+
continue
853855
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
854856

855857
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")

tests/models/unets/test_models_unet_2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ def test_output_pretrained(self):
237237

238238
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
239239

240+
def test_gradient_checkpointing_is_applied(self):
241+
expected_set = {"DownBlock2D", "UNetMidBlock2D", "UpBlock2D"}
242+
243+
# NOTE: unlike UNet2DConditionModel, UNet2DModel does not currently support tuples for `attention_head_dim`
244+
attention_head_dim = 32
245+
block_out_channels = (32, 64)
246+
247+
super().test_gradient_checkpointing_is_applied(
248+
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
249+
)
250+
240251

241252
class NCSNppModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
242253
model_class = UNet2DModel
@@ -346,3 +357,21 @@ def test_output_pretrained_ve_large(self):
346357
def test_forward_with_norm_groups(self):
347358
# not required for this model
348359
pass
360+
361+
def test_gradient_checkpointing_is_applied(self):
362+
expected_set = {
363+
"UNetMidBlock2D",
364+
"SkipDownBlock2D",
365+
"AttnSkipDownBlock2D",
366+
"SkipUpBlock2D",
367+
"AttnSkipUpBlock2D",
368+
}
369+
370+
block_out_channels = (32, 64, 64, 64)
371+
372+
super().test_gradient_checkpointing_is_applied(
373+
expected_set=expected_set, block_out_channels=block_out_channels
374+
)
375+
376+
def test_effective_gradient_checkpointing(self):
377+
super().test_effective_gradient_checkpointing(skip={"time_proj.weight"})

0 commit comments

Comments
 (0)