3333from torch .nn .parallel import DistributedDataParallel as NativeDDP
3434
3535from timm import utils
36- from timm .data import create_dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
36+ from timm .data import create_dataset , create_loader , create_naflex_loader , resolve_data_config , \
37+ Mixup , FastCollateMixup , AugMixDataset
3738from timm .layers import convert_splitbn_model , convert_sync_batchnorm , set_fast_norm
3839from timm .loss import JsdCrossEntropy , SoftTargetCrossEntropy , BinaryCrossEntropy , LabelSmoothingCrossEntropy
3940from timm .models import create_model , safe_model_name , resume_checkpoint , load_checkpoint , model_parameters
403404 help = 'Sequence lengths to use for NaFlex loader' )
404405group .add_argument ('--naflex-max-seq-len' , type = int , default = 576 ,
405406 help = 'Fixed maximum sequence length for NaFlex loader (validation)' )
406-
407+ group .add_argument ('--naflex-loss-scale' , default = 'linear' , type = str ,
408+ help = 'Scale loss (gradient) by batch_size ("none", "sqrt", or "linear")' )
407409
408410
409411def _parse_args ():
@@ -762,11 +764,12 @@ def main():
762764 worker_seeding = args .worker_seeding ,
763765 )
764766
767+ naflex_mode = False
765768 if args .naflex_loader :
766- from timm .data .naflex_loader import create_naflex_loader
767769 if utils .is_primary (args ):
768770 _logger .info ('Using NaFlex loader' )
769771
772+ naflex_mode = True
770773 loader_train = create_naflex_loader (
771774 dataset = dataset_train ,
772775 patch_size = 16 , # Could be derived from model config
@@ -804,7 +807,6 @@ def main():
804807 )
805808
806809 if args .naflex_loader :
807- from timm .data .naflex_loader import create_naflex_loader
808810 # Use largest sequence length for validation
809811 loader_eval = create_naflex_loader (
810812 dataset = dataset_eval ,
@@ -950,6 +952,7 @@ def main():
950952 model_ema = model_ema ,
951953 mixup_fn = mixup_fn ,
952954 num_updates_total = num_epochs * updates_per_epoch ,
955+ naflex_mode = naflex_mode ,
953956 )
954957
955958 if args .distributed and args .dist_bn in ('broadcast' , 'reduce' ):
@@ -1052,6 +1055,7 @@ def train_one_epoch(
10521055 model_ema = None ,
10531056 mixup_fn = None ,
10541057 num_updates_total = None ,
1058+ naflex_mode = False ,
10551059):
10561060 if args .mixup_off_epoch and epoch >= args .mixup_off_epoch :
10571061 if args .prefetcher and loader .mixup_enabled :
@@ -1097,10 +1101,10 @@ def train_one_epoch(
10971101 def _forward ():
10981102 with amp_autocast ():
10991103 output = model (input )
1100- loss = loss_fn (output , target )
1104+ _loss = loss_fn (output , target )
11011105 if accum_steps > 1 :
1102- loss /= accum_steps
1103- return loss
1106+ _loss /= accum_steps
1107+ return _loss
11041108
11051109 def _backward (_loss ):
11061110 if loss_scaler is not None :
@@ -1124,18 +1128,48 @@ def _backward(_loss):
11241128 )
11251129 optimizer .step ()
11261130
1127- if has_no_sync and not need_update :
1128- with model .no_sync ():
1129- loss = _forward ()
1130- _backward (loss )
1131- else :
1132- loss = _forward ()
1133- _backward (loss )
1134-
1135- if isinstance (input , dict ):
1131+ if naflex_mode :
1132+ assert isinstance (input , dict )
11361133 batch_size = input ['patches' ].shape [0 ]
1134+
1135+ # scale gradient vs the minimum batch size (for max seq len)
1136+ if not args .naflex_loss_scale or args .naflex_loss_scale == 'none' :
1137+ local_scale = 1.0
1138+ else :
1139+ local_scale = (batch_size / args .batch_size )
1140+ if local_scale == 'sqrt' :
1141+ local_scale = local_scale ** 0.5
1142+
1143+ if args .distributed :
1144+ # scale gradient btw distributed ranks, each one can have different batch size
1145+ global_batch_size = utils .reduce_tensor (torch .tensor (batch_size , device = device ), 1 ) # SUM
1146+ dist_scale = args .world_size * batch_size / global_batch_size
1147+ else :
1148+ dist_scale = None
1149+
1150+ if has_no_sync and not need_update :
1151+ with model .no_sync ():
1152+ loss = _forward ()
1153+ scaled_loss = local_scale * loss
1154+ if dist_scale is not None :
1155+ scaled_loss *= dist_scale
1156+ _backward (scaled_loss )
1157+ else :
1158+ loss = _forward ()
1159+ scaled_loss = local_scale * loss
1160+ if dist_scale is not None :
1161+ scaled_loss *= dist_scale
1162+ _backward (scaled_loss )
11371163 else :
11381164 batch_size = input .shape [0 ]
1165+ if has_no_sync and not need_update :
1166+ with model .no_sync ():
1167+ loss = _forward ()
1168+ _backward (loss )
1169+ else :
1170+ loss = _forward ()
1171+ _backward (loss )
1172+
11391173 losses_m .update (loss .item () * accum_steps , batch_size )
11401174 update_sample_count += batch_size
11411175
@@ -1154,7 +1188,8 @@ def _backward(_loss):
11541188 elif device .type == 'npu' :
11551189 torch .npu .synchronize ()
11561190 time_now = time .time ()
1157- update_time_m .update (time .time () - update_start_time )
1191+
1192+ update_time_m .update ((time .time () - update_start_time ) / update_sample_count , update_sample_count )
11581193 update_start_time = time_now
11591194
11601195 if update_idx % args .log_interval == 0 :
@@ -1173,8 +1208,8 @@ def _backward(_loss):
11731208 f'Train: { epoch } [{ update_idx :>4d} /{ updates_per_epoch } '
11741209 f'({ 100. * (update_idx + 1 ) / updates_per_epoch :>3.0f} %)] '
11751210 f'Loss: { loss_now :#.3g} ({ loss_avg :#.3g} ) '
1176- f'Time: { update_time_m .val :.3f} s, { update_sample_count / update_time_m .val :>7.2f} /s '
1177- f'({ update_time_m .avg :.3f} s, { update_sample_count / update_time_m .avg :>7.2f} /s) '
1211+ f'Time: { update_time_m .val :.3f} s, { 1 / update_time_m .val :>7.2f} /s '
1212+ f'({ update_time_m .avg :.3f} s, { 1 / update_time_m .avg :>7.2f} /s) '
11781213 f'LR: { lr :.3e} '
11791214 f'Data: { data_time_m .val :.3f} ({ data_time_m .avg :.3f} )'
11801215 )
0 commit comments