@@ -1354,14 +1354,65 @@ def compute_time_ids(original_size, crops_coords_top_left):
13541354 progress_bar .update (1 )
13551355 global_step += 1
13561356
1357+ # Enhanced logging for LPL metrics
13571358 log_data = {
13581359 "train_loss" : train_loss ,
13591360 "diffusion_loss" : loss .item (),
1361+ "learning_rate" : lr_scheduler .get_last_lr ()[0 ],
13601362 }
1361- if args .use_lpl and lpl_loss_value .item () > 0 :
1362- log_data ["lpl_loss" ] = lpl_loss_value .item ()
1363+
1364+ if args .use_lpl and lpl_fn is not None and global_step >= args .lpl_start :
1365+ if lpl_mask .any ():
1366+ # LPL application statistics
1367+ log_data .update ({
1368+ "lpl/loss" : lpl_loss_value .item (),
1369+ "lpl/num_samples" : lpl_mask .sum ().item (),
1370+ "lpl/application_ratio" : lpl_mask .float ().mean ().item (),
1371+ "lpl/weight" : args .lpl_weight ,
1372+ "lpl/weighted_loss" : (args .lpl_weight * lpl_loss_value ).item (),
1373+ })
1374+
1375+ # SNR statistics for LPL-applied samples
1376+ if args .snr_gamma is not None :
1377+ snr_values = snr [masked_indices ]
1378+ log_data .update ({
1379+ "lpl/snr_mean" : snr_values .mean ().item (),
1380+ "lpl/snr_std" : snr_values .std ().item (),
1381+ "lpl/snr_min" : snr_values .min ().item (),
1382+ "lpl/snr_max" : snr_values .max ().item (),
1383+ })
1384+
1385+ # Feature statistics if available
1386+ if hasattr (lpl_fn , 'last_feature_stats' ):
1387+ for layer_idx , stats in enumerate (lpl_fn .last_feature_stats ):
1388+ log_data .update ({
1389+ f"lpl/features/layer_{ layer_idx } /mean" : stats ['mean' ].item (),
1390+ f"lpl/features/layer_{ layer_idx } /std" : stats ['std' ].item (),
1391+ f"lpl/features/layer_{ layer_idx } /outlier_ratio" : stats .get ('outlier_ratio' , 0.0 ),
1392+ })
1393+
1394+ # Memory usage if available
1395+ if torch .cuda .is_available ():
1396+ log_data .update ({
1397+ "lpl/memory/allocated" : torch .cuda .memory_allocated () / 1024 ** 2 , # MB
1398+ "lpl/memory/reserved" : torch .cuda .memory_reserved () / 1024 ** 2 , # MB
1399+ })
1400+
1401+ # Log to accelerator
13631402 accelerator .log (log_data , step = global_step )
13641403
1404+ # Update progress bar with more metrics
1405+ progress_bar_logs = {
1406+ "loss" : loss .detach ().item (),
1407+ "lr" : lr_scheduler .get_last_lr ()[0 ],
1408+ }
1409+ if args .use_lpl and lpl_loss_value .item () > 0 :
1410+ progress_bar_logs .update ({
1411+ "lpl" : lpl_loss_value .item (),
1412+ "lpl_ratio" : lpl_mask .float ().mean ().item () if lpl_mask .any () else 0.0 ,
1413+ })
1414+ progress_bar .set_postfix (** progress_bar_logs )
1415+
13651416 # DeepSpeed requires saving weights on every device; saving weights only on the main process would cause issues.
13661417 if accelerator .distributed_type == DistributedType .DEEPSPEED or accelerator .is_main_process :
13671418 if global_step % args .checkpointing_steps == 0 :
0 commit comments