Skip to content

Commit 6fd6500

Browse files
committed
check.
1 parent aa73072 commit 6fd6500

File tree

6 files changed

+102
-275
lines changed

6 files changed

+102
-275
lines changed

tests/models/autoencoders/test_models_vae.py

Lines changed: 9 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
load_hf_numpy,
4040
require_torch_accelerator,
4141
require_torch_accelerator_with_fp16,
42-
require_torch_accelerator_with_training,
4342
require_torch_gpu,
4443
skip_mps,
4544
slow,
@@ -170,53 +169,14 @@ def prepare_init_args_and_inputs_for_common(self):
170169
inputs_dict = self.dummy_input
171170
return init_dict, inputs_dict
172171

172+
@unittest.skip("Not tested.")
173173
def test_forward_signature(self):
174174
pass
175175

176+
@unittest.skip("Not tested.")
176177
def test_training(self):
177178
pass
178179

179-
@require_torch_accelerator_with_training
180-
def test_gradient_checkpointing(self):
181-
# enable deterministic behavior for gradient checkpointing
182-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
183-
model = self.model_class(**init_dict)
184-
model.to(torch_device)
185-
186-
assert not model.is_gradient_checkpointing and model.training
187-
188-
out = model(**inputs_dict).sample
189-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
190-
# we won't calculate the loss and rather backprop on out.sum()
191-
model.zero_grad()
192-
193-
labels = torch.randn_like(out)
194-
loss = (out - labels).mean()
195-
loss.backward()
196-
197-
# re-instantiate the model now enabling gradient checkpointing
198-
model_2 = self.model_class(**init_dict)
199-
# clone model
200-
model_2.load_state_dict(model.state_dict())
201-
model_2.to(torch_device)
202-
model_2.enable_gradient_checkpointing()
203-
204-
assert model_2.is_gradient_checkpointing and model_2.training
205-
206-
out_2 = model_2(**inputs_dict).sample
207-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
208-
# we won't calculate the loss and rather backprop on out.sum()
209-
model_2.zero_grad()
210-
loss_2 = (out_2 - labels).mean()
211-
loss_2.backward()
212-
213-
# compare the output and parameters gradients
214-
self.assertTrue((loss - loss_2).abs() < 1e-5)
215-
named_params = dict(model.named_parameters())
216-
named_params_2 = dict(model_2.named_parameters())
217-
for name, param in named_params.items():
218-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
219-
220180
def test_from_pretrained_hub(self):
221181
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
222182
self.assertIsNotNone(model)
@@ -329,9 +289,11 @@ def prepare_init_args_and_inputs_for_common(self):
329289
inputs_dict = self.dummy_input
330290
return init_dict, inputs_dict
331291

292+
@unittest.skip("Not tested.")
332293
def test_forward_signature(self):
333294
pass
334295

296+
@unittest.skip("Not tested.")
335297
def test_forward_with_norm_groups(self):
336298
pass
337299

@@ -364,6 +326,7 @@ def prepare_init_args_and_inputs_for_common(self):
364326
inputs_dict = self.dummy_input
365327
return init_dict, inputs_dict
366328

329+
@unittest.skip("Not tested.")
367330
def test_outputs_equivalence(self):
368331
pass
369332

@@ -443,56 +406,14 @@ def prepare_init_args_and_inputs_for_common(self):
443406
inputs_dict = self.dummy_input
444407
return init_dict, inputs_dict
445408

409+
@unittest.skip("Not tested.")
446410
def test_forward_signature(self):
447411
pass
448412

413+
@unittest.skip("Not tested.")
449414
def test_training(self):
450415
pass
451416

452-
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
453-
def test_gradient_checkpointing(self):
454-
# enable deterministic behavior for gradient checkpointing
455-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
456-
model = self.model_class(**init_dict)
457-
model.to(torch_device)
458-
459-
assert not model.is_gradient_checkpointing and model.training
460-
461-
out = model(**inputs_dict).sample
462-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
463-
# we won't calculate the loss and rather backprop on out.sum()
464-
model.zero_grad()
465-
466-
labels = torch.randn_like(out)
467-
loss = (out - labels).mean()
468-
loss.backward()
469-
470-
# re-instantiate the model now enabling gradient checkpointing
471-
model_2 = self.model_class(**init_dict)
472-
# clone model
473-
model_2.load_state_dict(model.state_dict())
474-
model_2.to(torch_device)
475-
model_2.enable_gradient_checkpointing()
476-
477-
assert model_2.is_gradient_checkpointing and model_2.training
478-
479-
out_2 = model_2(**inputs_dict).sample
480-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
481-
# we won't calculate the loss and rather backprop on out.sum()
482-
model_2.zero_grad()
483-
loss_2 = (out_2 - labels).mean()
484-
loss_2.backward()
485-
486-
# compare the output and parameters gradients
487-
self.assertTrue((loss - loss_2).abs() < 1e-5)
488-
named_params = dict(model.named_parameters())
489-
named_params_2 = dict(model_2.named_parameters())
490-
for name, param in named_params.items():
491-
if "post_quant_conv" in name:
492-
continue
493-
494-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
495-
496417

497418
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
498419
model_class = AutoencoderOobleck
@@ -522,9 +443,11 @@ def prepare_init_args_and_inputs_for_common(self):
522443
inputs_dict = self.dummy_input
523444
return init_dict, inputs_dict
524445

446+
@unittest.skip("Not tested.")
525447
def test_forward_signature(self):
526448
pass
527449

450+
@unittest.skip("Not tested.")
528451
def test_forward_with_norm_groups(self):
529452
pass
530453

tests/models/test_modeling_common.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import inspect
1718
import json
1819
import os
@@ -50,6 +51,7 @@
5051
require_torch_gpu,
5152
require_torch_multi_gpu,
5253
run_test_in_subprocess,
54+
torch_all_close,
5355
torch_device,
5456
)
5557

@@ -732,6 +734,89 @@ def test_enable_disable_gradient_checkpointing(self):
732734
model.disable_gradient_checkpointing()
733735
self.assertFalse(model.is_gradient_checkpointing)
734736

737+
@require_torch_accelerator_with_training
738+
def test_effective_gradient_checkpointing(self):
739+
if not self.model_class._supports_gradient_checkpointing:
740+
return # Skip test if model does not support gradient checkpointing
741+
if torch_device == "mps" and self.model_class.__name__ in [
742+
"UNetSpatioTemporalConditionModel",
743+
"AutoencoderKLTemporalDecoder",
744+
]:
745+
return
746+
747+
# enable deterministic behavior for gradient checkpointing
748+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
749+
model = self.model_class(**init_dict)
750+
model.to(torch_device)
751+
752+
assert not model.is_gradient_checkpointing and model.training
753+
754+
out = model(**inputs_dict).sample
755+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
756+
# we won't calculate the loss and rather backprop on out.sum()
757+
model.zero_grad()
758+
759+
labels = torch.randn_like(out)
760+
loss = (out - labels).mean()
761+
loss.backward()
762+
763+
# re-instantiate the model now enabling gradient checkpointing
764+
model_2 = self.model_class(**init_dict)
765+
# clone model
766+
model_2.load_state_dict(model.state_dict())
767+
model_2.to(torch_device)
768+
model_2.enable_gradient_checkpointing()
769+
770+
assert model_2.is_gradient_checkpointing and model_2.training
771+
772+
out_2 = model_2(**inputs_dict).sample
773+
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
774+
# we won't calculate the loss and rather backprop on out.sum()
775+
model_2.zero_grad()
776+
loss_2 = (out_2 - labels).mean()
777+
loss_2.backward()
778+
779+
# compare the output and parameters gradients
780+
self.assertTrue((loss - loss_2).abs() < 1e-5)
781+
named_params = dict(model.named_parameters())
782+
named_params_2 = dict(model_2.named_parameters())
783+
for name, param in named_params.items():
784+
if "post_quant_conv" in name:
785+
continue
786+
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
787+
788+
def test_gradient_checkpointing_is_applied(self, expected_set=None):
789+
if not self.model_class._supports_gradient_checkpointing:
790+
return # Skip test if model does not support gradient checkpointing
791+
if torch_device == "mps" and self.model_class.__name__ == "UNetSpatioTemporalConditionModel":
792+
return
793+
794+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
795+
796+
init_dict["num_attention_heads"] = (8, 16)
797+
798+
model_class_copy = copy.copy(self.model_class)
799+
800+
modules_with_gc_enabled = {}
801+
802+
# now monkey patch the following function:
803+
# def _set_gradient_checkpointing(self, module, value=False):
804+
# if hasattr(module, "gradient_checkpointing"):
805+
# module.gradient_checkpointing = value
806+
807+
def _set_gradient_checkpointing_new(self, module, value=False):
808+
if hasattr(module, "gradient_checkpointing"):
809+
module.gradient_checkpointing = value
810+
modules_with_gc_enabled[module.__class__.__name__] = True
811+
812+
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
813+
814+
model = model_class_copy(**init_dict)
815+
model.enable_gradient_checkpointing()
816+
817+
assert set(modules_with_gc_enabled.keys()) == expected_set
818+
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
819+
735820
def test_deprecated_kwargs(self):
736821
has_kwarg_in_model_class = "kwargs" in inspect.signature(self.model_class.__init__).parameters
737822
has_deprecated_kwarg = len(self.model_class._deprecated_kwargs) > 0

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
require_peft_backend,
4444
require_torch_accelerator,
4545
require_torch_accelerator_with_fp16,
46-
require_torch_accelerator_with_training,
4746
require_torch_gpu,
4847
skip_mps,
4948
slow,
@@ -406,47 +405,6 @@ def test_xformers_enable_works(self):
406405
== "XFormersAttnProcessor"
407406
), "xformers is not enabled"
408407

409-
@require_torch_accelerator_with_training
410-
def test_gradient_checkpointing(self):
411-
# enable deterministic behavior for gradient checkpointing
412-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
413-
model = self.model_class(**init_dict)
414-
model.to(torch_device)
415-
416-
assert not model.is_gradient_checkpointing and model.training
417-
418-
out = model(**inputs_dict).sample
419-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
420-
# we won't calculate the loss and rather backprop on out.sum()
421-
model.zero_grad()
422-
423-
labels = torch.randn_like(out)
424-
loss = (out - labels).mean()
425-
loss.backward()
426-
427-
# re-instantiate the model now enabling gradient checkpointing
428-
model_2 = self.model_class(**init_dict)
429-
# clone model
430-
model_2.load_state_dict(model.state_dict())
431-
model_2.to(torch_device)
432-
model_2.enable_gradient_checkpointing()
433-
434-
assert model_2.is_gradient_checkpointing and model_2.training
435-
436-
out_2 = model_2(**inputs_dict).sample
437-
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
438-
# we won't calculate the loss and rather backprop on out.sum()
439-
model_2.zero_grad()
440-
loss_2 = (out_2 - labels).mean()
441-
loss_2.backward()
442-
443-
# compare the output and parameters gradients
444-
self.assertTrue((loss - loss_2).abs() < 1e-5)
445-
named_params = dict(model.named_parameters())
446-
named_params_2 = dict(model_2.named_parameters())
447-
for name, param in named_params.items():
448-
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
449-
450408
def test_model_with_attention_head_dim_tuple(self):
451409
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
452410

@@ -599,41 +557,15 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
599557
check_sliceable_dim_attr(module)
600558

601559
def test_gradient_checkpointing_is_applied(self):
602-
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
603-
604-
init_dict["block_out_channels"] = (16, 32)
605-
init_dict["attention_head_dim"] = (8, 16)
606-
607-
model_class_copy = copy.copy(self.model_class)
608-
609-
modules_with_gc_enabled = {}
610-
611-
# now monkey patch the following function:
612-
# def _set_gradient_checkpointing(self, module, value=False):
613-
# if hasattr(module, "gradient_checkpointing"):
614-
# module.gradient_checkpointing = value
615-
616-
def _set_gradient_checkpointing_new(self, module, value=False):
617-
if hasattr(module, "gradient_checkpointing"):
618-
module.gradient_checkpointing = value
619-
modules_with_gc_enabled[module.__class__.__name__] = True
620-
621-
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
622-
623-
model = model_class_copy(**init_dict)
624-
model.enable_gradient_checkpointing()
625-
626-
EXPECTED_SET = {
560+
expected_set = {
627561
"CrossAttnUpBlock2D",
628562
"CrossAttnDownBlock2D",
629563
"UNetMidBlock2DCrossAttn",
630564
"UpBlock2D",
631565
"Transformer2DModel",
632566
"DownBlock2D",
633567
}
634-
635-
assert set(modules_with_gc_enabled.keys()) == EXPECTED_SET
636-
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
568+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
637569

638570
def test_special_attn_proc(self):
639571
class AttnEasyProc(torch.nn.Module):

0 commit comments

Comments
 (0)