Skip to content

Commit de92b67

Browse files
committed
update test
1 parent 50d0a28 commit de92b67

File tree

2 files changed

+10
-17
lines changed

2 files changed

+10
-17
lines changed

src/diffusers/models/unets/unet_motion_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,10 +1007,12 @@ def forward(
10071007
)[0]
10081008

10091009
if torch.is_grad_enabled() and self.gradient_checkpointing:
1010-
hidden_states = self._gradient_checkpointing_func(motion_module, hidden_states, temb)
1010+
hidden_states = self._gradient_checkpointing_func(
1011+
motion_module, hidden_states, None, None, None, num_frames, None
1012+
)
10111013
hidden_states = self._gradient_checkpointing_func(resnet, hidden_states, temb)
10121014
else:
1013-
hidden_states = motion_module(hidden_states, num_frames=num_frames)
1015+
hidden_states = motion_module(hidden_states, None, None, None, num_frames, None)
10141016
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
10151017

10161018
return hidden_states

tests/models/test_modeling_common.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -953,24 +953,15 @@ def test_gradient_checkpointing_is_applied(
953953
init_dict["block_out_channels"] = block_out_channels
954954

955955
model_class_copy = copy.copy(self.model_class)
956-
957-
modules_with_gc_enabled = {}
958-
959-
# now monkey patch the following function:
960-
# def _set_gradient_checkpointing(self, module, value=False):
961-
# if hasattr(module, "gradient_checkpointing"):
962-
# module.gradient_checkpointing = value
963-
964-
def _set_gradient_checkpointing_new(self, module, value=False):
965-
if hasattr(module, "gradient_checkpointing"):
966-
module.gradient_checkpointing = value
967-
modules_with_gc_enabled[module.__class__.__name__] = True
968-
969-
model_class_copy._set_gradient_checkpointing = _set_gradient_checkpointing_new
970-
971956
model = model_class_copy(**init_dict)
972957
model.enable_gradient_checkpointing()
973958

959+
modules_with_gc_enabled = {}
960+
for submodule in model.modules():
961+
if hasattr(submodule, "gradient_checkpointing"):
962+
self.assertTrue(submodule.gradient_checkpointing)
963+
modules_with_gc_enabled[submodule.__class__.__name__] = True
964+
974965
assert set(modules_with_gc_enabled.keys()) == expected_set
975966
assert all(modules_with_gc_enabled.values()), "All modules should be enabled"
976967

0 commit comments

Comments
 (0)