@@ -333,12 +333,14 @@ def test_loss_weighting_enabled(self, simple_adapter):
333333 model_pred = torch .randn (2 , 16 , 4 , 8 , 8 )
334334 target = torch .randn_like (model_pred )
335335 sigma = torch .tensor ([0.3 , 0.7 ])
336+ batch = {}
336337
337- weighted_loss , unweighted_loss , loss_weight = pipeline .compute_loss (model_pred , target , sigma )
338+ # Returns: weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask
339+ _ , scalar_weighted_loss , _ , scalar_unweighted_loss , loss_weight , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
338340
339341 # Verify shapes
340- assert weighted_loss .ndim == 0 , "Weighted loss should be scalar"
341- assert unweighted_loss .ndim == 0 , "Unweighted loss should be scalar"
342+ assert scalar_weighted_loss .ndim == 0 , "Weighted loss should be scalar"
343+ assert scalar_unweighted_loss .ndim == 0 , "Unweighted loss should be scalar"
342344
343345 # Verify weight formula: w = 1 + shift * σ
344346 expected_weights = 1.0 + 3.0 * sigma
@@ -357,11 +359,12 @@ def test_loss_weighting_disabled(self, simple_adapter):
357359 model_pred = torch .randn (2 , 16 , 4 , 8 , 8 )
358360 target = torch .randn_like (model_pred )
359361 sigma = torch .tensor ([0.3 , 0.7 ])
362+ batch = {}
360363
361- weighted_loss , unweighted_loss , loss_weight = pipeline .compute_loss (model_pred , target , sigma )
364+ _ , scalar_weighted_loss , _ , scalar_unweighted_loss , loss_weight , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
362365
363366 # Without weighting, weighted loss should equal unweighted loss
364- assert torch .allclose (weighted_loss , unweighted_loss , atol = 1e-6 )
367+ assert torch .allclose (scalar_weighted_loss , scalar_unweighted_loss , atol = 1e-6 )
365368
366369 # All weights should be 1
367370 assert torch .allclose (loss_weight , torch .ones_like (loss_weight ))
@@ -377,10 +380,11 @@ def test_loss_weight_formula(self, simple_adapter):
377380
378381 model_pred = torch .zeros (4 , 16 , 4 , 8 , 8 )
379382 target = torch .ones_like (model_pred )
383+ batch = {}
380384
381385 for sigma_val in [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]:
382386 sigma = torch .full ((4 ,), sigma_val )
383- _ , _ , loss_weight = pipeline .compute_loss (model_pred , target , sigma )
387+ _ , _ , _ , _ , loss_weight , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
384388
385389 expected_weight = 1.0 + flow_shift * sigma_val
386390 actual_weight = loss_weight [0 , 0 , 0 , 0 , 0 ].item ()
@@ -396,11 +400,12 @@ def test_loss_is_non_negative(self, simple_adapter):
396400 model_pred = torch .randn (2 , 16 , 4 , 8 , 8 )
397401 target = torch .randn_like (model_pred )
398402 sigma = torch .rand (2 )
403+ batch = {}
399404
400- weighted_loss , unweighted_loss , _ = pipeline .compute_loss (model_pred , target , sigma )
405+ _ , scalar_weighted_loss , _ , scalar_unweighted_loss , _ , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
401406
402- assert weighted_loss >= 0 , "Weighted loss should be non-negative"
403- assert unweighted_loss >= 0 , "Unweighted loss should be non-negative"
407+ assert scalar_weighted_loss >= 0 , "Weighted loss should be non-negative"
408+ assert scalar_unweighted_loss >= 0 , "Unweighted loss should be non-negative"
404409
405410 def test_loss_is_finite (self , simple_adapter ):
406411 """Test that computed loss is finite."""
@@ -409,11 +414,12 @@ def test_loss_is_finite(self, simple_adapter):
409414 model_pred = torch .randn (2 , 16 , 4 , 8 , 8 )
410415 target = torch .randn_like (model_pred )
411416 sigma = torch .rand (2 )
417+ batch = {}
412418
413- weighted_loss , unweighted_loss , _ = pipeline .compute_loss (model_pred , target , sigma )
419+ _ , scalar_weighted_loss , _ , scalar_unweighted_loss , _ , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
414420
415- assert torch .isfinite (weighted_loss ), "Weighted loss should be finite"
416- assert torch .isfinite (unweighted_loss ), "Unweighted loss should be finite"
421+ assert torch .isfinite (scalar_weighted_loss ), "Weighted loss should be finite"
422+ assert torch .isfinite (scalar_unweighted_loss ), "Unweighted loss should be finite"
417423
418424 def test_loss_mse_correctness (self , simple_adapter ):
419425 """Test that base loss is MSE."""
@@ -425,13 +431,14 @@ def test_loss_mse_correctness(self, simple_adapter):
425431 model_pred = torch .randn (2 , 16 , 4 , 8 , 8 )
426432 target = torch .randn_like (model_pred )
427433 sigma = torch .rand (2 )
434+ batch = {}
428435
429- _ , unweighted_loss , _ = pipeline .compute_loss (model_pred , target , sigma )
436+ _ , _ , _ , scalar_unweighted_loss , _ , _ = pipeline .compute_loss (model_pred , target , sigma , batch )
430437
431438 # Manual MSE calculation
432439 expected_mse = nn .functional .mse_loss (model_pred .float (), target .float ())
433440
434- assert torch .allclose (unweighted_loss , expected_mse , atol = 1e-6 )
441+ assert torch .allclose (scalar_unweighted_loss , expected_mse , atol = 1e-6 )
435442
436443
437444class TestFullTrainingStep :
@@ -442,7 +449,8 @@ def test_basic_training_step(self, pipeline, mock_model, sample_batch):
442449 device = torch .device ("cpu" )
443450 dtype = torch .bfloat16
444451
445- loss , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
452+ # Returns: weighted_loss, average_weighted_loss, loss_mask, metrics
453+ _ , loss , _ , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
446454
447455 # Verify loss
448456 assert isinstance (loss , torch .Tensor ), "Loss should be a tensor"
@@ -478,7 +486,7 @@ def test_step_with_different_batch_sizes(self, simple_adapter, mock_model):
478486 "text_embeddings" : torch .randn (batch_size , 77 , 4096 ),
479487 }
480488
481- loss , metrics = pipeline .step (mock_model , batch , device , dtype , global_step = 0 )
489+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , device , dtype , global_step = 0 )
482490
483491 assert isinstance (loss , torch .Tensor ), f"Loss should be tensor for batch_size={ batch_size } "
484492 assert not torch .isnan (loss ), f"Loss should not be NaN for batch_size={ batch_size } "
@@ -493,7 +501,7 @@ def test_step_with_4d_video_latents(self, pipeline, mock_model):
493501 "text_embeddings" : torch .randn (77 , 4096 ), # 2D instead of 3D
494502 }
495503
496- loss , metrics = pipeline .step (mock_model , batch , device , dtype , global_step = 0 )
504+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , device , dtype , global_step = 0 )
497505
498506 assert isinstance (loss , torch .Tensor )
499507 assert not torch .isnan (loss )
@@ -503,7 +511,7 @@ def test_step_metrics_collection(self, pipeline, mock_model, sample_batch):
503511 device = torch .device ("cpu" )
504512 dtype = torch .bfloat16
505513
506- loss , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 100 )
514+ _ , loss , _ , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 100 )
507515
508516 expected_keys = [
509517 "loss" ,
@@ -530,7 +538,7 @@ def test_step_sigma_in_valid_range(self, pipeline, mock_model, sample_batch):
530538 device = torch .device ("cpu" )
531539 dtype = torch .bfloat16
532540
533- loss , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
541+ _ , loss , _ , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
534542
535543 assert 0.0 <= metrics ["sigma_min" ] <= 1.0 , "Sigma min should be in [0, 1]"
536544 assert 0.0 <= metrics ["sigma_max" ] <= 1.0 , "Sigma max should be in [0, 1]"
@@ -548,7 +556,7 @@ def test_step_timesteps_in_valid_range(self, simple_adapter, mock_model, sample_
548556 device = torch .device ("cpu" )
549557 dtype = torch .bfloat16
550558
551- loss , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
559+ _ , loss , _ , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
552560
553561 assert 0.0 <= metrics ["timestep_min" ] <= num_timesteps
554562 assert 0.0 <= metrics ["timestep_max" ] <= num_timesteps
@@ -558,7 +566,7 @@ def test_step_noisy_latents_are_finite(self, pipeline, mock_model, sample_batch)
558566 device = torch .device ("cpu" )
559567 dtype = torch .bfloat16
560568
561- loss , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
569+ _ , loss , _ , metrics = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
562570
563571 assert torch .isfinite (torch .tensor (metrics ["noisy_min" ])), "Noisy min should be finite"
564572 assert torch .isfinite (torch .tensor (metrics ["noisy_max" ])), "Noisy max should be finite"
@@ -568,7 +576,7 @@ def test_step_with_image_batch(self, pipeline, mock_model, image_batch):
568576 device = torch .device ("cpu" )
569577 dtype = torch .bfloat16
570578
571- loss , metrics = pipeline .step (mock_model , image_batch , device , dtype , global_step = 0 )
579+ _ , loss , _ , metrics = pipeline .step (mock_model , image_batch , device , dtype , global_step = 0 )
572580
573581 assert isinstance (loss , torch .Tensor )
574582 assert not torch .isnan (loss )
@@ -588,11 +596,11 @@ def test_deterministic_with_seed(self, simple_adapter, mock_model, sample_batch)
588596
589597 # First run
590598 torch .manual_seed (42 )
591- _ , metrics1 = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
599+ _ , _ , _ , metrics1 = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
592600
593601 # Second run with same seed
594602 torch .manual_seed (42 )
595- _ , metrics2 = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
603+ _ , _ , _ , metrics2 = pipeline .step (mock_model , sample_batch , device , dtype , global_step = 0 )
596604
597605 # Sigma values should be identical
598606 assert abs (metrics1 ["sigma_min" ] - metrics2 ["sigma_min" ]) < 1e-6
@@ -675,7 +683,7 @@ def test_empty_batch_handling(self, simple_adapter):
675683 }
676684
677685 mock_model = MockModel ()
678- loss , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
686+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
679687
680688 assert not torch .isnan (loss )
681689
@@ -693,7 +701,7 @@ def test_large_batch_handling(self, simple_adapter):
693701 }
694702
695703 mock_model = MockModel ()
696- loss , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
704+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
697705
698706 assert not torch .isnan (loss )
699707
@@ -716,7 +724,7 @@ def test_extreme_flow_shift_values(self, simple_adapter):
716724 }
717725
718726 mock_model = MockModel ()
719- loss , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
727+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
720728
721729 assert torch .isfinite (loss ), f"Loss should be finite for shift={ shift } "
722730
@@ -755,7 +763,7 @@ def test_multiple_training_steps(self, simple_adapter):
755763 "text_embeddings" : torch .randn (2 , 77 , 4096 ),
756764 }
757765
758- loss , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = step )
766+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = step )
759767 losses .append (loss .item ())
760768
761769 assert not torch .isnan (loss ), f"Loss became NaN at step { step } "
@@ -779,7 +787,7 @@ def test_pipeline_with_all_sampling_methods(self, simple_adapter):
779787 "text_embeddings" : torch .randn (2 , 77 , 4096 ),
780788 }
781789
782- loss , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
790+ _ , loss , _ , metrics = pipeline .step (mock_model , batch , torch .device ("cpu" ), torch .float32 , global_step = 0 )
783791
784792 assert not torch .isnan (loss ), f"Loss should not be NaN for method={ method } "
785793
0 commit comments