Skip to content

Commit 3fbbd51

Browse files
committed
Testing some LAB stuff
1 parent 3b181b7 commit 3fbbd51

File tree

2 files changed

+135
-5
lines changed

2 files changed

+135
-5
lines changed

timm/data/transforms.py

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def __repr__(self) -> str:
9090
return f"{self.__class__.__name__}()"
9191

9292

93-
class ToLab(transforms.ToTensor):
93+
class ToLabPIL:
9494

9595
def __init__(self) -> None:
9696
super().__init__()
@@ -115,6 +115,121 @@ def __repr__(self) -> str:
115115
return f"{self.__class__.__name__}()"
116116

117117

118+
def srgb_to_linear(srgb_image: torch.Tensor) -> torch.Tensor:
119+
return torch.where(
120+
srgb_image <= 0.04045,
121+
srgb_image / 12.92,
122+
((srgb_image + 0.055) / 1.055) ** 2.4
123+
)
124+
125+
126+
def rgb_to_lab_tensor(
127+
rgb_img: torch.Tensor,
128+
normalized: bool = True,
129+
srgb_input: bool = True,
130+
) -> torch.Tensor:
131+
"""
132+
Convert RGB image to LAB color space using tensor operations.
133+
134+
Args:
135+
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
136+
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
137+
138+
Returns:
139+
lab_img: Tensor of same shape with either:
140+
- normalized=False: L in [0, 100] and a,b in [-128, 127]
141+
- normalized=True: L,a,b in [0, 1]
142+
"""
143+
# Constants
144+
epsilon = 216 / 24389
145+
kappa = 24389 / 27
146+
xn = 0.95047
147+
yn = 1.0
148+
zn = 1.08883
149+
150+
# Convert sRGB to linear RGB
151+
if srgb_input:
152+
rgb_img = srgb_to_linear(rgb_img)
153+
154+
# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
155+
rgb_img.clamp_(0, 1.0)
156+
157+
# Convert to XYZ using matrix multiplication
158+
rgb_to_xyz = torch.tensor([
159+
[0.412453, 0.357580, 0.180423],
160+
[0.212671, 0.715160, 0.072169],
161+
[0.019334, 0.119193, 0.950227]
162+
], device=rgb_img.device)
163+
164+
# Reshape input for matrix multiplication if needed
165+
original_shape = rgb_img.shape
166+
if len(original_shape) > 2:
167+
rgb_img = rgb_img.reshape(-1, 3)
168+
169+
# Perform matrix multiplication
170+
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)
171+
172+
# Adjust XYZ values
173+
xyz[..., 0].div_(xn)
174+
xyz[..., 1].div_(yn)
175+
xyz[..., 2].div_(zn)
176+
177+
# Step 4: XYZ to LAB
178+
lab = torch.where(
179+
xyz > epsilon,
180+
torch.pow(xyz, 1 / 3),
181+
(kappa * xyz + 16) / 116
182+
)
183+
184+
if normalized:
185+
# Calculate normalized [0,1] L,a,b values directly
186+
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
187+
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
188+
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
189+
shift_128 = 128 / 255
190+
a_scale = 500 / 255
191+
b_scale = 200 / 255
192+
L = 1.16 * lab[..., 1] - 0.16
193+
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
194+
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
195+
else:
196+
# Calculate native range L,a,b values
197+
L = 116 * lab[..., 1] - 16
198+
a = 500 * (lab[..., 0] - lab[..., 1])
199+
b = 200 * (lab[..., 1] - lab[..., 2])
200+
201+
# Stack the results
202+
lab = torch.stack([L, a, b], dim=-1)
203+
204+
# Restore original shape if needed
205+
if len(original_shape) > 2:
206+
lab = lab.reshape(original_shape)
207+
208+
return lab
209+
210+
211+
class ToLabTensor:
212+
def __init__(self, srgb_input=False, normalized=True) -> None:
213+
self.srgb_input = srgb_input
214+
self.normalized = normalized
215+
216+
def __call__(self, pic) -> torch.Tensor:
217+
return rgb_to_lab_tensor(
218+
pic,
219+
normalized=self.normalized,
220+
srgb_input=self.srgb_input,
221+
)
222+
223+
224+
class ToLinearRgb:
225+
def __init__(self):
226+
pass
227+
228+
def __call__(self, pic) -> torch.Tensor:
229+
assert isinstance(pic, torch.Tensor)
230+
return srgb_to_linear(pic)
231+
232+
118233
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
119234
# favor of the Image.Resampling enum. The top-level resampling attributes will be
120235
# removed in Pillow 10.

timm/data/transforms_factory.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
1515
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, \
1616
ResizeKeepRatio, CenterCropOrPad, RandomCropOrPad, TrimBorder, ToNumpy, MaybeToTensor, MaybePILToTensor
17+
from timm.data.transforms import ToLabTensor, ToLinearRgb
1718
from timm.data.random_erasing import RandomErasing
1819

1920

@@ -123,7 +124,10 @@ def transforms_imagenet_train(
123124
* normalizes and converts the branches above with the third, final transform
124125
"""
125126
if use_tensor:
126-
primary_tfl = [MaybeToTensor()]
127+
primary_tfl = [
128+
MaybeToTensor(),
129+
ToLinearRgb(), # FIXME
130+
]
127131
else:
128132
primary_tfl = []
129133

@@ -236,6 +240,7 @@ def transforms_imagenet_train(
236240
if not use_tensor:
237241
final_tfl += [MaybeToTensor()]
238242
final_tfl += [
243+
ToLabTensor(), # FIXME
239244
transforms.Normalize(
240245
mean=torch.tensor(mean),
241246
std=torch.tensor(std),
@@ -268,6 +273,7 @@ def transforms_imagenet_eval(
268273
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
269274
use_prefetcher: bool = False,
270275
normalize: bool = True,
276+
use_tensor: bool = True,
271277
):
272278
""" ImageNet-oriented image transform for evaluation and inference.
273279
@@ -294,7 +300,13 @@ def transforms_imagenet_eval(
294300
scale_size = math.floor(img_size / crop_pct)
295301
scale_size = (scale_size, scale_size)
296302

297-
tfl = []
303+
if use_tensor:
304+
tfl = [
305+
MaybeToTensor(),
306+
ToLinearRgb(), # FIXME
307+
]
308+
else:
309+
tfl = []
298310

299311
if crop_border_pixels:
300312
tfl += [TrimBorder(crop_border_pixels)]
@@ -332,10 +344,13 @@ def transforms_imagenet_eval(
332344
tfl += [ToNumpy()]
333345
elif not normalize:
334346
# when normalize disabled, converted to tensor without scaling, keeps original dtype
335-
tfl += [MaybePILToTensor()]
347+
if not use_tensor:
348+
tfl += [MaybePILToTensor()]
336349
else:
350+
if not use_tensor:
351+
tfl += [MaybeToTensor()]
337352
tfl += [
338-
MaybeToTensor(),
353+
ToLabTensor(), # FIXME
339354
transforms.Normalize(
340355
mean=torch.tensor(mean),
341356
std=torch.tensor(std),

0 commit comments

Comments
 (0)