|
| 1 | +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import pytest |
| 16 | +import torch |
| 17 | + |
| 18 | +from dfm.src.common.utils.batch_ops import batch_mul |
| 19 | +from dfm.src.megatron.model.dit.edm.edm_pipeline import EDMPipeline |
| 20 | + |
| 21 | + |
| 22 | +class _DummyModel: |
| 23 | + """Dummy model for testing that mimics the DiT network interface.""" |
| 24 | + |
| 25 | + def __call__(self, x, timesteps, **condition): |
| 26 | + # Return zeros matching input shape |
| 27 | + return torch.zeros_like(x) |
| 28 | + |
| 29 | + |
| 30 | +@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA") |
| 31 | +class TestEDMPipeline: |
| 32 | + """Test class for EDMPipeline with shared setup.""" |
| 33 | + |
| 34 | + def setup_method(self, method, monkeypatch=None): |
| 35 | + """Set up test fixtures before each test method.""" |
| 36 | + # Stub parallel_state functions to avoid requiring initialization |
| 37 | + from megatron.core import parallel_state |
| 38 | + |
| 39 | + if monkeypatch: |
| 40 | + monkeypatch.setattr( |
| 41 | + parallel_state, "get_data_parallel_rank", lambda with_context_parallel=False: 0, raising=False |
| 42 | + ) |
| 43 | + monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False) |
| 44 | + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: True, raising=False) |
| 45 | + monkeypatch.setattr(parallel_state, "get_context_parallel_group", lambda: None, raising=False) |
| 46 | + |
| 47 | + # Create pipeline with common parameters |
| 48 | + self.sigma_data = 0.5 |
| 49 | + self.pipeline = EDMPipeline( |
| 50 | + vae=None, |
| 51 | + p_mean=0.0, |
| 52 | + p_std=1.0, |
| 53 | + sigma_max=80.0, |
| 54 | + sigma_min=0.0002, |
| 55 | + sigma_data=self.sigma_data, |
| 56 | + seed=1234, |
| 57 | + ) |
| 58 | + |
| 59 | + # Create and assign dummy model |
| 60 | + self.model = _DummyModel() |
| 61 | + self.pipeline.net = self.model |
| 62 | + |
| 63 | + # Create common test data shapes |
| 64 | + self.batch_size = 2 |
| 65 | + self.channels = 4 |
| 66 | + self.height = self.width = 8 |
| 67 | + |
| 68 | + # Create common test tensors |
| 69 | + self.x0 = torch.randn(self.batch_size, self.channels, self.height, self.width).to( |
| 70 | + **self.pipeline.tensor_kwargs |
| 71 | + ) |
| 72 | + self.sigma = torch.ones(self.batch_size).to(**self.pipeline.tensor_kwargs) * 1.0 |
| 73 | + self.condition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)} |
| 74 | + self.epsilon = torch.randn(self.batch_size, self.channels, self.height, self.width).to( |
| 75 | + **self.pipeline.tensor_kwargs |
| 76 | + ) |
| 77 | + |
| 78 | + def test_denoise(self, monkeypatch): |
| 79 | + """Test the denoise method produces correct output shapes and values.""" |
| 80 | + # Initialize with monkeypatch |
| 81 | + self.setup_method(None, monkeypatch) |
| 82 | + |
| 83 | + # Create test inputs (xt on CPU for conversion test) |
| 84 | + xt = torch.randn(self.batch_size, self.channels, self.height, self.width) |
| 85 | + sigma = torch.ones(self.batch_size) * 1.0 |
| 86 | + |
| 87 | + # Test Case 1: is_pipeline_last_stage = True |
| 88 | + # Call denoise |
| 89 | + x0_pred, eps_pred = self.pipeline.denoise(xt, sigma, self.condition) |
| 90 | + |
| 91 | + # Verify outputs have correct shapes |
| 92 | + assert x0_pred.shape == xt.shape, f"Expected x0_pred shape {xt.shape}, got {x0_pred.shape}" |
| 93 | + assert eps_pred.shape == xt.shape, f"Expected eps_pred shape {xt.shape}, got {eps_pred.shape}" |
| 94 | + |
| 95 | + # Verify outputs are on CUDA with correct dtype |
| 96 | + assert x0_pred.device.type == "cuda" |
| 97 | + assert x0_pred.dtype == torch.bfloat16 |
| 98 | + assert eps_pred.device.type == "cuda" |
| 99 | + assert eps_pred.dtype == torch.bfloat16 |
| 100 | + |
| 101 | + # Verify the outputs follow the expected formulas |
| 102 | + # Convert inputs to expected dtype/device for comparison |
| 103 | + xt_converted = xt.to(**self.pipeline.tensor_kwargs) |
| 104 | + sigma_converted = sigma.to(**self.pipeline.tensor_kwargs) |
| 105 | + |
| 106 | + # Get scaling factors |
| 107 | + c_skip, c_out, c_in, c_noise = self.pipeline.scaling(sigma=sigma_converted) |
| 108 | + |
| 109 | + # Since model returns zeros, net_output = 0 |
| 110 | + # Expected: x0_pred = c_skip * xt + c_out * 0 = c_skip * xt |
| 111 | + expected_x0_pred = batch_mul(c_skip, xt_converted) |
| 112 | + assert torch.allclose(x0_pred, expected_x0_pred, rtol=1e-3, atol=1e-5), "x0_pred doesn't match expected value" |
| 113 | + |
| 114 | + # Expected: eps_pred = (xt - x0_pred) / sigma |
| 115 | + expected_eps_pred = batch_mul(xt_converted - x0_pred, 1.0 / sigma_converted) |
| 116 | + assert torch.allclose(eps_pred, expected_eps_pred, rtol=1e-3, atol=1e-5), ( |
| 117 | + "eps_pred doesn't match expected value" |
| 118 | + ) |
| 119 | + |
| 120 | + # Test Case 2: is_pipeline_last_stage = False |
| 121 | + # Mock is_pipeline_last_stage to return False |
| 122 | + from megatron.core import parallel_state |
| 123 | + |
| 124 | + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: False) |
| 125 | + |
| 126 | + # Call denoise again |
| 127 | + net_output = self.pipeline.denoise(xt, sigma, self.condition) |
| 128 | + |
| 129 | + # Verify output is a single tensor (not a tuple) |
| 130 | + assert isinstance(net_output, torch.Tensor), "Expected net_output to be a single tensor when not last stage" |
| 131 | + assert not isinstance(net_output, tuple), "Expected net_output to not be a tuple when not last stage" |
| 132 | + |
| 133 | + # Verify output has correct shape (same as model output) |
| 134 | + assert net_output.shape == xt.shape, f"Expected net_output shape {xt.shape}, got {net_output.shape}" |
| 135 | + |
| 136 | + # Verify output is on CUDA with correct dtype |
| 137 | + assert net_output.device.type == "cuda" |
| 138 | + assert net_output.dtype == torch.bfloat16 |
| 139 | + |
| 140 | + # Since model returns zeros, net_output should be zeros |
| 141 | + assert torch.allclose(net_output, torch.zeros_like(xt_converted), rtol=1e-3, atol=1e-5), ( |
| 142 | + "net_output doesn't match expected value (zeros from dummy model)" |
| 143 | + ) |
| 144 | + |
| 145 | + def test_compute_loss_with_epsilon_and_sigma(self, monkeypatch): |
| 146 | + """Test the compute_loss_with_epsilon_and_sigma method produces correct output shapes and values.""" |
| 147 | + # Initialize with monkeypatch |
| 148 | + self.setup_method(None, monkeypatch) |
| 149 | + |
| 150 | + # Create test inputs |
| 151 | + data_batch = {"video": self.x0} |
| 152 | + x0_from_data_batch = self.x0 |
| 153 | + |
| 154 | + # Call compute_loss_with_epsilon_and_sigma |
| 155 | + output_batch, pred_mse, edm_loss = self.pipeline.compute_loss_with_epsilon_and_sigma( |
| 156 | + data_batch, x0_from_data_batch, self.x0, self.condition, self.epsilon, self.sigma |
| 157 | + ) |
| 158 | + |
| 159 | + # Verify output_batch contains expected keys |
| 160 | + assert "x0" in output_batch |
| 161 | + assert "xt" in output_batch |
| 162 | + assert "sigma" in output_batch |
| 163 | + assert "weights_per_sigma" in output_batch |
| 164 | + assert "condition" in output_batch |
| 165 | + assert "model_pred" in output_batch |
| 166 | + assert "mse_loss" in output_batch |
| 167 | + assert "edm_loss" in output_batch |
| 168 | + |
| 169 | + # Verify shapes |
| 170 | + assert output_batch["x0"].shape == self.x0.shape |
| 171 | + assert output_batch["xt"].shape == self.x0.shape |
| 172 | + assert output_batch["sigma"].shape == self.sigma.shape |
| 173 | + assert output_batch["weights_per_sigma"].shape == self.sigma.shape |
| 174 | + assert pred_mse.shape == self.x0.shape |
| 175 | + assert edm_loss.shape == self.x0.shape |
| 176 | + |
| 177 | + # Verify the loss computation follows expected formulas |
| 178 | + # 1. Compute expected xt from marginal probability |
| 179 | + mean, std = self.pipeline.sde.marginal_prob(self.x0, self.sigma) |
| 180 | + expected_xt = mean + batch_mul(std, self.epsilon) |
| 181 | + assert torch.allclose(output_batch["xt"], expected_xt, rtol=1e-3, atol=1e-5), "xt doesn't match expected value" |
| 182 | + |
| 183 | + # 2. Verify loss weights |
| 184 | + expected_weights = (self.sigma**2 + self.sigma_data**2) / (self.sigma * self.sigma_data) ** 2 |
| 185 | + assert torch.allclose(output_batch["weights_per_sigma"], expected_weights, rtol=1e-3, atol=1e-5), ( |
| 186 | + "weights_per_sigma doesn't match expected value" |
| 187 | + ) |
| 188 | + |
| 189 | + # 3. Verify edm_loss = weights * (x0 - x0_pred)^2 |
| 190 | + x0_pred = output_batch["model_pred"]["x0_pred"] |
| 191 | + expected_pred_mse = (self.x0 - x0_pred) ** 2 |
| 192 | + assert torch.allclose(pred_mse, expected_pred_mse, rtol=1e-3, atol=1e-5), ( |
| 193 | + "pred_mse doesn't match expected value" |
| 194 | + ) |
| 195 | + |
| 196 | + expected_edm_loss = batch_mul(expected_pred_mse, expected_weights) |
| 197 | + assert torch.allclose(edm_loss, expected_edm_loss, rtol=1e-3, atol=1e-5), ( |
| 198 | + "edm_loss doesn't match expected value" |
| 199 | + ) |
| 200 | + |
| 201 | + # 4. Verify scalar losses are proper means |
| 202 | + assert torch.isclose(output_batch["mse_loss"], pred_mse.mean(), rtol=1e-3, atol=1e-5) |
| 203 | + assert torch.isclose(output_batch["edm_loss"], edm_loss.mean(), rtol=1e-3, atol=1e-5) |
| 204 | + |
| 205 | + def test_training_step(self, monkeypatch): |
| 206 | + """Test the training_step method with mocked compute_loss_with_epsilon_and_sigma.""" |
| 207 | + from unittest.mock import patch |
| 208 | + |
| 209 | + # Initialize with monkeypatch |
| 210 | + self.setup_method(None, monkeypatch) |
| 211 | + |
| 212 | + # Create test data batch |
| 213 | + data_batch = { |
| 214 | + "video": self.x0, |
| 215 | + "context_embeddings": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs), |
| 216 | + } |
| 217 | + iteration = 0 |
| 218 | + |
| 219 | + # Test Case 1: is_pipeline_last_stage = True |
| 220 | + # Mock compute_loss_with_epsilon_and_sigma to return expected values |
| 221 | + mock_output_batch = { |
| 222 | + "x0": self.x0, |
| 223 | + "xt": torch.randn_like(self.x0), |
| 224 | + "sigma": self.sigma, |
| 225 | + "weights_per_sigma": torch.ones_like(self.sigma), |
| 226 | + "condition": self.condition, |
| 227 | + "model_pred": {"x0_pred": torch.randn_like(self.x0), "eps_pred": torch.randn_like(self.x0)}, |
| 228 | + "mse_loss": torch.tensor(0.5, **self.pipeline.tensor_kwargs), |
| 229 | + "edm_loss": torch.tensor(0.3, **self.pipeline.tensor_kwargs), |
| 230 | + } |
| 231 | + mock_pred_mse = torch.randn_like(self.x0) |
| 232 | + mock_edm_loss = torch.randn_like(self.x0) |
| 233 | + |
| 234 | + with patch.object( |
| 235 | + self.pipeline, |
| 236 | + "compute_loss_with_epsilon_and_sigma", |
| 237 | + return_value=(mock_output_batch, mock_pred_mse, mock_edm_loss), |
| 238 | + ) as mock_compute_loss: |
| 239 | + # Call training_step |
| 240 | + result = self.pipeline.training_step(self.model, data_batch, iteration) |
| 241 | + |
| 242 | + # Verify compute_loss_with_epsilon_and_sigma was called once |
| 243 | + assert mock_compute_loss.call_count == 1 |
| 244 | + |
| 245 | + # Verify return values are correct (output_batch, edm_loss) |
| 246 | + assert len(result) == 2 |
| 247 | + output_batch, edm_loss = result |
| 248 | + assert output_batch == mock_output_batch |
| 249 | + assert torch.equal(edm_loss, mock_edm_loss) |
| 250 | + |
| 251 | + # Test Case 2: is_pipeline_last_stage = False |
| 252 | + # Mock is_pipeline_last_stage to return False |
| 253 | + from megatron.core import parallel_state |
| 254 | + |
| 255 | + monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: False) |
| 256 | + |
| 257 | + # Mock compute_loss_with_epsilon_and_sigma to return net_output only |
| 258 | + mock_net_output = torch.randn_like(self.x0) |
| 259 | + |
| 260 | + with patch.object( |
| 261 | + self.pipeline, "compute_loss_with_epsilon_and_sigma", return_value=mock_net_output |
| 262 | + ) as mock_compute_loss: |
| 263 | + # Call training_step |
| 264 | + result = self.pipeline.training_step(self.model, data_batch, iteration) |
| 265 | + |
| 266 | + # Verify compute_loss_with_epsilon_and_sigma was called once |
| 267 | + assert mock_compute_loss.call_count == 1 |
| 268 | + |
| 269 | + # Verify return value is just net_output (not a tuple) |
| 270 | + assert torch.equal(result, mock_net_output) |
| 271 | + |
| 272 | + def test_get_data_and_condition(self, monkeypatch): |
| 273 | + """Test the get_data_and_condition method with different dropout rates.""" |
| 274 | + # Initialize with monkeypatch |
| 275 | + self.setup_method(None, monkeypatch) |
| 276 | + |
| 277 | + # Create test data batch |
| 278 | + video_data = torch.randn(self.batch_size, self.channels, self.height, self.width).to( |
| 279 | + **self.pipeline.tensor_kwargs |
| 280 | + ) |
| 281 | + context_embeddings = torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs) |
| 282 | + |
| 283 | + data_batch = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()} |
| 284 | + |
| 285 | + # Test Case 1: With default dropout_rate (0.2) |
| 286 | + raw_state, latent_state, condition = self.pipeline.get_data_and_condition(data_batch.copy(), dropout_rate=0.2) |
| 287 | + |
| 288 | + # Verify raw_state is video * sigma_data |
| 289 | + expected_raw_state = video_data * self.sigma_data |
| 290 | + assert torch.allclose(raw_state, expected_raw_state, rtol=1e-3, atol=1e-5), ( |
| 291 | + "raw_state doesn't match expected value (video * sigma_data)" |
| 292 | + ) |
| 293 | + |
| 294 | + # Verify latent_state equals raw_state |
| 295 | + assert torch.equal(latent_state, raw_state), "latent_state should equal raw_state" |
| 296 | + |
| 297 | + # Verify condition contains crossattn_emb |
| 298 | + assert "crossattn_emb" in condition, "condition should contain 'crossattn_emb' key" |
| 299 | + assert condition["crossattn_emb"].shape == context_embeddings.shape, ( |
| 300 | + f"Expected crossattn_emb shape {context_embeddings.shape}, got {condition['crossattn_emb'].shape}" |
| 301 | + ) |
| 302 | + |
| 303 | + # Verify crossattn_emb is on CUDA with correct dtype |
| 304 | + assert condition["crossattn_emb"].device.type == "cuda" |
| 305 | + assert condition["crossattn_emb"].dtype == torch.bfloat16 |
| 306 | + |
| 307 | + # Test Case 2: With dropout_rate=0.0 (no dropout, should keep all values) |
| 308 | + data_batch_no_dropout = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()} |
| 309 | + raw_state_no_dropout, latent_state_no_dropout, condition_no_dropout = self.pipeline.get_data_and_condition( |
| 310 | + data_batch_no_dropout, dropout_rate=0.0 |
| 311 | + ) |
| 312 | + |
| 313 | + # With dropout_rate=0.0, crossattn_emb should equal context_embeddings |
| 314 | + assert torch.allclose(condition_no_dropout["crossattn_emb"], context_embeddings, rtol=1e-3, atol=1e-5), ( |
| 315 | + "With dropout_rate=0.0, crossattn_emb should equal original context_embeddings" |
| 316 | + ) |
| 317 | + |
| 318 | + # Test Case 3: With dropout_rate=1.0 (complete dropout, should zero out all values) |
| 319 | + data_batch_full_dropout = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()} |
| 320 | + raw_state_full_dropout, latent_state_full_dropout, condition_full_dropout = ( |
| 321 | + self.pipeline.get_data_and_condition(data_batch_full_dropout, dropout_rate=1.0) |
| 322 | + ) |
| 323 | + |
| 324 | + # With dropout_rate=1.0, crossattn_emb should be all zeros |
| 325 | + expected_zeros = torch.zeros_like(context_embeddings) |
| 326 | + assert torch.allclose(condition_full_dropout["crossattn_emb"], expected_zeros, rtol=1e-3, atol=1e-5), ( |
| 327 | + "With dropout_rate=1.0, crossattn_emb should be all zeros" |
| 328 | + ) |
| 329 | + |
| 330 | + # Verify raw_state is consistent across all dropout rates |
| 331 | + assert torch.equal(raw_state, raw_state_no_dropout), "raw_state should be consistent regardless of dropout" |
| 332 | + assert torch.equal(raw_state, raw_state_full_dropout), "raw_state should be consistent regardless of dropout" |
| 333 | + |
| 334 | + def test_get_x0_fn_from_batch(self, monkeypatch): |
| 335 | + """Test the get_x0_fn_from_batch method returns a callable with correct guidance behavior.""" |
| 336 | + from unittest.mock import patch |
| 337 | + |
| 338 | + # Initialize with monkeypatch |
| 339 | + self.setup_method(None, monkeypatch) |
| 340 | + |
| 341 | + # Create test data batch |
| 342 | + video_data = torch.randn(self.batch_size, self.channels, self.height, self.width).to( |
| 343 | + **self.pipeline.tensor_kwargs |
| 344 | + ) |
| 345 | + context_embeddings = torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs) |
| 346 | + |
| 347 | + data_batch = {"video": video_data, "context_embeddings": context_embeddings} |
| 348 | + |
| 349 | + # Create mock condition and uncondition |
| 350 | + mock_condition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)} |
| 351 | + mock_uncondition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)} |
| 352 | + |
| 353 | + # Mock get_condition_uncondition to return our mock conditions |
| 354 | + with patch.object(self.pipeline, "get_condition_uncondition", return_value=(mock_condition, mock_uncondition)): |
| 355 | + # Test Case 1: Default guidance (1.5) |
| 356 | + guidance = 1.5 |
| 357 | + x0_fn = self.pipeline.get_x0_fn_from_batch(data_batch, guidance=guidance) |
| 358 | + |
| 359 | + # Verify x0_fn is callable |
| 360 | + assert callable(x0_fn), "get_x0_fn_from_batch should return a callable" |
| 361 | + |
| 362 | + # Create test inputs for the returned function |
| 363 | + noise_x = torch.randn(self.batch_size, self.channels, self.height, self.width).to( |
| 364 | + **self.pipeline.tensor_kwargs |
| 365 | + ) |
| 366 | + sigma = torch.ones(self.batch_size).to(**self.pipeline.tensor_kwargs) * 1.0 |
| 367 | + |
| 368 | + # Create mock outputs for denoise calls |
| 369 | + mock_cond_x0 = torch.randn_like(noise_x) |
| 370 | + mock_uncond_x0 = torch.randn_like(noise_x) |
| 371 | + mock_eps = torch.randn_like(noise_x) # dummy eps_pred (not used in x0_fn) |
| 372 | + |
| 373 | + # Mock denoise to return different values for condition vs uncondition |
| 374 | + call_count = [0] |
| 375 | + |
| 376 | + def mock_denoise(xt, sig, cond): |
| 377 | + call_count[0] += 1 |
| 378 | + if call_count[0] == 1: # First call (with condition) |
| 379 | + return mock_cond_x0, mock_eps |
| 380 | + else: # Second call (with uncondition) |
| 381 | + return mock_uncond_x0, mock_eps |
| 382 | + |
| 383 | + with patch.object(self.pipeline, "denoise", side_effect=mock_denoise): |
| 384 | + # Call the returned x0_fn |
| 385 | + result = x0_fn(noise_x, sigma) |
| 386 | + |
| 387 | + # Verify denoise was called twice |
| 388 | + assert call_count[0] == 2, "mock_denoise should be called twice (condition and uncondition)" |
| 389 | + |
| 390 | + # Verify the result follows the guidance formula: cond_x0 + guidance * (cond_x0 - uncond_x0) |
| 391 | + expected_result = mock_cond_x0 + guidance * (mock_cond_x0 - mock_uncond_x0) |
| 392 | + assert torch.allclose(result, expected_result, rtol=1e-3, atol=1e-5), ( |
| 393 | + "x0_fn output doesn't match expected guidance formula" |
| 394 | + ) |
| 395 | + |
| 396 | + # Verify output shape and dtype |
| 397 | + assert result.shape == noise_x.shape, f"Expected result shape {noise_x.shape}, got {result.shape}" |
| 398 | + assert result.device.type == "cuda" |
| 399 | + assert result.dtype == torch.bfloat16 |
0 commit comments