@@ -327,6 +327,7 @@ def main_train(args: arg_util.Args):
327
327
# build wandb logger
328
328
if dist .is_master ():
329
329
wandb_utils .wandb .init (project = args .project_name , name = args .exp_name , config = {})
330
+
330
331
for ep in range (start_ep , args .ep ):
331
332
if ep % ep_lg == 0 or ep == start_ep :
332
333
print (f'[PT info] from ep{ start_ep } it{ start_it } , acc_str: { acc_str } , diffs: { args .diffs } , =======> bed: { args .bed } <=======\n ' )
@@ -483,10 +484,15 @@ def train_one_ep(
483
484
with maybe_record_function ('before_train' ):
484
485
# [get data]
485
486
inp , captions = data
486
- tokens = text_tokenizer (text = captions , max_length = text_tokenizer .model_max_length , padding = 'max_length' , truncation = True , return_tensors = 'pt' ) # todo: put this into dataset
487
+ tokens = text_tokenizer (text = captions , max_length = text_tokenizer .model_max_length ,
488
+ padding = 'max_length' , truncation = True , return_tensors = 'pt' ) # todo: put this into dataset
489
+ print ("gongwb tokens:" , tokens )
490
+
487
491
input_ids = tokens .input_ids .cuda (non_blocking = True )
488
492
mask = tokens .attention_mask .cuda (non_blocking = True )
493
+
489
494
text_features = text_encoder (input_ids = input_ids , attention_mask = mask )['last_hidden_state' ].float ()
495
+ print ("gongwb text_features:" , text_features )
490
496
491
497
lens : List [int ] = mask .sum (dim = - 1 ).tolist ()
492
498
cu_seqlens_k = F .pad (mask .sum (dim = - 1 ).to (dtype = torch .int32 ).cumsum_ (0 ), (1 , 0 ))
@@ -521,7 +527,8 @@ def train_one_ep(
521
527
step_cnt += int (stepping )
522
528
523
529
with maybe_record_function ('in_training' ):
524
- grad_norm_t , scale_log2_t = trainer .train_step (
530
+ #grad_norm_t, scale_log2_t =
531
+ trainer .train_step (
525
532
ep = ep , it = it , g_it = g_it , stepping = stepping , clip_decay_ratio = clip_decay_ratio ,
526
533
metric_lg = me ,
527
534
logging_params = stepping and step_cnt == 1 and (ep < 4 or ep in logging_params_milestone ),
0 commit comments