Skip to content

Commit 13e0f3a

Browse files
committed
Add loss scale arg, initial distributed loss scale. Maybe fix FX for the model.
1 parent 6675590 commit 13e0f3a

File tree

3 files changed

+58
-23
lines changed

3 files changed

+58
-23
lines changed

timm/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
99
from .loader import create_loader
1010
from .mixup import Mixup, FastCollateMixup
11+
from .naflex_dataset import VariableSeqMapWrapper
12+
from .naflex_loader import create_naflex_loader
1113
from .naflex_transforms import (
1214
ResizeToSequence,
1315
CenterCropToSequence,

timm/models/vision_transformer_flex.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,10 +356,9 @@ def create_attention_mask(
356356
"""
357357
patch_valid = patch_valid.bool()
358358
B = patch_valid.shape[0]
359-
device = patch_valid.device
360359

361360
if num_prefix_tokens > 0:
362-
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
361+
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
363362
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
364363

365364
mask_bool = (patch_valid.unsqueeze(-1) & patch_valid.unsqueeze(1)).unsqueeze(1)
@@ -390,10 +389,9 @@ def create_attention_mask2(
390389
"""
391390
patch_valid = patch_valid.bool()
392391
B, kv_len = patch_valid.shape
393-
device = patch_valid.device
394392

395393
if num_prefix_tokens > 0:
396-
prefix_valid = torch.ones((B, num_prefix_tokens), device=device, dtype=torch.bool)
394+
prefix_valid = patch_valid.new_ones((B, num_prefix_tokens))
397395
patch_valid = torch.cat([prefix_valid, patch_valid], dim=1)
398396
kv_len = patch_valid.shape[1]
399397

train.py

Lines changed: 54 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
from torch.nn.parallel import DistributedDataParallel as NativeDDP
3434

3535
from timm import utils
36-
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
36+
from timm.data import create_dataset, create_loader, create_naflex_loader, resolve_data_config, \
37+
Mixup, FastCollateMixup, AugMixDataset
3738
from timm.layers import convert_splitbn_model, convert_sync_batchnorm, set_fast_norm
3839
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy
3940
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, model_parameters
@@ -403,7 +404,8 @@
403404
help='Sequence lengths to use for NaFlex loader')
404405
group.add_argument('--naflex-max-seq-len', type=int, default=576,
405406
help='Fixed maximum sequence length for NaFlex loader (validation)')
406-
407+
group.add_argument('--naflex-loss-scale', default='linear', type=str,
408+
help='Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")')
407409

408410

409411
def _parse_args():
@@ -762,11 +764,12 @@ def main():
762764
worker_seeding=args.worker_seeding,
763765
)
764766

767+
naflex_mode = False
765768
if args.naflex_loader:
766-
from timm.data.naflex_loader import create_naflex_loader
767769
if utils.is_primary(args):
768770
_logger.info('Using NaFlex loader')
769771

772+
naflex_mode = True
770773
loader_train = create_naflex_loader(
771774
dataset=dataset_train,
772775
patch_size=16, # Could be derived from model config
@@ -804,7 +807,6 @@ def main():
804807
)
805808

806809
if args.naflex_loader:
807-
from timm.data.naflex_loader import create_naflex_loader
808810
# Use largest sequence length for validation
809811
loader_eval = create_naflex_loader(
810812
dataset=dataset_eval,
@@ -950,6 +952,7 @@ def main():
950952
model_ema=model_ema,
951953
mixup_fn=mixup_fn,
952954
num_updates_total=num_epochs * updates_per_epoch,
955+
naflex_mode=naflex_mode,
953956
)
954957

955958
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
@@ -1052,6 +1055,7 @@ def train_one_epoch(
10521055
model_ema=None,
10531056
mixup_fn=None,
10541057
num_updates_total=None,
1058+
naflex_mode=False,
10551059
):
10561060
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
10571061
if args.prefetcher and loader.mixup_enabled:
@@ -1097,10 +1101,10 @@ def train_one_epoch(
10971101
def _forward():
10981102
with amp_autocast():
10991103
output = model(input)
1100-
loss = loss_fn(output, target)
1104+
_loss = loss_fn(output, target)
11011105
if accum_steps > 1:
1102-
loss /= accum_steps
1103-
return loss
1106+
_loss /= accum_steps
1107+
return _loss
11041108

11051109
def _backward(_loss):
11061110
if loss_scaler is not None:
@@ -1124,18 +1128,48 @@ def _backward(_loss):
11241128
)
11251129
optimizer.step()
11261130

1127-
if has_no_sync and not need_update:
1128-
with model.no_sync():
1129-
loss = _forward()
1130-
_backward(loss)
1131-
else:
1132-
loss = _forward()
1133-
_backward(loss)
1134-
1135-
if isinstance(input, dict):
1131+
if naflex_mode:
1132+
assert isinstance(input, dict)
11361133
batch_size = input['patches'].shape[0]
1134+
1135+
# scale gradient vs the minimum batch size (for max seq len)
1136+
if not args.naflex_loss_scale or args.naflex_loss_scale == 'none':
1137+
local_scale = 1.0
1138+
else:
1139+
local_scale = (batch_size / args.batch_size)
1140+
if local_scale == 'sqrt':
1141+
local_scale = local_scale ** 0.5
1142+
1143+
if args.distributed:
1144+
# scale gradient btw distributed ranks, each one can have different batch size
1145+
global_batch_size = utils.reduce_tensor(torch.tensor(batch_size, device=device), 1) # SUM
1146+
dist_scale = args.world_size * batch_size / global_batch_size
1147+
else:
1148+
dist_scale = None
1149+
1150+
if has_no_sync and not need_update:
1151+
with model.no_sync():
1152+
loss = _forward()
1153+
scaled_loss = local_scale * loss
1154+
if dist_scale is not None:
1155+
scaled_loss *= dist_scale
1156+
_backward(scaled_loss)
1157+
else:
1158+
loss = _forward()
1159+
scaled_loss = local_scale * loss
1160+
if dist_scale is not None:
1161+
scaled_loss *= dist_scale
1162+
_backward(scaled_loss)
11371163
else:
11381164
batch_size = input.shape[0]
1165+
if has_no_sync and not need_update:
1166+
with model.no_sync():
1167+
loss = _forward()
1168+
_backward(loss)
1169+
else:
1170+
loss = _forward()
1171+
_backward(loss)
1172+
11391173
losses_m.update(loss.item() * accum_steps, batch_size)
11401174
update_sample_count += batch_size
11411175

@@ -1154,7 +1188,8 @@ def _backward(_loss):
11541188
elif device.type == 'npu':
11551189
torch.npu.synchronize()
11561190
time_now = time.time()
1157-
update_time_m.update(time.time() - update_start_time)
1191+
1192+
update_time_m.update((time.time() - update_start_time) / update_sample_count, update_sample_count)
11581193
update_start_time = time_now
11591194

11601195
if update_idx % args.log_interval == 0:
@@ -1173,8 +1208,8 @@ def _backward(_loss):
11731208
f'Train: {epoch} [{update_idx:>4d}/{updates_per_epoch} '
11741209
f'({100. * (update_idx + 1) / updates_per_epoch:>3.0f}%)] '
11751210
f'Loss: {loss_now:#.3g} ({loss_avg:#.3g}) '
1176-
f'Time: {update_time_m.val:.3f}s, {update_sample_count / update_time_m.val:>7.2f}/s '
1177-
f'({update_time_m.avg:.3f}s, {update_sample_count / update_time_m.avg:>7.2f}/s) '
1211+
f'Time: {update_time_m.val:.3f}s, {1 / update_time_m.val:>7.2f}/s '
1212+
f'({update_time_m.avg:.3f}s, {1 / update_time_m.avg:>7.2f}/s) '
11781213
f'LR: {lr:.3e} '
11791214
f'Data: {data_time_m.val:.3f} ({data_time_m.avg:.3f})'
11801215
)

0 commit comments

Comments
 (0)