Skip to content

Commit 75a263f

Browse files
committed
Add Accelerator API to tensorboard_profiler tutorial
Signed-off-by: jafraustro <[email protected]>
1 parent fc8981b commit 75a263f

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

intermediate_source/tensorboard_profiler_tutorial.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
Introduction
1616
------------
1717
PyTorch 1.8 includes an updated profiler API capable of
18-
recording the CPU side operations as well as the CUDA kernel launches on the GPU side.
18+
recording CPU-side operations as well as device-side kernel launches (for example CUDA or XPU),
19+
when supported by the platform and underlying tracing integrations.
20+
1921
The profiler can visualize this information
2022
in TensorBoard Plugin and provide analysis of the performance bottlenecks.
2123
@@ -76,9 +78,10 @@
7678
# Next, create Resnet model, loss function, and optimizer objects.
7779
# To run on GPU, move model and loss to GPU device.
7880

79-
device = torch.device("cuda:0")
80-
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').cuda(device)
81-
criterion = torch.nn.CrossEntropyLoss().cuda(device)
81+
acc = torch.accelerator.current_accelerator()
82+
device = torch.device(f'{acc}:0')
83+
model = torchvision.models.resnet18(weights='IMAGENET1K_V1').to(device)
84+
criterion = torch.nn.CrossEntropyLoss().to(device)
8285
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
8386
model.train()
8487

@@ -346,7 +349,7 @@ def train(data):
346349
# For example, "GPU0" means the following table only shows each operator's memory usage on GPU 0, not including CPU or other GPUs.
347350
#
348351
# The memory curve shows the trends of memory consumption. The "Allocated" curve shows the total memory that is actually
349-
# in use, e.g., tensors. In PyTorch, caching mechanism is employed in CUDA allocator and some other allocators. The
352+
# in use, e.g., tensors. In PyTorch, caching mechanism is employed in the device allocator and some other allocators. The
350353
# "Reserved" curve shows the total memory that is reserved by the allocator. You can left click and drag on the graph
351354
# to select events in the desired range:
352355
#

0 commit comments

Comments
 (0)