Skip to content

Commit 8341cb9

Browse files
committed
Add Accelerator API to ensembling tutorial
Signed-off-by: jafraustro <[email protected]>
1 parent fc8981b commit 8341cb9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

intermediate_source/ensembling.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ def forward(self, x):
5050
# minibatch of size 64. Furthermore, lets say we want to combine the predictions
5151
# from 10 different models.
5252

53-
device = 'cuda'
53+
54+
device = torch.accelerator.current_accelerator()
5455
num_models = 10
5556

5657
data = torch.randn(100, 64, 1, 28, 28, device=device)

0 commit comments

Comments
 (0)