Skip to content

Commit be9f8ca

Browse files
authored
feat: Add more comprehensive testing for the automodel path (#82)
* Add more comprehensive testing for the automodel path Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Adding unit tests for the flow matching pipeline Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Adding functional test for Wan Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Fixing linting errors Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Linting fixes Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Increase test timeout Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> * Remove flash attention3 as the default attention backend during training Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com> --------- Signed-off-by: Pranav Prashant Thombre <pthombre@nvidia.com>
1 parent b8e525e commit be9f8ca

File tree

17 files changed

+2279
-34
lines changed

17 files changed

+2279
-34
lines changed

dfm/src/automodel/datasets/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
collate_fn,
2020
create_dataloader,
2121
)
22+
from dfm.src.automodel.datasets.mock_dataloader import (
23+
MockWanDataset,
24+
build_mock_dataloader,
25+
mock_collate_fn,
26+
)
2227

2328

2429
__all__ = [
@@ -27,4 +32,7 @@
2732
"build_node_parallel_sampler",
2833
"collate_fn",
2934
"create_dataloader",
35+
"MockWanDataset",
36+
"build_mock_dataloader",
37+
"mock_collate_fn",
3038
]
Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Mock dataloader for automodel WAN training tests.
16+
17+
This module provides a mock dataset and dataloader that generates random
18+
tensors with the correct shapes for WAN 2.1 training, allowing functional
19+
tests to run without requiring real data.
20+
"""
21+
22+
from typing import Dict, Optional, Tuple
23+
24+
import torch
25+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
26+
27+
28+
class MockWanDataset(Dataset):
29+
"""Mock dataset that generates random data matching WAN 2.1 expected format.
30+
31+
Args:
32+
length: Number of samples in the dataset.
33+
num_channels: Number of latent channels (default: 16 for WAN).
34+
num_frame_latents: Number of temporal latent frames.
35+
spatial_h: Height of spatial latents.
36+
spatial_w: Width of spatial latents.
37+
text_seq_len: Length of text sequence.
38+
text_embed_dim: Dimension of text embeddings (default: 4096 for UMT5).
39+
device: Device to place tensors on.
40+
"""
41+
42+
def __init__(
43+
self,
44+
length: int = 1024,
45+
num_channels: int = 16,
46+
num_frame_latents: int = 16,
47+
spatial_h: int = 30,
48+
spatial_w: int = 52,
49+
text_seq_len: int = 77,
50+
text_embed_dim: int = 4096,
51+
device: str = "cpu",
52+
) -> None:
53+
self.length = max(int(length), 1)
54+
self.num_channels = num_channels
55+
self.num_frame_latents = num_frame_latents
56+
self.spatial_h = spatial_h
57+
self.spatial_w = spatial_w
58+
self.text_seq_len = text_seq_len
59+
self.text_embed_dim = text_embed_dim
60+
self.device = device
61+
62+
def __len__(self) -> int:
63+
return self.length
64+
65+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
66+
"""Generate a mock sample with random data.
67+
68+
Returns:
69+
Dict containing:
70+
- text_embeddings: [1, text_seq_len, text_embed_dim]
71+
- video_latents: [1, num_channels, num_frame_latents, spatial_h, spatial_w]
72+
- metadata: empty dict
73+
- file_info: mock file info
74+
"""
75+
# Generate random video latents: (1, C, T, H, W)
76+
video_latents = torch.randn(
77+
1,
78+
self.num_channels,
79+
self.num_frame_latents,
80+
self.spatial_h,
81+
self.spatial_w,
82+
dtype=torch.float32,
83+
device=self.device,
84+
)
85+
86+
# Generate random text embeddings: (1, seq_len, embed_dim)
87+
text_embeddings = torch.randn(
88+
1,
89+
self.text_seq_len,
90+
self.text_embed_dim,
91+
dtype=torch.float32,
92+
device=self.device,
93+
)
94+
95+
return {
96+
"text_embeddings": text_embeddings,
97+
"video_latents": video_latents,
98+
"metadata": {},
99+
"file_info": {
100+
"meta_filename": f"mock_sample_{idx}.meta",
101+
"original_filename": f"mock_video_{idx}.mp4",
102+
"original_video_path": f"/mock/path/video_{idx}.mp4",
103+
"deterministic_latents": True,
104+
"memory_optimization": False,
105+
"num_frames": self.num_frame_latents * 4, # Approximate original frames
106+
},
107+
}
108+
109+
110+
def mock_collate_fn(batch):
111+
"""Collate function for mock dataset, matching the real collate_fn behavior."""
112+
text_embeddings = torch.cat([item["text_embeddings"] for item in batch], dim=0)
113+
video_latents = torch.cat([item["video_latents"] for item in batch], dim=0)
114+
115+
return {
116+
"text_embeddings": text_embeddings,
117+
"video_latents": video_latents,
118+
"metadata": [item["metadata"] for item in batch],
119+
"file_info": [item["file_info"] for item in batch],
120+
}
121+
122+
123+
def build_mock_dataloader(
124+
*,
125+
dp_rank: int = 0,
126+
dp_world_size: int = 1,
127+
batch_size: int = 1,
128+
num_workers: int = 0,
129+
device: str = "cpu",
130+
length: int = 1024,
131+
num_channels: int = 16,
132+
num_frame_latents: int = 16,
133+
spatial_h: int = 30,
134+
spatial_w: int = 52,
135+
text_seq_len: int = 77,
136+
text_embed_dim: int = 4096,
137+
shuffle: bool = True,
138+
) -> Tuple[DataLoader, Optional[DistributedSampler]]:
139+
"""Build a mock dataloader for WAN training tests.
140+
141+
This function follows the same interface as build_dataloader but generates
142+
random data instead of loading from .meta files.
143+
144+
Args:
145+
dp_rank: Data parallel rank.
146+
dp_world_size: Data parallel world size.
147+
batch_size: Batch size per GPU.
148+
num_workers: Number of dataloader workers.
149+
device: Device to place tensors on.
150+
length: Number of samples in mock dataset.
151+
num_channels: Number of latent channels (default: 16).
152+
num_frame_latents: Number of temporal latent frames.
153+
spatial_h: Height of spatial latents.
154+
spatial_w: Width of spatial latents.
155+
text_seq_len: Length of text sequence.
156+
text_embed_dim: Dimension of text embeddings.
157+
shuffle: Whether to shuffle data.
158+
159+
Returns:
160+
Tuple of (DataLoader, DistributedSampler or None).
161+
"""
162+
dataset = MockWanDataset(
163+
length=length,
164+
num_channels=num_channels,
165+
num_frame_latents=num_frame_latents,
166+
spatial_h=spatial_h,
167+
spatial_w=spatial_w,
168+
text_seq_len=text_seq_len,
169+
text_embed_dim=text_embed_dim,
170+
device=device,
171+
)
172+
173+
sampler = None
174+
if dp_world_size > 1:
175+
sampler = DistributedSampler(
176+
dataset,
177+
num_replicas=dp_world_size,
178+
rank=dp_rank,
179+
shuffle=shuffle,
180+
drop_last=False,
181+
)
182+
183+
use_pin_memory = device == "cpu"
184+
dataloader = DataLoader(
185+
dataset,
186+
batch_size=batch_size,
187+
shuffle=(sampler is None and shuffle),
188+
sampler=sampler,
189+
num_workers=num_workers,
190+
collate_fn=mock_collate_fn,
191+
pin_memory=use_pin_memory,
192+
)
193+
194+
return dataloader, sampler

dfm/src/automodel/flow_matching/adapters/base.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -117,32 +117,6 @@ def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor:
117117
"""
118118
pass
119119

120-
def get_condition_latents(self, latents: torch.Tensor, task_type: str) -> torch.Tensor:
121-
"""
122-
Generate conditional latents based on task type.
123-
124-
Override this method if your model uses a different conditioning scheme.
125-
Default implementation adds a channel for conditioning mask.
126-
127-
Args:
128-
latents: Input latents [B, C, F, H, W]
129-
task_type: Task type ("t2v" or "i2v")
130-
131-
Returns:
132-
Conditional latents [B, C+1, F, H, W]
133-
"""
134-
b, c, f, h, w = latents.shape
135-
cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype)
136-
137-
if task_type == "t2v":
138-
return cond
139-
elif task_type == "i2v":
140-
cond[:, :-1, :1] = latents[:, :, :1]
141-
cond[:, -1, 0] = 1
142-
return cond
143-
else:
144-
raise ValueError(f"Unsupported task type: {task_type}")
145-
146120
def post_process_prediction(self, model_pred: torch.Tensor) -> torch.Tensor:
147121
"""
148122
Post-process model prediction if needed.

dfm/src/automodel/flow_matching/adapters/hunyuan.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,29 @@ def __init__(
6565
self.default_image_embed_shape = default_image_embed_shape
6666
self.use_condition_latents = use_condition_latents
6767

68+
def get_condition_latents(self, latents: torch.Tensor, task_type: str) -> torch.Tensor:
69+
"""
70+
Generate conditional latents based on task type.
71+
72+
Args:
73+
latents: Input latents [B, C, F, H, W]
74+
task_type: Task type ("t2v" or "i2v")
75+
76+
Returns:
77+
Conditional latents [B, C+1, F, H, W]
78+
"""
79+
b, c, f, h, w = latents.shape
80+
cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype)
81+
82+
if task_type == "t2v":
83+
return cond
84+
elif task_type == "i2v":
85+
cond[:, :-1, :1] = latents[:, :, :1]
86+
cond[:, -1, 0] = 1
87+
return cond
88+
else:
89+
raise ValueError(f"Unsupported task type: {task_type}")
90+
6891
def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]:
6992
"""
7093
Prepare inputs for HunyuanVideo model.

dfm/src/automodel/recipes/train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def setup(self):
171171
self.rng = StatefulRNG(seed=self.seed, ranked=True)
172172

173173
self.model_id = self.cfg.get("model.pretrained_model_name_or_path")
174-
self.attention_backend = self.cfg.get("model.attention_backend", "_flash_3_hub")
174+
self.attention_backend = self.cfg.get("model.attention_backend")
175175
self.learning_rate = self.cfg.get("optim.learning_rate", 5e-6)
176176
self.bf16 = torch.bfloat16
177177

@@ -250,8 +250,6 @@ def setup(self):
250250
raise ValueError(
251251
"checkpoint config is required in YAML (enabled, checkpoint_dir, model_save_format, save_consolidated)"
252252
)
253-
if not checkpoint_cfg.get("enabled", False):
254-
raise ValueError("checkpoint.enabled must be true in YAML for diffusion training")
255253

256254
# Build BaseRecipe-style checkpointing configuration (DCP/TORCH_SAVE) from YAML
257255
model_state_dict_keys = list(self.model.state_dict().keys())

examples/automodel/pretrain/pretrain.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
from nemo_automodel.components.config._arg_parser import parse_args_and_load_config
1818

19-
from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe
19+
from dfm.src.automodel.recipes.train import TrainDiffusionRecipe
2020

2121

2222
def main(default_config_path="examples/automodel/pretrain/wan2_1_t2v_flow.yaml"):
2323
cfg = parse_args_and_load_config(default_config_path)
24-
recipe = TrainWan21DiffusionRecipe(cfg)
24+
recipe = TrainDiffusionRecipe(cfg)
2525
recipe.setup()
2626
recipe.run_train_validation_loop()
2727

examples/automodel/pretrain/wan2_1_t2v_flow.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ optim:
3737

3838

3939
flow_matching:
40+
adapter_type: "simple" # Options: "hunyuan", "simple"
41+
adapter_kwargs: {}
4042
use_sigma_noise: true
4143
timestep_sampling: uniform
4244
logit_mean: 0.0

tests/functional_tests/L2_Automodel_Wan21_Test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,4 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
CUDA_VISIBLE_DEVICES="0,1" uv run coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/automodel/wan21 -m "not pleasefixme" --with_downloads
15+
CUDA_VISIBLE_DEVICES="0,1" uv run --group automodel coverage run -a --data-file=/opt/DFM/.coverage --source=/opt/DFM/ -m pytest tests/functional_tests/automodel/wan21 -m "not pleasefixme" --with_downloads

0 commit comments

Comments
 (0)