Skip to content

Commit f427ead

Browse files
committed
fixes
1 parent c5a9254 commit f427ead

File tree

4 files changed

+16
-1
lines changed

4 files changed

+16
-1
lines changed

tests/models/test_modeling_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,12 +837,13 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
837837
continue
838838
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=param_grad_tol))
839839

840+
@unittest.skipIf(torch_device == "mps", "This test is not supported for MPS devices.")
840841
def test_gradient_checkpointing_is_applied(
841842
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
842843
):
843844
if not self.model_class._supports_gradient_checkpointing:
844845
return # Skip test if model does not support gradient checkpointing
845-
if torch_device == "mps" and self.model_class.__name__ in [
846+
if self.model_class.__name__ in [
846847
"UNetSpatioTemporalConditionModel",
847848
"AutoencoderKLTemporalDecoder",
848849
]:
@@ -876,6 +877,8 @@ def _set_gradient_checkpointing_new(self, module, value=False):
876877
model = model_class_copy(**init_dict)
877878
model.enable_gradient_checkpointing()
878879

880+
print(f"{set(modules_with_gc_enabled.keys())=}, {expected_set=}")
881+
879882
assert set(modules_with_gc_enabled.keys()) == expected_set
880883
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
881884

tests/models/transformers/test_models_transformer_allegro.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,7 @@ def prepare_init_args_and_inputs_for_common(self):
7777
}
7878
inputs_dict = self.dummy_input
7979
return init_dict, inputs_dict
80+
81+
def test_gradient_checkpointing_is_applied(self):
82+
expected_set = {"AllegroTransformer3DModel"}
83+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_cogview3plus.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,3 +83,7 @@ def prepare_init_args_and_inputs_for_common(self):
8383
}
8484
inputs_dict = self.dummy_input
8585
return init_dict, inputs_dict
86+
87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"CogView3PlusTransformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def prepare_init_args_and_inputs_for_common(self):
8484
def test_set_attn_processor_for_determinism(self):
8585
pass
8686

87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"SD3Transformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90+
8791

8892
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
8993
model_class = SD3Transformer2DModel

0 commit comments

Comments
 (0)