Skip to content

Commit d9ca099

Browse files
committed
Fix fixture in test_train_utils
1 parent d58dea2 commit d9ca099

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/tests/test_train_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def test_gradient_accumulation(
3636

3737
model = mocker.MagicMock(name="model")
3838
model().loss.__truediv__().detach.return_value = torch.tensor(1)
39+
model().loss.detach.return_value = torch.tensor(1)
3940
mock_tensor = mocker.MagicMock(name="tensor")
4041
batch = {"input": mock_tensor}
4142
train_dataloader = [batch, batch, batch, batch, batch]
@@ -94,6 +95,7 @@ def test_gradient_accumulation(
9495
def test_save_to_json(temp_output_dir, mocker):
9596
model = mocker.MagicMock(name="model")
9697
model().loss.__truediv__().detach.return_value = torch.tensor(1)
98+
model().loss.detach.return_value = torch.tensor(1)
9799
mock_tensor = mocker.MagicMock(name="tensor")
98100
batch = {"input": mock_tensor}
99101
train_dataloader = [batch, batch, batch, batch, batch]

0 commit comments

Comments
 (0)