Skip to content

Commit d1e7779

Browse files
committed
Merge branch 'main' into grad_checkpointing
2 parents b0b28e2 + 9d294cd commit d1e7779

24 files changed

+5134
-296
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,25 @@
1212

1313
## What's New
1414

15+
## June 5, 2025
16+
* Initial NaFlexVit model code. NaFlexVit is a Vision Transformer with:
17+
1. Encapsulated embedding and position encoding in a single module
18+
2. Support for nn.Linear patch embedding on pre-patchified (dictionary) inputs
19+
3. Support for NaFlex variable aspect, variable resolution (SigLip-2: https://arxiv.org/abs/2502.14786)
20+
4. Support for FlexiViT variable patch size (https://arxiv.org/abs/2212.08013)
21+
5. Support for NaViT fractional/factorized position embedding (https://arxiv.org/abs/2307.06304)
22+
* Existing vit models in `vision_transformer.py` can be loaded into the NaFlexVit model by adding the `use_naflex=True` flag to `create_model`
23+
* Some native weights coming soon
24+
* A full NaFlex data pipeline is available that allows training / fine-tuning / evaluating with variable aspect / size images
25+
* To enable in `train.py` and `validate.py` add the `--naflex-loader` arg, must be used with a NaFlexVit
26+
* To evaluate an existing (classic) ViT loaded in NaFlexVit model w/ NaFlex data pipe:
27+
* `python validate.py /imagenet --amp -j 8 --model vit_base_patch16_224 --model-kwargs use_naflex=True --naflex-loader --naflex-max-seq-len 256`
28+
* The training has some extra args features worth noting
29+
* The `--naflex-train-seq-lens'` argument specifies which sequence lengths to randomly pick from per batch during training
30+
* The `--naflex-max-seq-len` argument sets the target sequence length for validation
31+
* Adding `--model-kwargs enable_patch_interpolator=True --naflex-patch-sizes 12 16 24` will enable random patch size selection per-batch w/ interpolation
32+
* The `--naflex-loss-scale` arg changes loss scaling mode per batch relative to the batch size, `timm` NaFlex loading changes the batch size for each seq len
33+
1534
## May 28, 2025
1635
* Add a number of small/fast models thanks to https://github.com/brianhou0208
1736
* SwiftFormer - [(ICCV2023) SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications](https://github.com/Amshaker/SwiftFormer)

tests/test_models.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@
5656
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera', 'fastvit', 'hieradet_sam2', 'aimv2*', 'tnt',
5757
'tiny_vit', 'vovnet', 'tresnet', 'rexnet', 'resnetv2', 'repghost', 'repvit', 'pvt_v2', 'nextvit', 'nest',
5858
'mambaout', 'inception_next', 'inception_v4', 'hgnet', 'gcvit', 'focalnet', 'efficientformer_v2', 'edgenext',
59-
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet',
59+
'davit', 'rdnet', 'convnext', 'pit', 'starnet', 'shvit', 'fasternet', 'swiftformer', 'ghostnet', 'naflexvit'
6060
]
6161

6262
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
6363
NON_STD_FILTERS = [
64-
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
64+
'vit_*', 'naflexvit*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
6565
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*', 'aimv2*', 'swiftformer_*',
6666
'poolformer_*', 'volo_*', 'sequencer2d_*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'sam_hiera*',
6767
'eva_*', 'flexivit*', 'eva02*', 'samvit_*', 'efficientvit_m*', 'tiny_vit_*', 'hiera_*', 'vitamin*', 'test_vit*',
@@ -81,7 +81,7 @@
8181
EXCLUDE_FILTERS = ['*enormous*']
8282
NON_STD_EXCLUDE_FILTERS = ['*gigantic*', '*enormous*', '*_3b_*']
8383

84-
EXCLUDE_JIT_FILTERS = ['hiera_*']
84+
EXCLUDE_JIT_FILTERS = ['hiera_*', '*naflex*']
8585

8686
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
8787
TARGET_BWD_SIZE = 128

timm/data/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11+
from .naflex_dataset import NaFlexMapDatasetWrapper, calculate_naflex_batch_size
12+
from .naflex_loader import create_naflex_loader
13+
from .naflex_mixup import NaFlexMixup, pairwise_mixup_target, mix_batch_variable_size
14+
from .naflex_transforms import (
15+
ResizeToSequence,
16+
CenterCropToSequence,
17+
RandomCropToSequence,
18+
RandomResizedCropToSequence,
19+
ResizeKeepRatioToSequence,
20+
Patchify,
21+
patchify_image,
22+
)
1123
from .readers import create_reader
1224
from .readers import get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
1325
from .real_labels import RealLabelsImagenet

timm/data/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def fast_collate(batch):
3333
if isinstance(batch[0][0], tuple):
3434
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
3535
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
36+
is_np = isinstance(batch[0][0], np.ndarray)
3637
inner_tuple_size = len(batch[0][0])
3738
flattened_batch_size = batch_size * inner_tuple_size
3839
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
@@ -41,7 +42,10 @@ def fast_collate(batch):
4142
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
4243
for j in range(inner_tuple_size):
4344
targets[i + j * batch_size] = batch[i][1]
44-
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
45+
if is_np:
46+
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
47+
else:
48+
tensor[i + j * batch_size] += batch[i][0][j]
4549
return tensor, targets
4650
elif isinstance(batch[0][0], np.ndarray):
4751
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)

timm/data/mixup.py

Lines changed: 52 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -229,29 +229,41 @@ def _mix_elem_collate(self, output, batch, half=False):
229229
num_elem = batch_size // 2 if half else batch_size
230230
assert len(output) == num_elem
231231
lam_batch, use_cutmix = self._params_per_elem(num_elem)
232+
is_np = isinstance(batch[0][0], np.ndarray)
233+
232234
for i in range(num_elem):
233235
j = batch_size - i - 1
234236
lam = lam_batch[i]
235237
mixed = batch[i][0]
236238
if lam != 1.:
237239
if use_cutmix[i]:
238240
if not half:
239-
mixed = mixed.copy()
241+
mixed = mixed.copy() if is_np else mixed.clone()
240242
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
241-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
243+
output.shape,
244+
lam,
245+
ratio_minmax=self.cutmix_minmax,
246+
correct_lam=self.correct_lam,
247+
)
242248
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
243249
lam_batch[i] = lam
244250
else:
245-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
246-
np.rint(mixed, out=mixed)
247-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
251+
if is_np:
252+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
253+
np.rint(mixed, out=mixed)
254+
else:
255+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
256+
torch.round(mixed, out=mixed)
257+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
248258
if half:
249259
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
250260
return torch.tensor(lam_batch).unsqueeze(1)
251261

252262
def _mix_pair_collate(self, output, batch):
253263
batch_size = len(batch)
254264
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
265+
is_np = isinstance(batch[0][0], np.ndarray)
266+
255267
for i in range(batch_size // 2):
256268
j = batch_size - i - 1
257269
lam = lam_batch[i]
@@ -261,39 +273,60 @@ def _mix_pair_collate(self, output, batch):
261273
if lam < 1.:
262274
if use_cutmix[i]:
263275
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
264-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
265-
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
276+
output.shape,
277+
lam,
278+
ratio_minmax=self.cutmix_minmax,
279+
correct_lam=self.correct_lam,
280+
)
281+
patch_i = mixed_i[:, yl:yh, xl:xh].copy() if is_np else mixed_i[:, yl:yh, xl:xh].clone()
266282
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
267283
mixed_j[:, yl:yh, xl:xh] = patch_i
268284
lam_batch[i] = lam
269285
else:
270-
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
271-
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
272-
mixed_i = mixed_temp
273-
np.rint(mixed_j, out=mixed_j)
274-
np.rint(mixed_i, out=mixed_i)
275-
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
276-
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
286+
if is_np:
287+
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
288+
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
289+
mixed_i = mixed_temp
290+
np.rint(mixed_j, out=mixed_j)
291+
np.rint(mixed_i, out=mixed_i)
292+
else:
293+
mixed_temp = mixed_i.float() * lam + mixed_j.float() * (1 - lam)
294+
mixed_j = mixed_j.float() * lam + mixed_i.float() * (1 - lam)
295+
mixed_i = mixed_temp
296+
torch.round(mixed_j, out=mixed_j)
297+
torch.round(mixed_i, out=mixed_i)
298+
output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) if is_np else mixed_i.byte()
299+
output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) if is_np else mixed_j.byte()
277300
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
278301
return torch.tensor(lam_batch).unsqueeze(1)
279302

280303
def _mix_batch_collate(self, output, batch):
281304
batch_size = len(batch)
282305
lam, use_cutmix = self._params_per_batch()
306+
is_np = isinstance(batch[0][0], np.ndarray)
307+
283308
if use_cutmix:
284309
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
285-
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
310+
output.shape,
311+
lam,
312+
ratio_minmax=self.cutmix_minmax,
313+
correct_lam=self.correct_lam,
314+
)
286315
for i in range(batch_size):
287316
j = batch_size - i - 1
288317
mixed = batch[i][0]
289318
if lam != 1.:
290319
if use_cutmix:
291-
mixed = mixed.copy() # don't want to modify the original while iterating
320+
mixed = mixed.copy() if is_np else mixed.clone() # don't want to modify the original while iterating
292321
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
293322
else:
294-
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
295-
np.rint(mixed, out=mixed)
296-
output[i] += torch.from_numpy(mixed.astype(np.uint8))
323+
if is_np:
324+
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
325+
np.rint(mixed, out=mixed)
326+
else:
327+
mixed = mixed.float() * lam + batch[j][0].float() * (1 - lam)
328+
torch.round(mixed, out=mixed)
329+
output[i] += torch.from_numpy(mixed.astype(np.uint8)) if is_np else mixed.byte()
297330
return lam
298331

299332
def __call__(self, batch, _=None):

0 commit comments

Comments
 (0)