Skip to content

Commit 8fd12ce

Browse files
Ensure constant tensors are on the correct device (GPU)
Moved constant tensors to the same device as input for compatibility.
1 parent 90be62c commit 8fd12ce

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

torchstain/torch/utils/lab2rgb.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
def lab2rgb(lab):
77
lab = lab.type(torch.float32)
88

9+
# Move constant tensors to the same device as input
10+
device = lab.device
11+
_white_device = _white.to(device)
12+
_xyz2rgb_device = _xyz2rgb.to(device)
13+
914
# rescale back from OpenCV format and extract LAB channel
1015
L, a, b = lab[0] / 2.55, lab[1] - 128, lab[2] - 128
1116

@@ -24,10 +29,10 @@ def lab2rgb(lab):
2429
out.masked_scatter_(not_mask, (torch.masked_select(out, not_mask) - 16 / 116) / 7.787)
2530

2631
# rescale to the reference white (illuminant)
27-
out = torch.mul(out, _white.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
32+
out = torch.mul(out, _white_device.type(out.dtype).unsqueeze(dim=-1).unsqueeze(dim=-1))
2833

2934
# convert XYZ -> RGB color domain
30-
arr = torch.tensordot(out, torch.t(_xyz2rgb).type(out.dtype), dims=([0], [0]))
35+
arr = torch.tensordot(out, torch.t(_xyz2rgb_device).type(out.dtype), dims=([0], [0]))
3136
mask = arr > 0.0031308
3237
not_mask = torch.logical_not(mask)
3338
arr.masked_scatter_(mask, 1.055 * torch.pow(torch.masked_select(arr, mask), 1 / 2.4) - 0.055)

0 commit comments

Comments
 (0)