Skip to content

Commit 16554e5

Browse files
authored
imagenet: fix typo addressing args.gpu (#1361)
1 parent 8d408d2 commit 16554e5

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

imagenet/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def run_validate(loader, base_progress=0):
372372
i = base_progress + i
373373
if use_accel:
374374
if args.gpu is not None and device.type=='cuda':
375-
torch.accelerator.set_device_index(argps.gpu)
375+
torch.accelerator.set_device_index(args.gpu)
376376
images = images.cuda(args.gpu, non_blocking=True)
377377
target = target.cuda(args.gpu, non_blocking=True)
378378
else:

run_python_examples.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ function imagenet() {
8181
cp sample/train/n/* sample/val/n/
8282
fi
8383
uv run main.py --epochs 1 sample/ || error "imagenet example failed"
84+
uv run main.py --epochs 1 --gpu 0 sample/ || error "imagenet example failed"
8485
}
8586

8687
function language_translation() {

0 commit comments

Comments
 (0)