1515
1616
1717class DummyLoss1 (Loss ):
18- def __init__ (self , loss_fn , true_output , output_transform = lambda x : x ):
19- super (DummyLoss1 , self ).__init__ (loss_fn , output_transform = output_transform )
18+ def __init__ (self , loss_fn , true_output , output_transform = lambda x : x , device = "cpu" ):
19+ super ().__init__ (loss_fn , output_transform = output_transform , device = device )
2020 print (true_output )
2121 self .true_output = true_output
2222
@@ -30,23 +30,23 @@ def update(self, output):
3030 assert output == self .true_output
3131
3232
33- def test_output_as_mapping_without_criterion_kwargs ():
33+ def test_output_as_mapping_without_criterion_kwargs (available_device ):
3434 y_pred = torch .tensor ([[2.0 ], [- 2.0 ]])
3535 y = torch .zeros (2 )
3636 criterion_kwargs = {}
3737
38- loss_metric = DummyLoss1 (nll_loss , true_output = (y_pred , y , criterion_kwargs ))
38+ loss_metric = DummyLoss1 (nll_loss , true_output = (y_pred , y , criterion_kwargs ), device = available_device )
3939 state = State (output = ({"y_pred" : y_pred , "y" : y , "criterion_kwargs" : {}}))
4040 engine = MagicMock (state = state )
4141 loss_metric .iteration_completed (engine )
4242
4343
44- def test_output_as_mapping_with_criterion_kwargs ():
44+ def test_output_as_mapping_with_criterion_kwargs (available_device ):
4545 y_pred = torch .tensor ([[2.0 ], [- 2.0 ]])
4646 y = torch .zeros (2 )
4747 criterion_kwargs = {"reduction" : "sum" }
4848
49- loss_metric = DummyLoss1 (nll_loss , true_output = (y_pred , y , criterion_kwargs ))
49+ loss_metric = DummyLoss1 (nll_loss , true_output = (y_pred , y , criterion_kwargs ), device = available_device )
5050 state = State (output = ({"y_pred" : y_pred , "y" : y , "criterion_kwargs" : {"reduction" : "sum" }}))
5151 engine = MagicMock (state = state )
5252 loss_metric .iteration_completed (engine )
@@ -79,8 +79,9 @@ def test_zero_div():
7979
8080
8181@pytest .mark .parametrize ("criterion" , [nll_loss , nn .NLLLoss ()])
82- def test_compute (criterion ):
83- loss = Loss (criterion )
82+ def test_compute (criterion , available_device ):
83+ loss = Loss (criterion , device = available_device )
84+ assert loss ._device == torch .device (available_device )
8485
8586 y_pred , y , expected_loss = y_test_1 ()
8687 loss .update ((y_pred , y ))
@@ -99,7 +100,7 @@ def test_non_averaging_loss():
99100 loss .update ((y_pred , y ))
100101
101102
102- def test_gradient_based_loss ():
103+ def test_gradient_based_loss (available_device ):
103104 # Tests https://github.com/pytorch/ignite/issues/1674
104105 x = torch .tensor ([[0.1 , 0.4 , 0.5 ], [0.1 , 0.7 , 0.2 ]], requires_grad = True )
105106 y_pred = x .mm (torch .randn (size = (3 , 1 )))
@@ -113,12 +114,14 @@ def loss_fn(y_pred, x):
113114
114115 return gradients .norm (2 , dim = 1 ).mean ()
115116
116- loss = Loss (loss_fn )
117+ loss = Loss (loss_fn , device = available_device )
118+ assert loss ._device == torch .device (available_device )
117119 loss .update ((y_pred , x ))
118120
119121
120- def test_kwargs_loss ():
121- loss = Loss (nll_loss )
122+ def test_kwargs_loss (available_device ):
123+ loss = Loss (nll_loss , device = available_device )
124+ assert loss ._device == torch .device (available_device )
122125
123126 y_pred , y , _ = y_test_1 ()
124127 kwargs = {"weight" : torch .tensor ([0.1 , 0.1 , 0.1 ])}
@@ -330,8 +333,8 @@ def forward(
330333
331334
332335class DummyLoss3 (Loss ):
333- def __init__ (self , loss_fn , expected_loss , output_transform = lambda x : x , skip_unrolling = False ):
334- super (DummyLoss3 , self ).__init__ (loss_fn , output_transform = output_transform , skip_unrolling = skip_unrolling )
336+ def __init__ (self , loss_fn , expected_loss , output_transform = lambda x : x , skip_unrolling = False , device = "cpu" ):
337+ super ().__init__ (loss_fn , output_transform = output_transform , skip_unrolling = skip_unrolling , device = device )
335338 self ._expected_loss = expected_loss
336339 self ._loss_fn = loss_fn
337340
@@ -347,7 +350,7 @@ def update(self, output):
347350 assert calculated_loss == self ._expected_loss
348351
349352
350- def test_skip_unrolling_loss ():
353+ def test_skip_unrolling_loss (available_device ):
351354 a_pred = torch .rand (8 , 1 )
352355 b_pred = torch .rand (8 , 1 )
353356 y_pred = [a_pred , b_pred ]
@@ -358,7 +361,9 @@ def test_skip_unrolling_loss():
358361 multi_output_mse_loss = CustomMultiMSELoss ()
359362 expected_loss = multi_output_mse_loss (y_pred = y_pred , y_true = y_true )
360363
361- loss_metric = DummyLoss3 (loss_fn = multi_output_mse_loss , expected_loss = expected_loss , skip_unrolling = True )
364+ loss_metric = DummyLoss3 (
365+ loss_fn = multi_output_mse_loss , expected_loss = expected_loss , skip_unrolling = True , device = available_device
366+ )
362367 state = State (output = (y_pred , y_true ))
363368 engine = MagicMock (state = state )
364369 loss_metric .iteration_completed (engine )
0 commit comments