Skip to content

Commit af2d8c7

Browse files
committed
fixes
1 parent 6fd6500 commit af2d8c7

11 files changed

+64
-7
lines changed

tests/models/autoencoders/test_models_vae.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ def test_forward_signature(self):
177177
def test_training(self):
178178
pass
179179

180+
def test_gradient_checkpointing_is_applied(self):
181+
expected_set = {"Decoder", "Encoder"}
182+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
183+
180184
def test_from_pretrained_hub(self):
181185
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
182186
self.assertIsNotNone(model)
@@ -330,6 +334,10 @@ def prepare_init_args_and_inputs_for_common(self):
330334
def test_outputs_equivalence(self):
331335
pass
332336

337+
def test_gradient_checkpointing_is_applied(self):
338+
expected_set = {"DecoderTiny", "EncoderTiny"}
339+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
340+
333341

334342
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
335343
model_class = ConsistencyDecoderVAE
@@ -414,6 +422,10 @@ def test_forward_signature(self):
414422
def test_training(self):
415423
pass
416424

425+
def test_gradient_checkpointing_is_applied(self):
426+
expected_set = {"Encoder", "TemporalDecoder"}
427+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
428+
417429

418430
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
419431
model_class = AutoencoderOobleck

tests/models/test_modeling_common.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,7 @@ def test_enable_disable_gradient_checkpointing(self):
735735
self.assertFalse(model.is_gradient_checkpointing)
736736

737737
@require_torch_accelerator_with_training
738-
def test_effective_gradient_checkpointing(self):
738+
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5):
739739
if not self.model_class._supports_gradient_checkpointing:
740740
return # Skip test if model does not support gradient checkpointing
741741
if torch_device == "mps" and self.model_class.__name__ in [
@@ -777,23 +777,33 @@ def test_effective_gradient_checkpointing(self):
777777
loss_2.backward()
778778

779779
# compare the output and parameters gradients
780-
self.assertTrue((loss - loss_2).abs() < 1e-5)
780+
self.assertTrue((loss - loss_2).abs() < loss_tolerance)
781781
named_params = dict(model.named_parameters())
782782
named_params_2 = dict(model_2.named_parameters())
783783
for name, param in named_params.items():
784784
if "post_quant_conv" in name:
785785
continue
786786
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
787787

788-
def test_gradient_checkpointing_is_applied(self, expected_set=None):
788+
def test_gradient_checkpointing_is_applied(
789+
self, expected_set=None, attention_head_dim=None, num_attention_heads=None, block_out_channels=None
790+
):
789791
if not self.model_class._supports_gradient_checkpointing:
790792
return # Skip test if model does not support gradient checkpointing
791-
if torch_device == "mps" and self.model_class.__name__ == "UNetSpatioTemporalConditionModel":
793+
if torch_device == "mps" and self.model_class.__name__ in [
794+
"UNetSpatioTemporalConditionModel",
795+
"AutoencoderKLTemporalDecoder",
796+
]:
792797
return
793798

794799
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
795800

796-
init_dict["num_attention_heads"] = (8, 16)
801+
if attention_head_dim is not None:
802+
init_dict["attention_head_dim"] = attention_head_dim
803+
if num_attention_heads is not None:
804+
init_dict["num_attention_heads"] = num_attention_heads
805+
if block_out_channels is not None:
806+
init_dict["block_out_channels"] = block_out_channels
797807

798808
model_class_copy = copy.copy(self.model_class)
799809

tests/models/transformers/test_models_dit_transformer2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def test_correct_class_remapping_from_dict_config(self):
8484
model = Transformer2DModel.from_config(init_dict)
8585
assert isinstance(model, DiTTransformer2DModel)
8686

87+
def test_gradient_checkpointing_is_applied(self):
88+
expected_set = {"DiTTransformer2DModel"}
89+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
90+
8791
def test_correct_class_remapping_from_pretrained_config(self):
8892
config = DiTTransformer2DModel.load_config("facebook/DiT-XL-2-256", subfolder="transformer")
8993
model = Transformer2DModel.from_config(config)

tests/models/transformers/test_models_pixart_transformer2d.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def test_output(self):
9292
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
9393
)
9494

95+
def test_gradient_checkpointing_is_applied(self):
96+
expected_set = {"PixArtTransformer2DModel"}
97+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
98+
9599
def test_correct_class_remapping_from_dict_config(self):
96100
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
97101
model = Transformer2DModel.from_config(init_dict)

tests/models/transformers/test_models_transformer_aura_flow.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ def prepare_init_args_and_inputs_for_common(self):
7474
inputs_dict = self.dummy_input
7575
return init_dict, inputs_dict
7676

77+
def test_gradient_checkpointing_is_applied(self):
78+
expected_set = {"AuraFlowTransformer2DModel"}
79+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
80+
7781
@unittest.skip("AuraFlowTransformer2DModel uses its own dedicated attention processor. This test does not apply")
7882
def test_set_attn_processor_for_determinism(self):
7983
pass

tests/models/transformers/test_models_transformer_cogvideox.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,7 @@ def prepare_init_args_and_inputs_for_common(self):
8181
}
8282
inputs_dict = self.dummy_input
8383
return init_dict, inputs_dict
84+
85+
def test_gradient_checkpointing_is_applied(self):
86+
expected_set = {"CogVideoXTransformer3DModel"}
87+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_flux.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,7 @@ def test_deprecated_inputs_img_txt_ids_3d(self):
111111
torch.allclose(output_1, output_2, atol=1e-5),
112112
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
113113
)
114+
115+
def test_gradient_checkpointing_is_applied(self):
116+
expected_set = {"FluxTransformer2DModel"}
117+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_latte.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,7 @@ def test_output(self):
8686
super().test_output(
8787
expected_output_shape=(self.dummy_input[self.main_input_name].shape[0],) + self.output_shape
8888
)
89+
90+
def test_gradient_checkpointing_is_applied(self):
91+
expected_set = {"LatteTransformer3DModel"}
92+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,7 @@ def prepare_init_args_and_inputs_for_common(self):
8080
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
8181
def test_set_attn_processor_for_determinism(self):
8282
pass
83+
84+
def test_gradient_checkpointing_is_applied(self):
85+
expected_set = {"SD3Transformer2DModel"}
86+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,11 @@ def test_gradient_checkpointing_is_applied(self):
565565
"Transformer2DModel",
566566
"DownBlock2D",
567567
}
568-
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
568+
attention_head_dim = (8, 16)
569+
block_out_channels = (16, 32)
570+
super().test_gradient_checkpointing_is_applied(
571+
expected_set=expected_set, attention_head_dim=attention_head_dim, block_out_channels=block_out_channels
572+
)
569573

570574
def test_special_attn_proc(self):
571575
class AttnEasyProc(torch.nn.Module):

0 commit comments

Comments
 (0)