Skip to content

Commit 137d89b

Browse files
author
Huy Vu2
committed
update unit test
1 parent f85b23d commit 137d89b

File tree

1 file changed

+37
-29
lines changed

1 file changed

+37
-29
lines changed

tests/unit_tests/automodel/test_flow_matching_pipeline.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,14 @@ def test_loss_weighting_enabled(self, simple_adapter):
333333
model_pred = torch.randn(2, 16, 4, 8, 8)
334334
target = torch.randn_like(model_pred)
335335
sigma = torch.tensor([0.3, 0.7])
336+
batch = {}
336337

337-
weighted_loss, unweighted_loss, loss_weight = pipeline.compute_loss(model_pred, target, sigma)
338+
# Returns: weighted_loss, average_weighted_loss, unweighted_loss, average_unweighted_loss, loss_weight, loss_mask
339+
_, scalar_weighted_loss, _, scalar_unweighted_loss, loss_weight, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
338340

339341
# Verify shapes
340-
assert weighted_loss.ndim == 0, "Weighted loss should be scalar"
341-
assert unweighted_loss.ndim == 0, "Unweighted loss should be scalar"
342+
assert scalar_weighted_loss.ndim == 0, "Weighted loss should be scalar"
343+
assert scalar_unweighted_loss.ndim == 0, "Unweighted loss should be scalar"
342344

343345
# Verify weight formula: w = 1 + shift * σ
344346
expected_weights = 1.0 + 3.0 * sigma
@@ -357,11 +359,12 @@ def test_loss_weighting_disabled(self, simple_adapter):
357359
model_pred = torch.randn(2, 16, 4, 8, 8)
358360
target = torch.randn_like(model_pred)
359361
sigma = torch.tensor([0.3, 0.7])
362+
batch = {}
360363

361-
weighted_loss, unweighted_loss, loss_weight = pipeline.compute_loss(model_pred, target, sigma)
364+
_, scalar_weighted_loss, _, scalar_unweighted_loss, loss_weight, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
362365

363366
# Without weighting, weighted loss should equal unweighted loss
364-
assert torch.allclose(weighted_loss, unweighted_loss, atol=1e-6)
367+
assert torch.allclose(scalar_weighted_loss, scalar_unweighted_loss, atol=1e-6)
365368

366369
# All weights should be 1
367370
assert torch.allclose(loss_weight, torch.ones_like(loss_weight))
@@ -377,10 +380,11 @@ def test_loss_weight_formula(self, simple_adapter):
377380

378381
model_pred = torch.zeros(4, 16, 4, 8, 8)
379382
target = torch.ones_like(model_pred)
383+
batch = {}
380384

381385
for sigma_val in [0.0, 0.25, 0.5, 0.75, 1.0]:
382386
sigma = torch.full((4,), sigma_val)
383-
_, _, loss_weight = pipeline.compute_loss(model_pred, target, sigma)
387+
_, _, _, _, loss_weight, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
384388

385389
expected_weight = 1.0 + flow_shift * sigma_val
386390
actual_weight = loss_weight[0, 0, 0, 0, 0].item()
@@ -396,11 +400,12 @@ def test_loss_is_non_negative(self, simple_adapter):
396400
model_pred = torch.randn(2, 16, 4, 8, 8)
397401
target = torch.randn_like(model_pred)
398402
sigma = torch.rand(2)
403+
batch = {}
399404

400-
weighted_loss, unweighted_loss, _ = pipeline.compute_loss(model_pred, target, sigma)
405+
_, scalar_weighted_loss, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
401406

402-
assert weighted_loss >= 0, "Weighted loss should be non-negative"
403-
assert unweighted_loss >= 0, "Unweighted loss should be non-negative"
407+
assert scalar_weighted_loss >= 0, "Weighted loss should be non-negative"
408+
assert scalar_unweighted_loss >= 0, "Unweighted loss should be non-negative"
404409

405410
def test_loss_is_finite(self, simple_adapter):
406411
"""Test that computed loss is finite."""
@@ -409,11 +414,12 @@ def test_loss_is_finite(self, simple_adapter):
409414
model_pred = torch.randn(2, 16, 4, 8, 8)
410415
target = torch.randn_like(model_pred)
411416
sigma = torch.rand(2)
417+
batch = {}
412418

413-
weighted_loss, unweighted_loss, _ = pipeline.compute_loss(model_pred, target, sigma)
419+
_, scalar_weighted_loss, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
414420

415-
assert torch.isfinite(weighted_loss), "Weighted loss should be finite"
416-
assert torch.isfinite(unweighted_loss), "Unweighted loss should be finite"
421+
assert torch.isfinite(scalar_weighted_loss), "Weighted loss should be finite"
422+
assert torch.isfinite(scalar_unweighted_loss), "Unweighted loss should be finite"
417423

418424
def test_loss_mse_correctness(self, simple_adapter):
419425
"""Test that base loss is MSE."""
@@ -425,13 +431,14 @@ def test_loss_mse_correctness(self, simple_adapter):
425431
model_pred = torch.randn(2, 16, 4, 8, 8)
426432
target = torch.randn_like(model_pred)
427433
sigma = torch.rand(2)
434+
batch = {}
428435

429-
_, unweighted_loss, _ = pipeline.compute_loss(model_pred, target, sigma)
436+
_, _, _, scalar_unweighted_loss, _, _ = pipeline.compute_loss(model_pred, target, sigma, batch)
430437

431438
# Manual MSE calculation
432439
expected_mse = nn.functional.mse_loss(model_pred.float(), target.float())
433440

434-
assert torch.allclose(unweighted_loss, expected_mse, atol=1e-6)
441+
assert torch.allclose(scalar_unweighted_loss, expected_mse, atol=1e-6)
435442

436443

437444
class TestFullTrainingStep:
@@ -442,7 +449,8 @@ def test_basic_training_step(self, pipeline, mock_model, sample_batch):
442449
device = torch.device("cpu")
443450
dtype = torch.bfloat16
444451

445-
loss, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
452+
# Returns: weighted_loss, average_weighted_loss, loss_mask, metrics
453+
_, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
446454

447455
# Verify loss
448456
assert isinstance(loss, torch.Tensor), "Loss should be a tensor"
@@ -478,7 +486,7 @@ def test_step_with_different_batch_sizes(self, simple_adapter, mock_model):
478486
"text_embeddings": torch.randn(batch_size, 77, 4096),
479487
}
480488

481-
loss, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0)
489+
_, loss, _, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0)
482490

483491
assert isinstance(loss, torch.Tensor), f"Loss should be tensor for batch_size={batch_size}"
484492
assert not torch.isnan(loss), f"Loss should not be NaN for batch_size={batch_size}"
@@ -493,7 +501,7 @@ def test_step_with_4d_video_latents(self, pipeline, mock_model):
493501
"text_embeddings": torch.randn(77, 4096), # 2D instead of 3D
494502
}
495503

496-
loss, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0)
504+
_, loss, _, metrics = pipeline.step(mock_model, batch, device, dtype, global_step=0)
497505

498506
assert isinstance(loss, torch.Tensor)
499507
assert not torch.isnan(loss)
@@ -503,7 +511,7 @@ def test_step_metrics_collection(self, pipeline, mock_model, sample_batch):
503511
device = torch.device("cpu")
504512
dtype = torch.bfloat16
505513

506-
loss, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=100)
514+
_, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=100)
507515

508516
expected_keys = [
509517
"loss",
@@ -530,7 +538,7 @@ def test_step_sigma_in_valid_range(self, pipeline, mock_model, sample_batch):
530538
device = torch.device("cpu")
531539
dtype = torch.bfloat16
532540

533-
loss, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
541+
_, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
534542

535543
assert 0.0 <= metrics["sigma_min"] <= 1.0, "Sigma min should be in [0, 1]"
536544
assert 0.0 <= metrics["sigma_max"] <= 1.0, "Sigma max should be in [0, 1]"
@@ -548,7 +556,7 @@ def test_step_timesteps_in_valid_range(self, simple_adapter, mock_model, sample_
548556
device = torch.device("cpu")
549557
dtype = torch.bfloat16
550558

551-
loss, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
559+
_, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
552560

553561
assert 0.0 <= metrics["timestep_min"] <= num_timesteps
554562
assert 0.0 <= metrics["timestep_max"] <= num_timesteps
@@ -558,7 +566,7 @@ def test_step_noisy_latents_are_finite(self, pipeline, mock_model, sample_batch)
558566
device = torch.device("cpu")
559567
dtype = torch.bfloat16
560568

561-
loss, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
569+
_, loss, _, metrics = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
562570

563571
assert torch.isfinite(torch.tensor(metrics["noisy_min"])), "Noisy min should be finite"
564572
assert torch.isfinite(torch.tensor(metrics["noisy_max"])), "Noisy max should be finite"
@@ -568,7 +576,7 @@ def test_step_with_image_batch(self, pipeline, mock_model, image_batch):
568576
device = torch.device("cpu")
569577
dtype = torch.bfloat16
570578

571-
loss, metrics = pipeline.step(mock_model, image_batch, device, dtype, global_step=0)
579+
_, loss, _, metrics = pipeline.step(mock_model, image_batch, device, dtype, global_step=0)
572580

573581
assert isinstance(loss, torch.Tensor)
574582
assert not torch.isnan(loss)
@@ -588,11 +596,11 @@ def test_deterministic_with_seed(self, simple_adapter, mock_model, sample_batch)
588596

589597
# First run
590598
torch.manual_seed(42)
591-
_, metrics1 = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
599+
_, _, _, metrics1 = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
592600

593601
# Second run with same seed
594602
torch.manual_seed(42)
595-
_, metrics2 = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
603+
_, _, _, metrics2 = pipeline.step(mock_model, sample_batch, device, dtype, global_step=0)
596604

597605
# Sigma values should be identical
598606
assert abs(metrics1["sigma_min"] - metrics2["sigma_min"]) < 1e-6
@@ -675,7 +683,7 @@ def test_empty_batch_handling(self, simple_adapter):
675683
}
676684

677685
mock_model = MockModel()
678-
loss, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
686+
_, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
679687

680688
assert not torch.isnan(loss)
681689

@@ -693,7 +701,7 @@ def test_large_batch_handling(self, simple_adapter):
693701
}
694702

695703
mock_model = MockModel()
696-
loss, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
704+
_, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
697705

698706
assert not torch.isnan(loss)
699707

@@ -716,7 +724,7 @@ def test_extreme_flow_shift_values(self, simple_adapter):
716724
}
717725

718726
mock_model = MockModel()
719-
loss, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
727+
_, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
720728

721729
assert torch.isfinite(loss), f"Loss should be finite for shift={shift}"
722730

@@ -755,7 +763,7 @@ def test_multiple_training_steps(self, simple_adapter):
755763
"text_embeddings": torch.randn(2, 77, 4096),
756764
}
757765

758-
loss, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=step)
766+
_, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=step)
759767
losses.append(loss.item())
760768

761769
assert not torch.isnan(loss), f"Loss became NaN at step {step}"
@@ -779,7 +787,7 @@ def test_pipeline_with_all_sampling_methods(self, simple_adapter):
779787
"text_embeddings": torch.randn(2, 77, 4096),
780788
}
781789

782-
loss, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
790+
_, loss, _, metrics = pipeline.step(mock_model, batch, torch.device("cpu"), torch.float32, global_step=0)
783791

784792
assert not torch.isnan(loss), f"Loss should not be NaN for method={method}"
785793

0 commit comments

Comments
 (0)