Skip to content

Commit 7633eaa

Browse files
committed
fixes
1 parent be6478e commit 7633eaa

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -738,15 +738,11 @@ def test_enable_disable_gradient_checkpointing(self):
738738
def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_tol=5e-5):
739739
if not self.model_class._supports_gradient_checkpointing:
740740
return # Skip test if model does not support gradient checkpointing
741-
if torch_device == "mps" and self.model_class.__name__ in [
742-
"UNetSpatioTemporalConditionModel",
743-
"AutoencoderKLTemporalDecoder",
744-
]:
745-
return
746741

747742
# enable deterministic behavior for gradient checkpointing
748743
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
749744
inputs_dict_copy = copy.deepcopy(inputs_dict)
745+
torch.manual_seed(0)
750746
model = self.model_class(**init_dict)
751747
model.to(torch_device)
752748

@@ -762,6 +758,7 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
762758
loss.backward()
763759

764760
# re-instantiate the model now enabling gradient checkpointing
761+
torch.manual_seed(0)
765762
model_2 = self.model_class(**init_dict)
766763
# clone model
767764
model_2.load_state_dict(model.state_dict())

0 commit comments

Comments
 (0)