Skip to content

Commit 0161de0

Browse files
committed
Switch RandoErasing back to on GPU normal sampling
1 parent 3129bdb commit 0161de0

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

timm/data/random_erasing.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@ def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='
77
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
88
# paths, flip the order so normal is run on CPU if this becomes a problem
99
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
10-
# will revert back to doing normal_() on GPU when it's in next release
1110
if per_pixel:
12-
return torch.empty(
13-
patch_size, dtype=dtype).normal_().to(device=device)
11+
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
1412
elif rand_color:
15-
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
13+
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
1614
else:
1715
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
1816

0 commit comments

Comments
 (0)