Skip to content

Commit be6478e

Browse files
committed
updates
1 parent 7d23fb1 commit be6478e

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/models/test_modeling_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,7 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
746746

747747
# enable deterministic behavior for gradient checkpointing
748748
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
749+
inputs_dict_copy = copy.deepcopy(inputs_dict)
749750
model = self.model_class(**init_dict)
750751
model.to(torch_device)
751752

@@ -769,7 +770,7 @@ def test_effective_gradient_checkpointing(self, loss_tolerance=1e-5, param_grad_
769770

770771
assert model_2.is_gradient_checkpointing and model_2.training
771772

772-
out_2 = model_2(**inputs_dict).sample
773+
out_2 = model_2(**inputs_dict_copy).sample
773774
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
774775
# we won't calculate the loss and rather backprop on out.sum()
775776
model_2.zero_grad()

0 commit comments

Comments
 (0)