Skip to content

Commit 9d19dcb

Browse files
committed
Add Accelerator API to neural tangent kernels (pytorch#3578)
- Integrate the Accelerator API to support multiple accelerators and improve backend initialization. Signed-off-by: jafraustro <[email protected]>
1 parent 75a263f commit 9d19dcb

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

intermediate_source/neural_tangent_kernels.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
1414
.. note::
1515
16-
This tutorial requires PyTorch 2.0.0 or later.
16+
This tutorial requires PyTorch 2.6.0 or later.
1717
1818
Setup
1919
-----
@@ -24,7 +24,12 @@
2424
import torch
2525
import torch.nn as nn
2626
from torch.func import functional_call, vmap, vjp, jvp, jacrev
27-
device = 'cuda' if torch.cuda.device_count() > 0 else 'cpu'
27+
28+
if torch.accelerator.is_available() and torch.accelerator.device_count() > 0:
29+
device = torch.accelerator.current_accelerator()
30+
else:
31+
device = torch.device("cpu")
32+
2833

2934
class CNN(nn.Module):
3035
def __init__(self):

0 commit comments

Comments
 (0)