Skip to content
3 changes: 2 additions & 1 deletion dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel.shared.utils import dtype_from_str

from dfm.src.automodel.distributed.dfm_parallelizer import WanParallelizationStrategy
from dfm.src.automodel.distributed.dfm_parallelizer import HunyuanParallelizationStrategy, WanParallelizationStrategy


logger = logging.getLogger(__name__)


def _init_parallelizer():
parallelizer.PARALLELIZATION_STRATEGIES["WanTransformer3DModel"] = WanParallelizationStrategy()
parallelizer.PARALLELIZATION_STRATEGIES["HunyuanVideo15Transformer3DModel"] = HunyuanParallelizationStrategy()


def _choose_device(device: Optional[torch.device]) -> torch.device:
Expand Down
6 changes: 3 additions & 3 deletions dfm/src/automodel/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dfm.src.automodel.datasets.wan21 import (
from dfm.src.automodel.datasets.dataloader import (
MetaFilesDataset,
build_dataloader,
build_node_parallel_sampler,
build_wan21_dataloader,
collate_fn,
create_dataloader,
)


__all__ = [
"build_dataloader",
"MetaFilesDataset",
"build_node_parallel_sampler",
"build_wan21_dataloader",
"collate_fn",
"create_dataloader",
]
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,20 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # type: ignore[ov
text_embeddings: torch.Tensor = data["text_embeddings"].to(self.device)
video_latents: torch.Tensor = data["video_latents"].to(self.device)

# Load text_mask if available (backwards compatible)
text_mask = data.get("text_mask")
text_embeddings_2 = data.get("text_embeddings_2")
text_mask_2 = data.get("text_mask_2")
image_embeds = data.get("image_embeds")
if text_mask is not None:
text_mask = text_mask.to(self.device)
if text_embeddings_2 is not None:
text_embeddings_2 = text_embeddings_2.to(self.device)
if text_mask_2 is not None:
text_mask_2 = text_mask_2.to(self.device)
if image_embeds is not None:
image_embeds = image_embeds.to(self.device)

if self.transform_text is not None:
text_embeddings = self.transform_text(text_embeddings)
if self.transform_video is not None:
Expand All @@ -126,13 +140,25 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # type: ignore[ov
"num_frames": data.get("num_frames", "unknown"),
}

return {
result = {
"text_embeddings": text_embeddings,
"video_latents": video_latents,
"metadata": data.get("metadata", {}),
"file_info": file_info,
}

# Add text_mask if available (backwards compatible)
if text_mask is not None:
result["text_mask"] = text_mask
if text_embeddings_2 is not None:
result["text_embeddings_2"] = text_embeddings_2
if text_mask_2 is not None:
result["text_mask_2"] = text_mask_2
if image_embeds is not None:
result["image_embeds"] = image_embeds

return result


def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
if len(batch) > 0:
Expand All @@ -141,13 +167,30 @@ def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
# use cat to stack the tensors in the batch
text_embeddings = torch.cat([item["text_embeddings"] for item in batch], dim=0)
video_latents = torch.cat([item["video_latents"] for item in batch], dim=0)
return {

result = {
"text_embeddings": text_embeddings,
"video_latents": video_latents,
"metadata": [item["metadata"] for item in batch],
"file_info": [item["file_info"] for item in batch],
}

# Collate text_mask if available (backwards compatible)
if len(batch) > 0 and "text_mask" in batch[0]:
text_mask = torch.cat([item["text_mask"] for item in batch], dim=0)
result["text_mask"] = text_mask
if len(batch) > 0 and "text_embeddings_2" in batch[0]:
text_embeddings_2 = torch.cat([item["text_embeddings_2"] for item in batch], dim=0)
result["text_embeddings_2"] = text_embeddings_2
if len(batch) > 0 and "text_mask_2" in batch[0]:
text_mask_2 = torch.cat([item["text_mask_2"] for item in batch], dim=0)
result["text_mask_2"] = text_mask_2
if len(batch) > 0 and "image_embeds" in batch[0]:
image_embeds = torch.cat([item["image_embeds"] for item in batch], dim=0)
result["image_embeds"] = image_embeds

return result


def build_node_parallel_sampler(
dataset: "Dataset",
Expand All @@ -167,7 +210,7 @@ def build_node_parallel_sampler(
)


def build_wan21_dataloader(
def build_dataloader(
*,
meta_folder: str,
batch_size: int,
Expand Down Expand Up @@ -211,4 +254,4 @@ def create_dataloader(
batch_size: int,
num_nodes: int,
) -> Tuple[DataLoader, Optional[DistributedSampler]]:
return build_wan21_dataloader(meta_folder=meta_folder, batch_size=batch_size, num_nodes=num_nodes)
return build_dataloader(meta_folder=meta_folder, batch_size=batch_size, num_nodes=num_nodes)
48 changes: 48 additions & 0 deletions dfm/src/automodel/distributed/dfm_parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh
Expand Down Expand Up @@ -146,3 +147,50 @@ def parallelize(
offload_policy=offload_policy,
reshard_after_forward=False,
)


class HunyuanParallelizationStrategy(ParallelizationStrategy):
"""Parallelization strategy for Hunyuan-style transformer modules used in HunyuanVideo.."""

def parallelize(
self,
model: nn.Module,
device_mesh: DeviceMesh,
mp_policy: Optional[MixedPrecisionPolicy] = None,
offload_policy: Optional[OffloadPolicy] = None,
sequence_parallel: bool = False,
activation_checkpointing: bool = True,
tp_shard_plan: Optional[Union[Dict[str, ParallelStyle], str]] = None,
dp_replicate_mesh_name: str = "dp_replicate",
dp_shard_cp_mesh_name: str = "dp_shard_cp",
tp_mesh_name: str = "tp",
) -> nn.Module:
tp_mesh = device_mesh[tp_mesh_name]
dp_mesh_dim_names = (dp_replicate_mesh_name, dp_shard_cp_mesh_name)
dp_mesh = device_mesh[dp_mesh_dim_names]

# Mixed precision default like Default strategy
if not mp_policy:
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
)
# Apply activation checkpointing to transformer blocks if requested
if activation_checkpointing:
for idx in range(len(model.transformer_blocks)):
model.transformer_blocks[idx] = checkpoint_wrapper(
model.transformer_blocks[idx],
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)

# Apply FSDP sharding recursively and to root
apply_fsdp2_sharding_recursively(model, dp_mesh, mp_policy, offload_policy)

return fully_shard(
model,
mesh=dp_mesh,
mp_policy=mp_policy,
offload_policy=offload_policy,
reshard_after_forward=False,
)
43 changes: 43 additions & 0 deletions dfm/src/automodel/flow_matching/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Model adapters for FlowMatching Pipeline.

This module provides model-specific adapters that decouple the flow matching
logic from model-specific implementation details.

Available Adapters:
- ModelAdapter: Abstract base class for all adapters
- HunyuanAdapter: For HunyuanVideo 1.5 style models
- SimpleAdapter: For simple transformer models (e.g., Wan)

Usage:
from automodel.flow_matching.adapters import HunyuanAdapter, SimpleAdapter

# Or import the base class to create custom adapters
from automodel.flow_matching.adapters import ModelAdapter
"""

from .base import FlowMatchingContext, ModelAdapter
from .hunyuan import HunyuanAdapter
from .simple import SimpleAdapter


__all__ = [
"FlowMatchingContext",
"ModelAdapter",
"HunyuanAdapter",
"SimpleAdapter",
]
160 changes: 160 additions & 0 deletions dfm/src/automodel/flow_matching/adapters/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Base classes and data structures for model adapters.

This module defines the abstract ModelAdapter class and the FlowMatchingContext
dataclass used to pass data between the pipeline and adapters.
"""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict

import torch
import torch.nn as nn


@dataclass
class FlowMatchingContext:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this context limited to just flow matching? if i'm not missing something EDM pipeline requires the same set of attributes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not specifically. But since I've implemented only flow matching thought to call it this.

"""
Context object passed to model adapters containing all necessary data.

This provides a clean interface for adapters to access the data they need
without coupling to the batch dictionary structure.

Attributes:
noisy_latents: [B, C, F, H, W] - Noisy latents after interpolation
video_latents: [B, C, F, H, W] - Original clean latents
timesteps: [B] - Sampled timesteps
sigma: [B] - Sigma values
task_type: "t2v" or "i2v"
data_type: "video" or "image"
device: Device for tensor operations
dtype: Data type for tensor operations
batch: Original batch dictionary (for model-specific data)
"""

# Core tensors
noisy_latents: torch.Tensor
video_latents: torch.Tensor
timesteps: torch.Tensor
sigma: torch.Tensor

# Task info
task_type: str
data_type: str

# Device/dtype
device: torch.device
dtype: torch.dtype

# Original batch (for model-specific data)
batch: Dict[str, Any]


class ModelAdapter(ABC):
"""
Abstract base class for model-specific forward pass logic.

Implement this class to add support for new model architectures
without modifying the FlowMatchingPipeline.

The adapter pattern decouples the flow matching logic from model-specific
details like input preparation and forward pass conventions.

Example:
class MyCustomAdapter(ModelAdapter):
def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]:
return {
"x": context.noisy_latents,
"t": context.timesteps,
"cond": context.batch["my_conditioning"],
}

def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor:
return model(**inputs)

pipeline = FlowMatchingPipelineV2(model_adapter=MyCustomAdapter())
"""

@abstractmethod
def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]:
"""
Prepare model-specific inputs from the context.

Args:
context: FlowMatchingContext containing all necessary data

Returns:
Dictionary of inputs to pass to the model's forward method
"""
pass

@abstractmethod
def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor:
"""
Execute the model forward pass.

Args:
model: The model to call
inputs: Dictionary of inputs from prepare_inputs()

Returns:
Model prediction tensor
"""
pass

def get_condition_latents(self, latents: torch.Tensor, task_type: str) -> torch.Tensor:
"""
Generate conditional latents based on task type.

Override this method if your model uses a different conditioning scheme.
Default implementation adds a channel for conditioning mask.

Args:
latents: Input latents [B, C, F, H, W]
task_type: Task type ("t2v" or "i2v")

Returns:
Conditional latents [B, C+1, F, H, W]
"""
b, c, f, h, w = latents.shape
cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype)

if task_type == "t2v":
return cond
elif task_type == "i2v":
cond[:, :-1, :1] = latents[:, :, :1]
cond[:, -1, 0] = 1
return cond
else:
raise ValueError(f"Unsupported task type: {task_type}")

def post_process_prediction(self, model_pred: torch.Tensor) -> torch.Tensor:
"""
Post-process model prediction if needed.

Override this for models that return extra outputs or need transformation.

Args:
model_pred: Raw model output

Returns:
Processed prediction tensor
"""
if isinstance(model_pred, tuple):
return model_pred[0]
return model_pred
Loading