Skip to content

Commit 1718290

Browse files
committed
edm and data preprocess tests.
Signed-off-by: sajadn <snorouzi@nvidia.com>
1 parent 032f2e1 commit 1718290

File tree

4 files changed

+475
-0
lines changed

4 files changed

+475
-0
lines changed

tests/unit_tests/megatron/model/dit/__init__.py

Whitespace-only changes.

tests/unit_tests/megatron/model/dit/edm/__init__.py

Whitespace-only changes.
Lines changed: 399 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,399 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
import torch
17+
18+
from dfm.src.common.utils.batch_ops import batch_mul
19+
from dfm.src.megatron.model.dit.edm.edm_pipeline import EDMPipeline
20+
21+
22+
class _DummyModel:
23+
"""Dummy model for testing that mimics the DiT network interface."""
24+
25+
def __call__(self, x, timesteps, **condition):
26+
# Return zeros matching input shape
27+
return torch.zeros_like(x)
28+
29+
30+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Requires CUDA")
31+
class TestEDMPipeline:
32+
"""Test class for EDMPipeline with shared setup."""
33+
34+
def setup_method(self, method, monkeypatch=None):
35+
"""Set up test fixtures before each test method."""
36+
# Stub parallel_state functions to avoid requiring initialization
37+
from megatron.core import parallel_state
38+
39+
if monkeypatch:
40+
monkeypatch.setattr(
41+
parallel_state, "get_data_parallel_rank", lambda with_context_parallel=False: 0, raising=False
42+
)
43+
monkeypatch.setattr(parallel_state, "get_context_parallel_world_size", lambda: 1, raising=False)
44+
monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: True, raising=False)
45+
monkeypatch.setattr(parallel_state, "get_context_parallel_group", lambda: None, raising=False)
46+
47+
# Create pipeline with common parameters
48+
self.sigma_data = 0.5
49+
self.pipeline = EDMPipeline(
50+
vae=None,
51+
p_mean=0.0,
52+
p_std=1.0,
53+
sigma_max=80.0,
54+
sigma_min=0.0002,
55+
sigma_data=self.sigma_data,
56+
seed=1234,
57+
)
58+
59+
# Create and assign dummy model
60+
self.model = _DummyModel()
61+
self.pipeline.net = self.model
62+
63+
# Create common test data shapes
64+
self.batch_size = 2
65+
self.channels = 4
66+
self.height = self.width = 8
67+
68+
# Create common test tensors
69+
self.x0 = torch.randn(self.batch_size, self.channels, self.height, self.width).to(
70+
**self.pipeline.tensor_kwargs
71+
)
72+
self.sigma = torch.ones(self.batch_size).to(**self.pipeline.tensor_kwargs) * 1.0
73+
self.condition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)}
74+
self.epsilon = torch.randn(self.batch_size, self.channels, self.height, self.width).to(
75+
**self.pipeline.tensor_kwargs
76+
)
77+
78+
def test_denoise(self, monkeypatch):
79+
"""Test the denoise method produces correct output shapes and values."""
80+
# Initialize with monkeypatch
81+
self.setup_method(None, monkeypatch)
82+
83+
# Create test inputs (xt on CPU for conversion test)
84+
xt = torch.randn(self.batch_size, self.channels, self.height, self.width)
85+
sigma = torch.ones(self.batch_size) * 1.0
86+
87+
# Test Case 1: is_pipeline_last_stage = True
88+
# Call denoise
89+
x0_pred, eps_pred = self.pipeline.denoise(xt, sigma, self.condition)
90+
91+
# Verify outputs have correct shapes
92+
assert x0_pred.shape == xt.shape, f"Expected x0_pred shape {xt.shape}, got {x0_pred.shape}"
93+
assert eps_pred.shape == xt.shape, f"Expected eps_pred shape {xt.shape}, got {eps_pred.shape}"
94+
95+
# Verify outputs are on CUDA with correct dtype
96+
assert x0_pred.device.type == "cuda"
97+
assert x0_pred.dtype == torch.bfloat16
98+
assert eps_pred.device.type == "cuda"
99+
assert eps_pred.dtype == torch.bfloat16
100+
101+
# Verify the outputs follow the expected formulas
102+
# Convert inputs to expected dtype/device for comparison
103+
xt_converted = xt.to(**self.pipeline.tensor_kwargs)
104+
sigma_converted = sigma.to(**self.pipeline.tensor_kwargs)
105+
106+
# Get scaling factors
107+
c_skip, c_out, c_in, c_noise = self.pipeline.scaling(sigma=sigma_converted)
108+
109+
# Since model returns zeros, net_output = 0
110+
# Expected: x0_pred = c_skip * xt + c_out * 0 = c_skip * xt
111+
expected_x0_pred = batch_mul(c_skip, xt_converted)
112+
assert torch.allclose(x0_pred, expected_x0_pred, rtol=1e-3, atol=1e-5), "x0_pred doesn't match expected value"
113+
114+
# Expected: eps_pred = (xt - x0_pred) / sigma
115+
expected_eps_pred = batch_mul(xt_converted - x0_pred, 1.0 / sigma_converted)
116+
assert torch.allclose(eps_pred, expected_eps_pred, rtol=1e-3, atol=1e-5), (
117+
"eps_pred doesn't match expected value"
118+
)
119+
120+
# Test Case 2: is_pipeline_last_stage = False
121+
# Mock is_pipeline_last_stage to return False
122+
from megatron.core import parallel_state
123+
124+
monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: False)
125+
126+
# Call denoise again
127+
net_output = self.pipeline.denoise(xt, sigma, self.condition)
128+
129+
# Verify output is a single tensor (not a tuple)
130+
assert isinstance(net_output, torch.Tensor), "Expected net_output to be a single tensor when not last stage"
131+
assert not isinstance(net_output, tuple), "Expected net_output to not be a tuple when not last stage"
132+
133+
# Verify output has correct shape (same as model output)
134+
assert net_output.shape == xt.shape, f"Expected net_output shape {xt.shape}, got {net_output.shape}"
135+
136+
# Verify output is on CUDA with correct dtype
137+
assert net_output.device.type == "cuda"
138+
assert net_output.dtype == torch.bfloat16
139+
140+
# Since model returns zeros, net_output should be zeros
141+
assert torch.allclose(net_output, torch.zeros_like(xt_converted), rtol=1e-3, atol=1e-5), (
142+
"net_output doesn't match expected value (zeros from dummy model)"
143+
)
144+
145+
def test_compute_loss_with_epsilon_and_sigma(self, monkeypatch):
146+
"""Test the compute_loss_with_epsilon_and_sigma method produces correct output shapes and values."""
147+
# Initialize with monkeypatch
148+
self.setup_method(None, monkeypatch)
149+
150+
# Create test inputs
151+
data_batch = {"video": self.x0}
152+
x0_from_data_batch = self.x0
153+
154+
# Call compute_loss_with_epsilon_and_sigma
155+
output_batch, pred_mse, edm_loss = self.pipeline.compute_loss_with_epsilon_and_sigma(
156+
data_batch, x0_from_data_batch, self.x0, self.condition, self.epsilon, self.sigma
157+
)
158+
159+
# Verify output_batch contains expected keys
160+
assert "x0" in output_batch
161+
assert "xt" in output_batch
162+
assert "sigma" in output_batch
163+
assert "weights_per_sigma" in output_batch
164+
assert "condition" in output_batch
165+
assert "model_pred" in output_batch
166+
assert "mse_loss" in output_batch
167+
assert "edm_loss" in output_batch
168+
169+
# Verify shapes
170+
assert output_batch["x0"].shape == self.x0.shape
171+
assert output_batch["xt"].shape == self.x0.shape
172+
assert output_batch["sigma"].shape == self.sigma.shape
173+
assert output_batch["weights_per_sigma"].shape == self.sigma.shape
174+
assert pred_mse.shape == self.x0.shape
175+
assert edm_loss.shape == self.x0.shape
176+
177+
# Verify the loss computation follows expected formulas
178+
# 1. Compute expected xt from marginal probability
179+
mean, std = self.pipeline.sde.marginal_prob(self.x0, self.sigma)
180+
expected_xt = mean + batch_mul(std, self.epsilon)
181+
assert torch.allclose(output_batch["xt"], expected_xt, rtol=1e-3, atol=1e-5), "xt doesn't match expected value"
182+
183+
# 2. Verify loss weights
184+
expected_weights = (self.sigma**2 + self.sigma_data**2) / (self.sigma * self.sigma_data) ** 2
185+
assert torch.allclose(output_batch["weights_per_sigma"], expected_weights, rtol=1e-3, atol=1e-5), (
186+
"weights_per_sigma doesn't match expected value"
187+
)
188+
189+
# 3. Verify edm_loss = weights * (x0 - x0_pred)^2
190+
x0_pred = output_batch["model_pred"]["x0_pred"]
191+
expected_pred_mse = (self.x0 - x0_pred) ** 2
192+
assert torch.allclose(pred_mse, expected_pred_mse, rtol=1e-3, atol=1e-5), (
193+
"pred_mse doesn't match expected value"
194+
)
195+
196+
expected_edm_loss = batch_mul(expected_pred_mse, expected_weights)
197+
assert torch.allclose(edm_loss, expected_edm_loss, rtol=1e-3, atol=1e-5), (
198+
"edm_loss doesn't match expected value"
199+
)
200+
201+
# 4. Verify scalar losses are proper means
202+
assert torch.isclose(output_batch["mse_loss"], pred_mse.mean(), rtol=1e-3, atol=1e-5)
203+
assert torch.isclose(output_batch["edm_loss"], edm_loss.mean(), rtol=1e-3, atol=1e-5)
204+
205+
def test_training_step(self, monkeypatch):
206+
"""Test the training_step method with mocked compute_loss_with_epsilon_and_sigma."""
207+
from unittest.mock import patch
208+
209+
# Initialize with monkeypatch
210+
self.setup_method(None, monkeypatch)
211+
212+
# Create test data batch
213+
data_batch = {
214+
"video": self.x0,
215+
"context_embeddings": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs),
216+
}
217+
iteration = 0
218+
219+
# Test Case 1: is_pipeline_last_stage = True
220+
# Mock compute_loss_with_epsilon_and_sigma to return expected values
221+
mock_output_batch = {
222+
"x0": self.x0,
223+
"xt": torch.randn_like(self.x0),
224+
"sigma": self.sigma,
225+
"weights_per_sigma": torch.ones_like(self.sigma),
226+
"condition": self.condition,
227+
"model_pred": {"x0_pred": torch.randn_like(self.x0), "eps_pred": torch.randn_like(self.x0)},
228+
"mse_loss": torch.tensor(0.5, **self.pipeline.tensor_kwargs),
229+
"edm_loss": torch.tensor(0.3, **self.pipeline.tensor_kwargs),
230+
}
231+
mock_pred_mse = torch.randn_like(self.x0)
232+
mock_edm_loss = torch.randn_like(self.x0)
233+
234+
with patch.object(
235+
self.pipeline,
236+
"compute_loss_with_epsilon_and_sigma",
237+
return_value=(mock_output_batch, mock_pred_mse, mock_edm_loss),
238+
) as mock_compute_loss:
239+
# Call training_step
240+
result = self.pipeline.training_step(self.model, data_batch, iteration)
241+
242+
# Verify compute_loss_with_epsilon_and_sigma was called once
243+
assert mock_compute_loss.call_count == 1
244+
245+
# Verify return values are correct (output_batch, edm_loss)
246+
assert len(result) == 2
247+
output_batch, edm_loss = result
248+
assert output_batch == mock_output_batch
249+
assert torch.equal(edm_loss, mock_edm_loss)
250+
251+
# Test Case 2: is_pipeline_last_stage = False
252+
# Mock is_pipeline_last_stage to return False
253+
from megatron.core import parallel_state
254+
255+
monkeypatch.setattr(parallel_state, "is_pipeline_last_stage", lambda: False)
256+
257+
# Mock compute_loss_with_epsilon_and_sigma to return net_output only
258+
mock_net_output = torch.randn_like(self.x0)
259+
260+
with patch.object(
261+
self.pipeline, "compute_loss_with_epsilon_and_sigma", return_value=mock_net_output
262+
) as mock_compute_loss:
263+
# Call training_step
264+
result = self.pipeline.training_step(self.model, data_batch, iteration)
265+
266+
# Verify compute_loss_with_epsilon_and_sigma was called once
267+
assert mock_compute_loss.call_count == 1
268+
269+
# Verify return value is just net_output (not a tuple)
270+
assert torch.equal(result, mock_net_output)
271+
272+
def test_get_data_and_condition(self, monkeypatch):
273+
"""Test the get_data_and_condition method with different dropout rates."""
274+
# Initialize with monkeypatch
275+
self.setup_method(None, monkeypatch)
276+
277+
# Create test data batch
278+
video_data = torch.randn(self.batch_size, self.channels, self.height, self.width).to(
279+
**self.pipeline.tensor_kwargs
280+
)
281+
context_embeddings = torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)
282+
283+
data_batch = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()}
284+
285+
# Test Case 1: With default dropout_rate (0.2)
286+
raw_state, latent_state, condition = self.pipeline.get_data_and_condition(data_batch.copy(), dropout_rate=0.2)
287+
288+
# Verify raw_state is video * sigma_data
289+
expected_raw_state = video_data * self.sigma_data
290+
assert torch.allclose(raw_state, expected_raw_state, rtol=1e-3, atol=1e-5), (
291+
"raw_state doesn't match expected value (video * sigma_data)"
292+
)
293+
294+
# Verify latent_state equals raw_state
295+
assert torch.equal(latent_state, raw_state), "latent_state should equal raw_state"
296+
297+
# Verify condition contains crossattn_emb
298+
assert "crossattn_emb" in condition, "condition should contain 'crossattn_emb' key"
299+
assert condition["crossattn_emb"].shape == context_embeddings.shape, (
300+
f"Expected crossattn_emb shape {context_embeddings.shape}, got {condition['crossattn_emb'].shape}"
301+
)
302+
303+
# Verify crossattn_emb is on CUDA with correct dtype
304+
assert condition["crossattn_emb"].device.type == "cuda"
305+
assert condition["crossattn_emb"].dtype == torch.bfloat16
306+
307+
# Test Case 2: With dropout_rate=0.0 (no dropout, should keep all values)
308+
data_batch_no_dropout = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()}
309+
raw_state_no_dropout, latent_state_no_dropout, condition_no_dropout = self.pipeline.get_data_and_condition(
310+
data_batch_no_dropout, dropout_rate=0.0
311+
)
312+
313+
# With dropout_rate=0.0, crossattn_emb should equal context_embeddings
314+
assert torch.allclose(condition_no_dropout["crossattn_emb"], context_embeddings, rtol=1e-3, atol=1e-5), (
315+
"With dropout_rate=0.0, crossattn_emb should equal original context_embeddings"
316+
)
317+
318+
# Test Case 3: With dropout_rate=1.0 (complete dropout, should zero out all values)
319+
data_batch_full_dropout = {"video": video_data.clone(), "context_embeddings": context_embeddings.clone()}
320+
raw_state_full_dropout, latent_state_full_dropout, condition_full_dropout = (
321+
self.pipeline.get_data_and_condition(data_batch_full_dropout, dropout_rate=1.0)
322+
)
323+
324+
# With dropout_rate=1.0, crossattn_emb should be all zeros
325+
expected_zeros = torch.zeros_like(context_embeddings)
326+
assert torch.allclose(condition_full_dropout["crossattn_emb"], expected_zeros, rtol=1e-3, atol=1e-5), (
327+
"With dropout_rate=1.0, crossattn_emb should be all zeros"
328+
)
329+
330+
# Verify raw_state is consistent across all dropout rates
331+
assert torch.equal(raw_state, raw_state_no_dropout), "raw_state should be consistent regardless of dropout"
332+
assert torch.equal(raw_state, raw_state_full_dropout), "raw_state should be consistent regardless of dropout"
333+
334+
def test_get_x0_fn_from_batch(self, monkeypatch):
335+
"""Test the get_x0_fn_from_batch method returns a callable with correct guidance behavior."""
336+
from unittest.mock import patch
337+
338+
# Initialize with monkeypatch
339+
self.setup_method(None, monkeypatch)
340+
341+
# Create test data batch
342+
video_data = torch.randn(self.batch_size, self.channels, self.height, self.width).to(
343+
**self.pipeline.tensor_kwargs
344+
)
345+
context_embeddings = torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)
346+
347+
data_batch = {"video": video_data, "context_embeddings": context_embeddings}
348+
349+
# Create mock condition and uncondition
350+
mock_condition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)}
351+
mock_uncondition = {"crossattn_emb": torch.randn(self.batch_size, 10, 512).to(**self.pipeline.tensor_kwargs)}
352+
353+
# Mock get_condition_uncondition to return our mock conditions
354+
with patch.object(self.pipeline, "get_condition_uncondition", return_value=(mock_condition, mock_uncondition)):
355+
# Test Case 1: Default guidance (1.5)
356+
guidance = 1.5
357+
x0_fn = self.pipeline.get_x0_fn_from_batch(data_batch, guidance=guidance)
358+
359+
# Verify x0_fn is callable
360+
assert callable(x0_fn), "get_x0_fn_from_batch should return a callable"
361+
362+
# Create test inputs for the returned function
363+
noise_x = torch.randn(self.batch_size, self.channels, self.height, self.width).to(
364+
**self.pipeline.tensor_kwargs
365+
)
366+
sigma = torch.ones(self.batch_size).to(**self.pipeline.tensor_kwargs) * 1.0
367+
368+
# Create mock outputs for denoise calls
369+
mock_cond_x0 = torch.randn_like(noise_x)
370+
mock_uncond_x0 = torch.randn_like(noise_x)
371+
mock_eps = torch.randn_like(noise_x) # dummy eps_pred (not used in x0_fn)
372+
373+
# Mock denoise to return different values for condition vs uncondition
374+
call_count = [0]
375+
376+
def mock_denoise(xt, sig, cond):
377+
call_count[0] += 1
378+
if call_count[0] == 1: # First call (with condition)
379+
return mock_cond_x0, mock_eps
380+
else: # Second call (with uncondition)
381+
return mock_uncond_x0, mock_eps
382+
383+
with patch.object(self.pipeline, "denoise", side_effect=mock_denoise):
384+
# Call the returned x0_fn
385+
result = x0_fn(noise_x, sigma)
386+
387+
# Verify denoise was called twice
388+
assert call_count[0] == 2, "mock_denoise should be called twice (condition and uncondition)"
389+
390+
# Verify the result follows the guidance formula: cond_x0 + guidance * (cond_x0 - uncond_x0)
391+
expected_result = mock_cond_x0 + guidance * (mock_cond_x0 - mock_uncond_x0)
392+
assert torch.allclose(result, expected_result, rtol=1e-3, atol=1e-5), (
393+
"x0_fn output doesn't match expected guidance formula"
394+
)
395+
396+
# Verify output shape and dtype
397+
assert result.shape == noise_x.shape, f"Expected result shape {noise_x.shape}, got {result.shape}"
398+
assert result.device.type == "cuda"
399+
assert result.dtype == torch.bfloat16

0 commit comments

Comments
 (0)