Skip to content

Commit 4c24f61

Browse files
BanzaiTokyovfdev-5
andauthored
adds available_device to test_loss.py #3335 (#3364)
* adds available_device to test_loss.py #3335 * removes available device from test_reset() and test_sum_detached() --------- Co-authored-by: vfdev <[email protected]>
1 parent 4bb80b7 commit 4c24f61

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

tests/ignite/metrics/test_loss.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616

1717
class DummyLoss1(Loss):
18-
def __init__(self, loss_fn, true_output, output_transform=lambda x: x):
19-
super(DummyLoss1, self).__init__(loss_fn, output_transform=output_transform)
18+
def __init__(self, loss_fn, true_output, output_transform=lambda x: x, device="cpu"):
19+
super().__init__(loss_fn, output_transform=output_transform, device=device)
2020
print(true_output)
2121
self.true_output = true_output
2222

@@ -30,23 +30,23 @@ def update(self, output):
3030
assert output == self.true_output
3131

3232

33-
def test_output_as_mapping_without_criterion_kwargs():
33+
def test_output_as_mapping_without_criterion_kwargs(available_device):
3434
y_pred = torch.tensor([[2.0], [-2.0]])
3535
y = torch.zeros(2)
3636
criterion_kwargs = {}
3737

38-
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs))
38+
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device)
3939
state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {}}))
4040
engine = MagicMock(state=state)
4141
loss_metric.iteration_completed(engine)
4242

4343

44-
def test_output_as_mapping_with_criterion_kwargs():
44+
def test_output_as_mapping_with_criterion_kwargs(available_device):
4545
y_pred = torch.tensor([[2.0], [-2.0]])
4646
y = torch.zeros(2)
4747
criterion_kwargs = {"reduction": "sum"}
4848

49-
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs))
49+
loss_metric = DummyLoss1(nll_loss, true_output=(y_pred, y, criterion_kwargs), device=available_device)
5050
state = State(output=({"y_pred": y_pred, "y": y, "criterion_kwargs": {"reduction": "sum"}}))
5151
engine = MagicMock(state=state)
5252
loss_metric.iteration_completed(engine)
@@ -79,8 +79,9 @@ def test_zero_div():
7979

8080

8181
@pytest.mark.parametrize("criterion", [nll_loss, nn.NLLLoss()])
82-
def test_compute(criterion):
83-
loss = Loss(criterion)
82+
def test_compute(criterion, available_device):
83+
loss = Loss(criterion, device=available_device)
84+
assert loss._device == torch.device(available_device)
8485

8586
y_pred, y, expected_loss = y_test_1()
8687
loss.update((y_pred, y))
@@ -99,7 +100,7 @@ def test_non_averaging_loss():
99100
loss.update((y_pred, y))
100101

101102

102-
def test_gradient_based_loss():
103+
def test_gradient_based_loss(available_device):
103104
# Tests https://github.com/pytorch/ignite/issues/1674
104105
x = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], requires_grad=True)
105106
y_pred = x.mm(torch.randn(size=(3, 1)))
@@ -113,12 +114,14 @@ def loss_fn(y_pred, x):
113114

114115
return gradients.norm(2, dim=1).mean()
115116

116-
loss = Loss(loss_fn)
117+
loss = Loss(loss_fn, device=available_device)
118+
assert loss._device == torch.device(available_device)
117119
loss.update((y_pred, x))
118120

119121

120-
def test_kwargs_loss():
121-
loss = Loss(nll_loss)
122+
def test_kwargs_loss(available_device):
123+
loss = Loss(nll_loss, device=available_device)
124+
assert loss._device == torch.device(available_device)
122125

123126
y_pred, y, _ = y_test_1()
124127
kwargs = {"weight": torch.tensor([0.1, 0.1, 0.1])}
@@ -330,8 +333,8 @@ def forward(
330333

331334

332335
class DummyLoss3(Loss):
333-
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False):
334-
super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling)
336+
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False, device="cpu"):
337+
super().__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling, device=device)
335338
self._expected_loss = expected_loss
336339
self._loss_fn = loss_fn
337340

@@ -347,7 +350,7 @@ def update(self, output):
347350
assert calculated_loss == self._expected_loss
348351

349352

350-
def test_skip_unrolling_loss():
353+
def test_skip_unrolling_loss(available_device):
351354
a_pred = torch.rand(8, 1)
352355
b_pred = torch.rand(8, 1)
353356
y_pred = [a_pred, b_pred]
@@ -358,7 +361,9 @@ def test_skip_unrolling_loss():
358361
multi_output_mse_loss = CustomMultiMSELoss()
359362
expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true)
360363

361-
loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True)
364+
loss_metric = DummyLoss3(
365+
loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True, device=available_device
366+
)
362367
state = State(output=(y_pred, y_true))
363368
engine = MagicMock(state=state)
364369
loss_metric.iteration_completed(engine)

0 commit comments

Comments
 (0)