Skip to content

Commit d285526

Browse files
committed
Lazy loader for TF, more LAB fiddling
1 parent 3fbbd51 commit d285526

File tree

3 files changed

+57
-51
lines changed

3 files changed

+57
-51
lines changed

timm/data/readers/reader_tfds.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,36 @@
1515
import torch.distributed as dist
1616
from PIL import Image
1717

18-
try:
19-
import tensorflow as tf
20-
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
21-
import tensorflow_datasets as tfds
22-
try:
23-
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
24-
has_buggy_even_splits = False
25-
except TypeError:
26-
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
27-
"Please update or use tfds-nightly for better fine-grained split behaviour.")
28-
has_buggy_even_splits = True
29-
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
30-
# import resource
31-
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
32-
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
33-
except ImportError as e:
34-
print(e)
35-
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
36-
raise e
18+
import importlib
19+
20+
class LazyTfLoader:
21+
def __init__(self):
22+
self._tf = None
23+
24+
def __getattr__(self, name):
25+
if self._tf is None:
26+
self._tf = importlib.import_module('tensorflow')
27+
self._tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
28+
return getattr(self._tf, name)
29+
30+
class LazyTfdsLoader:
31+
def __init__(self):
32+
self._tfds = None
33+
self.has_buggy_even_splits = False
34+
35+
def __getattr__(self, name):
36+
if self._tfds is None:
37+
self._tfds = importlib.import_module('tensorflow_datasets')
38+
try:
39+
self._tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
40+
except TypeError:
41+
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
42+
"Please update or use tfds-nightly for better fine-grained split behaviour.")
43+
self.has_buggy_even_splits = True
44+
return getattr(self._tfds, name)
45+
46+
tf = LazyTfLoader()
47+
tfds = LazyTfdsLoader()
3748

3849
from .class_map import load_class_map
3950
from .reader import Reader
@@ -45,7 +56,6 @@
4556
PREFETCH_SIZE = int(os.environ.get('TFDS_PREFETCH_SIZE', 2048)) # samples to prefetch
4657

4758

48-
@tfds.decode.make_decoder()
4959
def decode_example(serialized_image, feature, dct_method='INTEGER_ACCURATE', channels=3):
5060
return tf.image.decode_jpeg(
5161
serialized_image,
@@ -231,7 +241,7 @@ def _lazy_init(self):
231241
if should_subsplit:
232242
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
233243
# read patterns for distributed training (overlap across shards) so better to use InputContext there
234-
if has_buggy_even_splits:
244+
if tfds.has_buggy_even_splits:
235245
# my even_split workaround doesn't work on subsplits, upgrade tfds!
236246
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
237247
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
@@ -253,10 +263,11 @@ def _lazy_init(self):
253263
shuffle_reshuffle_each_iteration=True,
254264
input_context=input_context,
255265
)
266+
decode_fn = tfds.decode.make_decoder()(decode_example)
256267
ds = self.builder.as_dataset(
257268
split=self.subsplit or self.split,
258269
shuffle_files=self.is_training,
259-
decoders=dict(image=decode_example(channels=1 if self.input_img_mode == 'L' else 3)),
270+
decoders=dict(image=decode_fn(channels=1 if self.input_img_mode == 'L' else 3)),
260271
read_config=read_config,
261272
)
262273
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers

timm/data/transforms.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -127,14 +127,16 @@ def rgb_to_lab_tensor(
127127
rgb_img: torch.Tensor,
128128
normalized: bool = True,
129129
srgb_input: bool = True,
130-
) -> torch.Tensor:
130+
split_channels: bool = False,
131+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
131132
"""
132133
Convert RGB image to LAB color space using tensor operations.
133134
134135
Args:
135136
rgb_img: Tensor of shape (..., 3) with values in range [0, 255]
136137
normalized: If True, outputs L,a,b in [0, 1] range instead of native LAB ranges
137-
138+
srgb_input: Input is gamma corrected sRGB, otherwise linear RGB is assumed (rare unless part of a pipeline)
139+
split_channels: If True, outputs a tuple of flattened colour channels instead of stacked image
138140
Returns:
139141
lab_img: Tensor of same shape with either:
140142
- normalized=False: L in [0, 100] and a,b in [-128, 127]
@@ -152,13 +154,14 @@ def rgb_to_lab_tensor(
152154
rgb_img = srgb_to_linear(rgb_img)
153155

154156
# FIXME transforms before this are causing -ve values, can have a large impact on this conversion
155-
rgb_img.clamp_(0, 1.0)
157+
rgb_img = rgb_img.clamp(0, 1.0)
156158

157159
# Convert to XYZ using matrix multiplication
158160
rgb_to_xyz = torch.tensor([
159-
[0.412453, 0.357580, 0.180423],
160-
[0.212671, 0.715160, 0.072169],
161-
[0.019334, 0.119193, 0.950227]
161+
# X Y Z
162+
[0.412453, 0.212671, 0.019334], # R
163+
[0.357580, 0.715160, 0.119193], # G
164+
[0.180423, 0.072169, 0.950227], # B
162165
], device=rgb_img.device)
163166

164167
# Reshape input for matrix multiplication if needed
@@ -167,38 +170,30 @@ def rgb_to_lab_tensor(
167170
rgb_img = rgb_img.reshape(-1, 3)
168171

169172
# Perform matrix multiplication
170-
xyz = torch.matmul(rgb_img, rgb_to_xyz.T)
173+
xyz = rgb_img @ rgb_to_xyz
171174

172175
# Adjust XYZ values
173-
xyz[..., 0].div_(xn)
174-
xyz[..., 1].div_(yn)
175-
xyz[..., 2].div_(zn)
176+
xyz.div_(torch.tensor([xn, yn, zn], device=xyz.device))
176177

177178
# Step 4: XYZ to LAB
178-
lab = torch.where(
179+
fxfyfz = torch.where(
179180
xyz > epsilon,
180181
torch.pow(xyz, 1 / 3),
181182
(kappa * xyz + 16) / 116
182183
)
183184

185+
L = 116 * fxfyfz[..., 1] - 16
186+
a = 500 * (fxfyfz[..., 0] - fxfyfz[..., 1])
187+
b = 200 * (fxfyfz[..., 1] - fxfyfz[..., 2])
184188
if normalized:
185-
# Calculate normalized [0,1] L,a,b values directly
186-
# L: map [0,100] to [0,1] : (116y - 16)/100 = 1.16y - 0.16
187-
# a: map [-128,127] to [0,1] : (500(x-y) + 128)/255 ≈ 1.96(x-y) + 0.502
188-
# b: map [-128,127] to [0,1] : (200(y-z) + 128)/255 ≈ 0.784(y-z) + 0.502
189-
shift_128 = 128 / 255
190-
a_scale = 500 / 255
191-
b_scale = 200 / 255
192-
L = 1.16 * lab[..., 1] - 0.16
193-
a = a_scale * (lab[..., 0] - lab[..., 1]) + shift_128
194-
b = b_scale * (lab[..., 1] - lab[..., 2]) + shift_128
195-
else:
196-
# Calculate native range L,a,b values
197-
L = 116 * lab[..., 1] - 16
198-
a = 500 * (lab[..., 0] - lab[..., 1])
199-
b = 200 * (lab[..., 1] - lab[..., 2])
189+
# output in rage [0, 1] for each channel
190+
L.div_(100)
191+
a.add_(128).div_(255)
192+
b.add_(128).div_(255)
193+
194+
if split_channels:
195+
return L, a, b
200196

201-
# Stack the results
202197
lab = torch.stack([L, a, b], dim=-1)
203198

204199
# Restore original shape if needed

timm/data/transforms_factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def transforms_imagenet_train(
8686
use_prefetcher: bool = False,
8787
normalize: bool = True,
8888
separate: bool = False,
89-
use_tensor: Optional[bool] = True, # FIXME forced True for testing
89+
use_tensor: Optional[bool] = False,
9090
):
9191
""" ImageNet-oriented image transforms for training.
9292
@@ -273,7 +273,7 @@ def transforms_imagenet_eval(
273273
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
274274
use_prefetcher: bool = False,
275275
normalize: bool = True,
276-
use_tensor: bool = True,
276+
use_tensor: bool = False,
277277
):
278278
""" ImageNet-oriented image transform for evaluation and inference.
279279

0 commit comments

Comments
 (0)