diff --git a/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py b/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py index cb9e9d00..867a01d1 100644 --- a/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py +++ b/dfm/src/automodel/_diffusers/auto_diffusion_pipeline.py @@ -24,7 +24,7 @@ from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager from nemo_automodel.shared.utils import dtype_from_str -from dfm.src.automodel.distributed.dfm_parallelizer import WanParallelizationStrategy +from dfm.src.automodel.distributed.dfm_parallelizer import HunyuanParallelizationStrategy, WanParallelizationStrategy logger = logging.getLogger(__name__) @@ -32,6 +32,7 @@ def _init_parallelizer(): parallelizer.PARALLELIZATION_STRATEGIES["WanTransformer3DModel"] = WanParallelizationStrategy() + parallelizer.PARALLELIZATION_STRATEGIES["HunyuanVideo15Transformer3DModel"] = HunyuanParallelizationStrategy() def _choose_device(device: Optional[torch.device]) -> torch.device: diff --git a/dfm/src/automodel/datasets/__init__.py b/dfm/src/automodel/datasets/__init__.py index 051d4cd2..cc00daee 100644 --- a/dfm/src/automodel/datasets/__init__.py +++ b/dfm/src/automodel/datasets/__init__.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dfm.src.automodel.datasets.wan21 import ( +from dfm.src.automodel.datasets.dataloader import ( MetaFilesDataset, + build_dataloader, build_node_parallel_sampler, - build_wan21_dataloader, collate_fn, create_dataloader, ) __all__ = [ + "build_dataloader", "MetaFilesDataset", "build_node_parallel_sampler", - "build_wan21_dataloader", "collate_fn", "create_dataloader", ] diff --git a/dfm/src/automodel/datasets/wan21.py b/dfm/src/automodel/datasets/dataloader.py similarity index 78% rename from dfm/src/automodel/datasets/wan21.py rename to dfm/src/automodel/datasets/dataloader.py index a61e998e..c7dcad71 100644 --- a/dfm/src/automodel/datasets/wan21.py +++ b/dfm/src/automodel/datasets/dataloader.py @@ -112,6 +112,20 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # type: ignore[ov text_embeddings: torch.Tensor = data["text_embeddings"].to(self.device) video_latents: torch.Tensor = data["video_latents"].to(self.device) + # Load text_mask if available (backwards compatible) + text_mask = data.get("text_mask") + text_embeddings_2 = data.get("text_embeddings_2") + text_mask_2 = data.get("text_mask_2") + image_embeds = data.get("image_embeds") + if text_mask is not None: + text_mask = text_mask.to(self.device) + if text_embeddings_2 is not None: + text_embeddings_2 = text_embeddings_2.to(self.device) + if text_mask_2 is not None: + text_mask_2 = text_mask_2.to(self.device) + if image_embeds is not None: + image_embeds = image_embeds.to(self.device) + if self.transform_text is not None: text_embeddings = self.transform_text(text_embeddings) if self.transform_video is not None: @@ -126,13 +140,25 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]: # type: ignore[ov "num_frames": data.get("num_frames", "unknown"), } - return { + result = { "text_embeddings": text_embeddings, "video_latents": video_latents, "metadata": data.get("metadata", {}), "file_info": file_info, } + # Add text_mask if available (backwards compatible) + if text_mask is not None: + result["text_mask"] = text_mask + if text_embeddings_2 is not None: + result["text_embeddings_2"] = text_embeddings_2 + if text_mask_2 is not None: + result["text_mask_2"] = text_mask_2 + if image_embeds is not None: + result["image_embeds"] = image_embeds + + return result + def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: if len(batch) > 0: @@ -141,13 +167,30 @@ def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: # use cat to stack the tensors in the batch text_embeddings = torch.cat([item["text_embeddings"] for item in batch], dim=0) video_latents = torch.cat([item["video_latents"] for item in batch], dim=0) - return { + + result = { "text_embeddings": text_embeddings, "video_latents": video_latents, "metadata": [item["metadata"] for item in batch], "file_info": [item["file_info"] for item in batch], } + # Collate text_mask if available (backwards compatible) + if len(batch) > 0 and "text_mask" in batch[0]: + text_mask = torch.cat([item["text_mask"] for item in batch], dim=0) + result["text_mask"] = text_mask + if len(batch) > 0 and "text_embeddings_2" in batch[0]: + text_embeddings_2 = torch.cat([item["text_embeddings_2"] for item in batch], dim=0) + result["text_embeddings_2"] = text_embeddings_2 + if len(batch) > 0 and "text_mask_2" in batch[0]: + text_mask_2 = torch.cat([item["text_mask_2"] for item in batch], dim=0) + result["text_mask_2"] = text_mask_2 + if len(batch) > 0 and "image_embeds" in batch[0]: + image_embeds = torch.cat([item["image_embeds"] for item in batch], dim=0) + result["image_embeds"] = image_embeds + + return result + def build_node_parallel_sampler( dataset: "Dataset", @@ -167,7 +210,7 @@ def build_node_parallel_sampler( ) -def build_wan21_dataloader( +def build_dataloader( *, meta_folder: str, batch_size: int, @@ -211,4 +254,4 @@ def create_dataloader( batch_size: int, num_nodes: int, ) -> Tuple[DataLoader, Optional[DistributedSampler]]: - return build_wan21_dataloader(meta_folder=meta_folder, batch_size=batch_size, num_nodes=num_nodes) + return build_dataloader(meta_folder=meta_folder, batch_size=batch_size, num_nodes=num_nodes) diff --git a/dfm/src/automodel/distributed/dfm_parallelizer.py b/dfm/src/automodel/distributed/dfm_parallelizer.py index a2e8a3e5..1a8058c1 100644 --- a/dfm/src/automodel/distributed/dfm_parallelizer.py +++ b/dfm/src/automodel/distributed/dfm_parallelizer.py @@ -22,6 +22,7 @@ ) from torch import nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, checkpoint_wrapper, ) from torch.distributed.device_mesh import DeviceMesh @@ -146,3 +147,50 @@ def parallelize( offload_policy=offload_policy, reshard_after_forward=False, ) + + +class HunyuanParallelizationStrategy(ParallelizationStrategy): + """Parallelization strategy for Hunyuan-style transformer modules used in HunyuanVideo..""" + + def parallelize( + self, + model: nn.Module, + device_mesh: DeviceMesh, + mp_policy: Optional[MixedPrecisionPolicy] = None, + offload_policy: Optional[OffloadPolicy] = None, + sequence_parallel: bool = False, + activation_checkpointing: bool = True, + tp_shard_plan: Optional[Union[Dict[str, ParallelStyle], str]] = None, + dp_replicate_mesh_name: str = "dp_replicate", + dp_shard_cp_mesh_name: str = "dp_shard_cp", + tp_mesh_name: str = "tp", + ) -> nn.Module: + tp_mesh = device_mesh[tp_mesh_name] + dp_mesh_dim_names = (dp_replicate_mesh_name, dp_shard_cp_mesh_name) + dp_mesh = device_mesh[dp_mesh_dim_names] + + # Mixed precision default like Default strategy + if not mp_policy: + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + ) + # Apply activation checkpointing to transformer blocks if requested + if activation_checkpointing: + for idx in range(len(model.transformer_blocks)): + model.transformer_blocks[idx] = checkpoint_wrapper( + model.transformer_blocks[idx], + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ) + + # Apply FSDP sharding recursively and to root + apply_fsdp2_sharding_recursively(model, dp_mesh, mp_policy, offload_policy) + + return fully_shard( + model, + mesh=dp_mesh, + mp_policy=mp_policy, + offload_policy=offload_policy, + reshard_after_forward=False, + ) diff --git a/dfm/src/automodel/flow_matching/adapters/__init__.py b/dfm/src/automodel/flow_matching/adapters/__init__.py new file mode 100644 index 00000000..15cffef5 --- /dev/null +++ b/dfm/src/automodel/flow_matching/adapters/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Model adapters for FlowMatching Pipeline. + +This module provides model-specific adapters that decouple the flow matching +logic from model-specific implementation details. + +Available Adapters: +- ModelAdapter: Abstract base class for all adapters +- HunyuanAdapter: For HunyuanVideo 1.5 style models +- SimpleAdapter: For simple transformer models (e.g., Wan) + +Usage: + from automodel.flow_matching.adapters import HunyuanAdapter, SimpleAdapter + + # Or import the base class to create custom adapters + from automodel.flow_matching.adapters import ModelAdapter +""" + +from .base import FlowMatchingContext, ModelAdapter +from .hunyuan import HunyuanAdapter +from .simple import SimpleAdapter + + +__all__ = [ + "FlowMatchingContext", + "ModelAdapter", + "HunyuanAdapter", + "SimpleAdapter", +] diff --git a/dfm/src/automodel/flow_matching/adapters/base.py b/dfm/src/automodel/flow_matching/adapters/base.py new file mode 100644 index 00000000..144427bd --- /dev/null +++ b/dfm/src/automodel/flow_matching/adapters/base.py @@ -0,0 +1,160 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Base classes and data structures for model adapters. + +This module defines the abstract ModelAdapter class and the FlowMatchingContext +dataclass used to pass data between the pipeline and adapters. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict + +import torch +import torch.nn as nn + + +@dataclass +class FlowMatchingContext: + """ + Context object passed to model adapters containing all necessary data. + + This provides a clean interface for adapters to access the data they need + without coupling to the batch dictionary structure. + + Attributes: + noisy_latents: [B, C, F, H, W] - Noisy latents after interpolation + video_latents: [B, C, F, H, W] - Original clean latents + timesteps: [B] - Sampled timesteps + sigma: [B] - Sigma values + task_type: "t2v" or "i2v" + data_type: "video" or "image" + device: Device for tensor operations + dtype: Data type for tensor operations + batch: Original batch dictionary (for model-specific data) + """ + + # Core tensors + noisy_latents: torch.Tensor + video_latents: torch.Tensor + timesteps: torch.Tensor + sigma: torch.Tensor + + # Task info + task_type: str + data_type: str + + # Device/dtype + device: torch.device + dtype: torch.dtype + + # Original batch (for model-specific data) + batch: Dict[str, Any] + + +class ModelAdapter(ABC): + """ + Abstract base class for model-specific forward pass logic. + + Implement this class to add support for new model architectures + without modifying the FlowMatchingPipeline. + + The adapter pattern decouples the flow matching logic from model-specific + details like input preparation and forward pass conventions. + + Example: + class MyCustomAdapter(ModelAdapter): + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + return { + "x": context.noisy_latents, + "t": context.timesteps, + "cond": context.batch["my_conditioning"], + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + return model(**inputs) + + pipeline = FlowMatchingPipelineV2(model_adapter=MyCustomAdapter()) + """ + + @abstractmethod + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare model-specific inputs from the context. + + Args: + context: FlowMatchingContext containing all necessary data + + Returns: + Dictionary of inputs to pass to the model's forward method + """ + pass + + @abstractmethod + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute the model forward pass. + + Args: + model: The model to call + inputs: Dictionary of inputs from prepare_inputs() + + Returns: + Model prediction tensor + """ + pass + + def get_condition_latents(self, latents: torch.Tensor, task_type: str) -> torch.Tensor: + """ + Generate conditional latents based on task type. + + Override this method if your model uses a different conditioning scheme. + Default implementation adds a channel for conditioning mask. + + Args: + latents: Input latents [B, C, F, H, W] + task_type: Task type ("t2v" or "i2v") + + Returns: + Conditional latents [B, C+1, F, H, W] + """ + b, c, f, h, w = latents.shape + cond = torch.zeros([b, c + 1, f, h, w], device=latents.device, dtype=latents.dtype) + + if task_type == "t2v": + return cond + elif task_type == "i2v": + cond[:, :-1, :1] = latents[:, :, :1] + cond[:, -1, 0] = 1 + return cond + else: + raise ValueError(f"Unsupported task type: {task_type}") + + def post_process_prediction(self, model_pred: torch.Tensor) -> torch.Tensor: + """ + Post-process model prediction if needed. + + Override this for models that return extra outputs or need transformation. + + Args: + model_pred: Raw model output + + Returns: + Processed prediction tensor + """ + if isinstance(model_pred, tuple): + return model_pred[0] + return model_pred diff --git a/dfm/src/automodel/flow_matching/adapters/hunyuan.py b/dfm/src/automodel/flow_matching/adapters/hunyuan.py new file mode 100644 index 00000000..9dfa0cb1 --- /dev/null +++ b/dfm/src/automodel/flow_matching/adapters/hunyuan.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +HunyuanVideo model adapter for FlowMatching Pipeline. + +This adapter supports HunyuanVideo 1.5 style models with dual text encoders +and image embeddings for image-to-video conditioning. +""" + +from typing import Any, Dict, Tuple + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class HunyuanAdapter(ModelAdapter): + """ + Model adapter for HunyuanVideo 1.5 style models. + + These models use: + - Condition latents concatenated with noisy latents + - Dual text encoders with attention masks + - Image embeddings for i2v + + Expected batch keys: + - text_embeddings: Primary text encoder output [B, seq_len, dim] + - text_mask: Attention mask for primary encoder [B, seq_len] (optional) + - text_embeddings_2: Secondary text encoder output [B, seq_len, dim] (optional) + - text_mask_2: Attention mask for secondary encoder [B, seq_len] (optional) + - image_embeds: Image embeddings for i2v [B, seq_len, dim] (optional) + + Example: + adapter = HunyuanAdapter() + pipeline = FlowMatchingPipelineV2(model_adapter=adapter) + """ + + def __init__( + self, + default_image_embed_shape: Tuple[int, int] = (729, 1152), + use_condition_latents: bool = True, + ): + """ + Initialize the HunyuanAdapter. + + Args: + default_image_embed_shape: Default shape for image embeddings (seq_len, dim) + when not provided in batch. Defaults to (729, 1152). + use_condition_latents: Whether to concatenate condition latents with + noisy latents. Defaults to True. + """ + self.default_image_embed_shape = default_image_embed_shape + self.use_condition_latents = use_condition_latents + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for HunyuanVideo model. + + Args: + context: FlowMatchingContext with batch data + + Returns: + Dictionary containing: + - latents: Noisy latents (optionally concatenated with condition latents) + - timesteps: Timestep values + - encoder_hidden_states: Primary text embeddings + - encoder_attention_mask: Primary attention mask + - encoder_hidden_states_2: Secondary text embeddings + - encoder_attention_mask_2: Secondary attention mask + - image_embeds: Image embeddings + """ + batch = context.batch + batch_size = context.noisy_latents.shape[0] + device = context.device + dtype = context.dtype + + # Get text embeddings + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + # Get optional elements + text_mask = batch.get("text_mask") + text_embeddings_2 = batch.get("text_embeddings_2") + text_mask_2 = batch.get("text_mask_2") + + if text_mask is not None: + text_mask = text_mask.to(device, dtype=dtype) + if text_embeddings_2 is not None: + text_embeddings_2 = text_embeddings_2.to(device, dtype=dtype) + if text_mask_2 is not None: + text_mask_2 = text_mask_2.to(device, dtype=dtype) + + # Handle image embeds for i2v + if context.task_type == "i2v" and "image_embeds" in batch: + image_embeds = batch["image_embeds"].to(device, dtype=dtype) + else: + seq_len, dim = self.default_image_embed_shape + image_embeds = torch.zeros( + batch_size, + seq_len, + dim, + dtype=dtype, + device=device, + ) + + # Prepare latents (with or without condition) + if self.use_condition_latents: + cond_latents = self.get_condition_latents(context.video_latents, context.task_type) + latents = torch.cat([context.noisy_latents, cond_latents], dim=1) + else: + latents = context.noisy_latents + + return { + "latents": latents, + "timesteps": context.timesteps.to(dtype), + "encoder_hidden_states": text_embeddings, + "encoder_attention_mask": text_mask, + "encoder_hidden_states_2": text_embeddings_2, + "encoder_attention_mask_2": text_mask_2, + "image_embeds": image_embeds, + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for HunyuanVideo model. + + Args: + model: HunyuanVideo model + inputs: Dictionary from prepare_inputs() + + Returns: + Model prediction tensor + """ + model_pred = model( + inputs["latents"], + inputs["timesteps"], + encoder_hidden_states=inputs["encoder_hidden_states"], + encoder_attention_mask=inputs["encoder_attention_mask"], + encoder_hidden_states_2=inputs["encoder_hidden_states_2"], + encoder_attention_mask_2=inputs["encoder_attention_mask_2"], + image_embeds=inputs["image_embeds"], + return_dict=False, + ) + return self.post_process_prediction(model_pred) diff --git a/dfm/src/automodel/flow_matching/adapters/simple.py b/dfm/src/automodel/flow_matching/adapters/simple.py new file mode 100644 index 00000000..efb7aeb4 --- /dev/null +++ b/dfm/src/automodel/flow_matching/adapters/simple.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple transformer model adapter for FlowMatching Pipeline. + +This adapter supports simple transformer models with a basic interface, +such as Wan-style models. +""" + +from typing import Any, Dict + +import torch +import torch.nn as nn + +from .base import FlowMatchingContext, ModelAdapter + + +class SimpleAdapter(ModelAdapter): + """ + Model adapter for simple transformer models (e.g., Wan). + + These models use a simple interface with: + - hidden_states: noisy latents + - timestep: timestep values + - encoder_hidden_states: text embeddings + + Expected batch keys: + - text_embeddings: Text encoder output [B, seq_len, dim] + + Example: + adapter = SimpleAdapter() + pipeline = FlowMatchingPipelineV2(model_adapter=adapter) + """ + + def prepare_inputs(self, context: FlowMatchingContext) -> Dict[str, Any]: + """ + Prepare inputs for simple transformer model. + + Args: + context: FlowMatchingContext with batch data + + Returns: + Dictionary containing: + - hidden_states: Noisy latents + - timestep: Timestep values + - encoder_hidden_states: Text embeddings + """ + batch = context.batch + device = context.device + dtype = context.dtype + + # Get text embeddings + text_embeddings = batch["text_embeddings"].to(device, dtype=dtype) + if text_embeddings.ndim == 2: + text_embeddings = text_embeddings.unsqueeze(0) + + return { + "hidden_states": context.noisy_latents, + "timestep": context.timesteps.to(dtype), + "encoder_hidden_states": text_embeddings, + } + + def forward(self, model: nn.Module, inputs: Dict[str, Any]) -> torch.Tensor: + """ + Execute forward pass for simple transformer model. + + Args: + model: Transformer model + inputs: Dictionary from prepare_inputs() + + Returns: + Model prediction tensor + """ + model_pred = model( + hidden_states=inputs["hidden_states"], + timestep=inputs["timestep"], + encoder_hidden_states=inputs["encoder_hidden_states"], + return_dict=False, + ) + return self.post_process_prediction(model_pred) diff --git a/dfm/src/automodel/flow_matching/flow_matching_pipeline.py b/dfm/src/automodel/flow_matching/flow_matching_pipeline.py new file mode 100644 index 00000000..3e1a2478 --- /dev/null +++ b/dfm/src/automodel/flow_matching/flow_matching_pipeline.py @@ -0,0 +1,567 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FlowMatching Pipeline - Model-agnostic implementation with adapter pattern. + +This module provides a unified FlowMatchingPipeline class that is completely +independent of specific model implementations through the ModelAdapter abstraction. + +Features: +- Model-agnostic design via ModelAdapter protocol +- Various timestep sampling strategies (uniform, logit_normal, mode, lognorm) +- Flow shift transformation +- Sigma clamping for finetuning +- Loss weighting +- Detailed training logging +""" + +import logging +import math +import os +import random +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn as nn + +# Import adapters from the adapters module +from .adapters import ( + FlowMatchingContext, + HunyuanAdapter, + ModelAdapter, + SimpleAdapter, +) + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Noise Schedule +# ============================================================================= + + +class LinearInterpolationSchedule: + """Simple linear interpolation schedule for flow matching.""" + + def forward(self, x0: torch.Tensor, x1: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor: + """ + Linear interpolation: x_t = (1 - σ) * x_0 + σ * x_1 + + Args: + x0: Starting point (clean latents) + x1: Ending point (noise) + sigma: Sigma values in [0, 1] + + Returns: + Interpolated tensor at sigma + """ + sigma = sigma.view(-1, *([1] * (x0.ndim - 1))) + return (1.0 - sigma) * x0 + sigma * x1 + + +# ============================================================================= +# Flow Matching Pipeline +# ============================================================================= + + +class FlowMatchingPipeline: + """ + Flow Matching Pipeline - Model-agnostic implementation. + + This pipeline handles all flow matching training logic while delegating + model-specific operations to a ModelAdapter. This allows adding support + for new model architectures without modifying the pipeline code. + + Features: + - Noise scheduling with linear interpolation + - Timestep sampling with various strategies + - Flow shift transformation + - Sigma clamping for finetuning + - Loss weighting + - Detailed training logging + + Example: + # Create pipeline with HunyuanVideo adapter + from automodel.flow_matching.adapters import HunyuanAdapter + + pipeline = FlowMatchingPipeline( + model_adapter=HunyuanAdapter(), + flow_shift=3.0, + timestep_sampling="logit_normal", + ) + + # Training step + loss, metrics = pipeline.step(model, batch, device, dtype, global_step) + """ + + def __init__( + self, + model_adapter: ModelAdapter, + num_train_timesteps: int = 1000, + timestep_sampling: str = "logit_normal", + flow_shift: float = 3.0, + i2v_prob: float = 0.3, + # Logit-normal distribution parameters + logit_mean: float = 0.0, + logit_std: float = 1.0, + # Mix sampling parameters + mix_uniform_ratio: float = 0.1, + # Sigma clamping for finetuning (pretrain uses [0.0, 1.0]) + sigma_min: float = 0.0, + sigma_max: float = 1.0, + # Loss weighting + use_loss_weighting: bool = True, + # Logging + log_interval: int = 100, + summary_log_interval: int = 10, + device: Optional[torch.device] = None, + ): + """ + Initialize the FlowMatching pipeline. + + Args: + model_adapter: ModelAdapter instance for model-specific operations + num_train_timesteps: Total number of timesteps for the flow + timestep_sampling: Sampling strategy: + - "uniform": Pure uniform sampling + - "logit_normal": SD3-style logit-normal (recommended) + - "mode": Mode-based sampling + - "lognorm": Log-normal based sampling + - "mix": Mix of lognorm and uniform + flow_shift: Shift parameter for timestep transformation + i2v_prob: Probability of using image-to-video conditioning + logit_mean: Mean for logit-normal distribution + logit_std: Std for logit-normal distribution + mix_uniform_ratio: Ratio of uniform samples when using mix + sigma_min: Minimum sigma (0.0 for pretrain) + sigma_max: Maximum sigma (1.0 for pretrain) + use_loss_weighting: Whether to apply flow-based loss weighting + log_interval: Steps between detailed logs + summary_log_interval: Steps between summary logs + device: Device to use for computations + """ + self.model_adapter = model_adapter + self.num_train_timesteps = num_train_timesteps + self.timestep_sampling = timestep_sampling + self.flow_shift = flow_shift + self.i2v_prob = i2v_prob + self.logit_mean = logit_mean + self.logit_std = logit_std + self.mix_uniform_ratio = mix_uniform_ratio + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.use_loss_weighting = use_loss_weighting + self.log_interval = log_interval + self.summary_log_interval = summary_log_interval + self.device = device if device is not None else torch.device("cuda") + + # Initialize noise schedule + self.noise_schedule = LinearInterpolationSchedule() + + def sample_timesteps( + self, + batch_size: int, + device: Optional[torch.device] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, str]: + """ + Sample timesteps and compute sigma values with flow shift. + + Implements the flow shift transformation: + σ = shift / (shift + (1/u - 1)) + + Args: + batch_size: Number of timesteps to sample + device: Device for tensor operations + + Returns: + sigma: Sigma values in [sigma_min, sigma_max] + timesteps: Timesteps in [0, num_train_timesteps] + sampling_method: Name of the sampling method used + """ + if device is None: + device = self.device + + # Determine if we should use uniform (for mix strategy) + use_uniform = self.timestep_sampling == "uniform" or ( + self.mix_uniform_ratio > 0 and torch.rand(1).item() < self.mix_uniform_ratio + ) + + if use_uniform: + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + u = self._sample_from_distribution(batch_size, device) + sampling_method = self.timestep_sampling + + # Apply flow shift: σ = shift / (shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = self.flow_shift / (self.flow_shift + (1.0 / u_clamped - 1.0)) + + # Apply sigma clamping + sigma = torch.clamp(sigma, self.sigma_min, self.sigma_max) + + # Convert sigma to timesteps [0, T] + timesteps = sigma * self.num_train_timesteps + + return sigma, timesteps, sampling_method + + def _sample_from_distribution(self, batch_size: int, device: torch.device) -> torch.Tensor: + """Sample u values from the configured distribution.""" + if self.timestep_sampling == "logit_normal": + u = torch.normal( + mean=self.logit_mean, + std=self.logit_std, + size=(batch_size,), + device=device, + ) + u = torch.sigmoid(u) + + elif self.timestep_sampling == "lognorm": + u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device) + u = torch.sigmoid(u) + + elif self.timestep_sampling == "mode": + mode_scale = 1.29 + u = torch.rand(size=(batch_size,), device=device) + u = 1.0 - u - mode_scale * (torch.cos(math.pi * u / 2.0) ** 2 - 1.0 + u) + u = torch.clamp(u, 0.0, 1.0) + + elif self.timestep_sampling == "mix": + u = torch.normal(mean=0.0, std=1.0, size=(batch_size,), device=device) + u = torch.sigmoid(u) + + else: + u = torch.rand(size=(batch_size,), device=device) + + return u + + def determine_task_type(self, data_type: str) -> str: + """Determine task type based on data type and randomization.""" + if data_type == "image": + return "t2v" + elif data_type == "video": + return "i2v" if random.random() < self.i2v_prob else "t2v" + else: + return "t2v" + + def compute_loss( + self, + model_pred: torch.Tensor, + target: torch.Tensor, + sigma: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Compute flow matching loss with optional weighting. + + Loss weight: w = 1 + flow_shift * σ + + Args: + model_pred: Model prediction + target: Target (velocity = noise - clean) + sigma: Sigma values for each sample + + Returns: + weighted_loss: Final loss to backprop + unweighted_loss: Raw MSE loss + loss_weight: Applied weights + """ + loss = nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") + + if self.use_loss_weighting: + loss_weight = 1.0 + self.flow_shift * sigma + loss_weight = loss_weight.view(-1, *([1] * (loss.ndim - 1))) + else: + loss_weight = torch.ones_like(sigma).view(-1, *([1] * (loss.ndim - 1))) + + loss_weight = loss_weight.to(model_pred.device) + + unweighted_loss = loss.mean() + weighted_loss = (loss * loss_weight).mean() + + return weighted_loss, unweighted_loss, loss_weight + + def step( + self, + model: nn.Module, + batch: Dict[str, Any], + device: torch.device, + dtype: torch.dtype, + global_step: int = 0, + ) -> Tuple[torch.Tensor, Dict[str, Any]]: + """ + Execute a single training step with flow matching. + + Expected batch format: + { + "video_latents": torch.Tensor, # [B, C, F, H, W] + "text_embeddings": torch.Tensor, # [B, seq_len, dim] + "data_type": str, # "video" or "image" (optional) + # ... additional model-specific keys handled by adapter + } + + Args: + model: The model to train + batch: Batch of training data + device: Device to use + dtype: Data type for operations + global_step: Current training step (for logging) + + Returns: + loss: The computed loss + metrics: Dictionary of training metrics + """ + debug_mode = os.environ.get("DEBUG_TRAINING", "0") == "1" + detailed_log = global_step % self.log_interval == 0 + summary_log = global_step % self.summary_log_interval == 0 + + # Extract and prepare batch data + video_latents = batch["video_latents"].to(device, dtype=dtype) + + # Handle tensor shapes + if video_latents.ndim == 4: + video_latents = video_latents.unsqueeze(0) + + batch_size = video_latents.shape[0] + + # Determine task type + data_type = batch.get("data_type", "video") + task_type = self.determine_task_type(data_type) + + # ==================================================================== + # Flow Matching: Sample Timesteps + # ==================================================================== + sigma, timesteps, sampling_method = self.sample_timesteps(batch_size, device) + + # ==================================================================== + # Flow Matching: Add Noise + # ==================================================================== + noise = torch.randn_like(video_latents, dtype=torch.float32) + + # x_t = (1 - σ) * x_0 + σ * ε + noisy_latents = self.noise_schedule.forward(video_latents.float(), noise, sigma) + + # ==================================================================== + # Logging + # ==================================================================== + if detailed_log or debug_mode: + self._log_detailed( + global_step, sampling_method, batch_size, sigma, timesteps, video_latents, noise, noisy_latents + ) + elif summary_log: + logger.info( + f"[STEP {global_step}] σ=[{sigma.min():.3f},{sigma.max():.3f}] | " + f"t=[{timesteps.min():.1f},{timesteps.max():.1f}] | " + f"noisy=[{noisy_latents.min():.1f},{noisy_latents.max():.1f}] | " + f"{sampling_method}" + ) + + # Convert to target dtype + noisy_latents = noisy_latents.to(dtype) + + # ==================================================================== + # Forward Pass (via adapter) + # ==================================================================== + context = FlowMatchingContext( + noisy_latents=noisy_latents, + video_latents=video_latents, + timesteps=timesteps, + sigma=sigma, + task_type=task_type, + data_type=data_type, + device=device, + dtype=dtype, + batch=batch, + ) + + inputs = self.model_adapter.prepare_inputs(context) + model_pred = self.model_adapter.forward(model, inputs) + + # ==================================================================== + # Target: Flow Matching Velocity + # ==================================================================== + # v = ε - x_0 + target = noise - video_latents.float() + + # ==================================================================== + # Loss Computation + # ==================================================================== + weighted_loss, unweighted_loss, loss_weight = self.compute_loss(model_pred, target, sigma) + + # Safety check + if torch.isnan(weighted_loss) or weighted_loss > 100: + logger.error(f"[ERROR] Loss explosion! Loss={weighted_loss.item():.3f}") + raise ValueError(f"Loss exploded: {weighted_loss.item()}") + + # Logging + if detailed_log or debug_mode: + self._log_loss_detailed(global_step, model_pred, target, loss_weight, unweighted_loss, weighted_loss) + elif summary_log: + logger.info( + f"[STEP {global_step}] Loss: {weighted_loss.item():.6f} | " + f"w=[{loss_weight.min():.2f},{loss_weight.max():.2f}]" + ) + + # Collect metrics + metrics = { + "loss": weighted_loss.item(), + "unweighted_loss": unweighted_loss.item(), + "sigma_min": sigma.min().item(), + "sigma_max": sigma.max().item(), + "sigma_mean": sigma.mean().item(), + "weight_min": loss_weight.min().item(), + "weight_max": loss_weight.max().item(), + "timestep_min": timesteps.min().item(), + "timestep_max": timesteps.max().item(), + "noisy_min": noisy_latents.min().item(), + "noisy_max": noisy_latents.max().item(), + "sampling_method": sampling_method, + "task_type": task_type, + "data_type": data_type, + } + + return weighted_loss, metrics + + def _log_detailed( + self, + global_step: int, + sampling_method: str, + batch_size: int, + sigma: torch.Tensor, + timesteps: torch.Tensor, + video_latents: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + ): + """Log detailed training information.""" + logger.info("\n" + "=" * 80) + logger.info(f"[STEP {global_step}] FLOW MATCHING") + logger.info("=" * 80) + logger.info("[INFO] Using: x_t = (1-σ)x_0 + σ*ε") + logger.info("") + logger.info(f"[SAMPLING] Method: {sampling_method}") + logger.info(f"[FLOW] Shift: {self.flow_shift}") + logger.info(f"[BATCH] Size: {batch_size}") + logger.info("") + logger.info(f"[SIGMA] Range: [{sigma.min():.4f}, {sigma.max():.4f}]") + if sigma.numel() > 1: + logger.info(f"[SIGMA] Mean: {sigma.mean():.4f}, Std: {sigma.std():.4f}") + else: + logger.info(f"[SIGMA] Value: {sigma.item():.4f}") + logger.info("") + logger.info(f"[TIMESTEPS] Range: [{timesteps.min():.2f}, {timesteps.max():.2f}]") + logger.info("") + logger.info(f"[RANGES] Clean latents: [{video_latents.min():.4f}, {video_latents.max():.4f}]") + logger.info(f"[RANGES] Noise: [{noise.min():.4f}, {noise.max():.4f}]") + logger.info(f"[RANGES] Noisy latents: [{noisy_latents.min():.4f}, {noisy_latents.max():.4f}]") + + # Sanity check + max_expected = ( + max( + abs(video_latents.max().item()), + abs(video_latents.min().item()), + abs(noise.max().item()), + abs(noise.min().item()), + ) + * 1.5 + ) + if abs(noisy_latents.max()) > max_expected or abs(noisy_latents.min()) > max_expected: + logger.info(f"\n⚠️ WARNING: Noisy range seems large! Expected ~{max_expected:.1f}") + else: + logger.info("\n✓ Noisy latents range is reasonable") + logger.info("=" * 80 + "\n") + + def _log_loss_detailed( + self, + global_step: int, + model_pred: torch.Tensor, + target: torch.Tensor, + loss_weight: torch.Tensor, + unweighted_loss: torch.Tensor, + weighted_loss: torch.Tensor, + ): + """Log detailed loss information.""" + logger.info("=" * 80) + logger.info(f"[STEP {global_step}] LOSS DEBUG") + logger.info("=" * 80) + logger.info("[TARGET] Flow matching: v = ε - x_0") + logger.info("") + logger.info(f"[RANGES] Model pred: [{model_pred.min():.4f}, {model_pred.max():.4f}]") + logger.info(f"[RANGES] Target (v): [{target.min():.4f}, {target.max():.4f}]") + logger.info("") + logger.info(f"[WEIGHTS] Formula: 1 + {self.flow_shift} * σ") + logger.info(f"[WEIGHTS] Range: [{loss_weight.min():.4f}, {loss_weight.max():.4f}]") + logger.info(f"[WEIGHTS] Mean: {loss_weight.mean():.4f}") + logger.info("") + logger.info(f"[LOSS] Unweighted: {unweighted_loss.item():.6f}") + logger.info(f"[LOSS] Weighted: {weighted_loss.item():.6f}") + logger.info(f"[LOSS] Impact: {(weighted_loss / max(unweighted_loss, 1e-8)):.3f}x") + logger.info("=" * 80 + "\n") + + +# ============================================================================= +# Factory Functions +# ============================================================================= + + +def create_adapter(adapter_type: str, **kwargs) -> ModelAdapter: + """ + Factory function to create a model adapter by name. + + Args: + adapter_type: Type of adapter ("hunyuan", "simple") + **kwargs: Additional arguments passed to the adapter constructor + + Returns: + ModelAdapter instance + """ + adapters = { + "hunyuan": HunyuanAdapter, + "simple": SimpleAdapter, + } + + if adapter_type not in adapters: + raise ValueError(f"Unknown adapter type: {adapter_type}. Available: {list(adapters.keys())}") + + return adapters[adapter_type](**kwargs) + + +def create_pipeline( + adapter_type: str, + adapter_kwargs: Optional[Dict[str, Any]] = None, + **pipeline_kwargs, +) -> FlowMatchingPipeline: + """ + Factory function to create a pipeline with a specific adapter. + + Args: + adapter_type: Type of adapter ("hunyuan", "simple") + adapter_kwargs: Arguments for the adapter constructor + **pipeline_kwargs: Arguments for the pipeline constructor + + Returns: + FlowMatchingPipeline instance + + Example: + pipeline = create_pipeline( + adapter_type="hunyuan", + adapter_kwargs={"use_condition_latents": True}, + flow_shift=3.0, + timestep_sampling="logit_normal", + ) + """ + adapter_kwargs = adapter_kwargs or {} + adapter = create_adapter(adapter_type, **adapter_kwargs) + return FlowMatchingPipeline(model_adapter=adapter, **pipeline_kwargs) diff --git a/dfm/src/automodel/recipes/train.py b/dfm/src/automodel/recipes/train.py index b4561dd5..739e6c42 100644 --- a/dfm/src/automodel/recipes/train.py +++ b/dfm/src/automodel/recipes/train.py @@ -32,10 +32,8 @@ from torch.distributed.fsdp import MixedPrecisionPolicy from transformers.utils.hub import TRANSFORMERS_CACHE -from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoWanPipeline -from dfm.src.automodel.flow_matching.training_step_t2v import ( - step_fsdp_transformer_t2v, -) +from dfm.src.automodel._diffusers.auto_diffusion_pipeline import NeMoAutoDiffusionPipeline, NeMoWanPipeline +from dfm.src.automodel.flow_matching.flow_matching_pipeline import FlowMatchingPipeline, create_adapter def build_model_and_optimizer( @@ -47,11 +45,12 @@ def build_model_and_optimizer( dtype: torch.dtype, cpu_offload: bool = False, fsdp_cfg: Dict[str, Any] = {}, + attention_backend: Optional[str] = None, optimizer_cfg: Optional[Dict[str, Any]] = None, ) -> tuple[NeMoWanPipeline, dict[str, Dict[str, Any]], torch.optim.Optimizer, Any]: - """Build the WAN 2.1 diffusion model, parallel scheme, and optimizer.""" + """Build the diffusion model, parallel scheme, and optimizer.""" - logging.info("[INFO] Building NeMoWanPipeline with transformer parallel scheme...") + logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...") if not dist.is_initialized(): logging.info("[WARN] torch.distributed not initialized; proceeding in single-process mode") @@ -74,7 +73,7 @@ def build_model_and_optimizer( "activation_checkpointing": True, "mp_policy": MixedPrecisionPolicy( param_dtype=dtype, - reduce_dtype=dtype, + reduce_dtype=torch.float32, output_dtype=dtype, ), } @@ -85,7 +84,10 @@ def build_model_and_optimizer( if finetune_mode: kwargs["load_for_training"] = True kwargs["low_cpu_mem_usage"] = True - init_fn = NeMoWanPipeline.from_pretrained if finetune_mode else NeMoWanPipeline.from_config + if "wan" in model_id: + init_fn = NeMoWanPipeline.from_pretrained if finetune_mode else NeMoWanPipeline.from_config + else: + init_fn = NeMoAutoDiffusionPipeline.from_pretrained pipe, created_managers = init_fn( model_id, @@ -97,6 +99,9 @@ def build_model_and_optimizer( ) fsdp2_manager = created_managers["transformer"] transformer_module = pipe.transformer + if attention_backend is not None: + logging.info(f"[INFO] Setting attention backend to {attention_backend}") + transformer_module.set_attention_backend(attention_backend) trainable_params = [p for p in transformer_module.parameters() if p.requires_grad] if not trainable_params: @@ -105,6 +110,7 @@ def build_model_and_optimizer( optimizer_cfg = optimizer_cfg or {} weight_decay = optimizer_cfg.get("weight_decay", 0.01) betas = optimizer_cfg.get("betas", (0.9, 0.999)) + # TODO: Support other optimizers optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, weight_decay=weight_decay, betas=betas) logging.info("[INFO] Optimizer config: lr=%s, weight_decay=%s, betas=%s", learning_rate, weight_decay, betas) @@ -145,8 +151,8 @@ def is_main_process(): return (not dist.is_initialized()) or dist.get_rank() == 0 -class TrainWan21DiffusionRecipe(BaseRecipe): - """Config-driven wrapper around WAN 2.1 T2V training.""" +class TrainDiffusionRecipe(BaseRecipe): + """Training recipe for diffusion models.""" def __init__(self, cfg): self.cfg = cfg @@ -164,7 +170,8 @@ def setup(self): self.seed = self.cfg.get("seed", 42) self.rng = StatefulRNG(seed=self.seed, ranked=True) - self.model_id = self.cfg.get("model.pretrained_model_name_or_path", "Wan-AI/Wan2.1-T2V-1.3B-Diffusers") + self.model_id = self.cfg.get("model.pretrained_model_name_or_path") + self.attention_backend = self.cfg.get("model.attention_backend", "_flash_3_hub") self.learning_rate = self.cfg.get("optim.learning_rate", 5e-6) self.bf16 = torch.bfloat16 @@ -180,7 +187,7 @@ def setup(self): self.num_nodes = max(1, self.world_size // self.local_world_size) self.node_rank = dist.get_rank() // self.local_world_size if dist.is_initialized() else 0 - logging.info("[INFO] WAN 2.1 T2V Trainer with Flow Matching") + logging.info("[INFO] Diffusion Trainer with Flow Matching") logging.info( f"[INFO] Total GPUs: {self.world_size}, GPUs per node: {self.local_world_size}, Num nodes: {self.num_nodes}" ) @@ -191,20 +198,32 @@ def setup(self): fm_cfg = self.cfg.get("flow_matching", {}) self.cpu_offload = fsdp_cfg.get("cpu_offload", False) - self.use_sigma_noise = fm_cfg.get("use_sigma_noise", True) - self.timestep_sampling = fm_cfg.get("timestep_sampling", "uniform") + + # Flow matching configuration + self.adapter_type = fm_cfg.get("adapter_type", "simple") + self.timestep_sampling = fm_cfg.get("timestep_sampling", "logit_normal") self.logit_mean = fm_cfg.get("logit_mean", 0.0) self.logit_std = fm_cfg.get("logit_std", 1.0) self.flow_shift = fm_cfg.get("flow_shift", 3.0) self.mix_uniform_ratio = fm_cfg.get("mix_uniform_ratio", 0.1) self.sigma_min = fm_cfg.get("sigma_min", 0.0) self.sigma_max = fm_cfg.get("sigma_max", 1.0) - - logging.info(f"[INFO] Flow matching: {'ENABLED' if self.use_sigma_noise else 'DISABLED'}") - if self.use_sigma_noise: - logging.info(f"[INFO] - Timestep sampling: {self.timestep_sampling}") - logging.info(f"[INFO] - Flow shift: {self.flow_shift}") - logging.info(f"[INFO] - Mix uniform ratio: {self.mix_uniform_ratio}") + self.num_train_timesteps = fm_cfg.get("num_train_timesteps", 1000) + self.i2v_prob = fm_cfg.get("i2v_prob", 0.3) + self.use_loss_weighting = fm_cfg.get("use_loss_weighting", True) + self.log_interval = fm_cfg.get("log_interval", 100) + self.summary_log_interval = fm_cfg.get("summary_log_interval", 10) + + # Adapter-specific configuration + adapter_kwargs = fm_cfg.get("adapter_kwargs", {}) + self.adapter_kwargs = adapter_kwargs.to_dict() + + logging.info("[INFO] Flow Matching V2 Pipeline") + logging.info(f"[INFO] - Adapter type: {self.adapter_type}") + logging.info(f"[INFO] - Timestep sampling: {self.timestep_sampling}") + logging.info(f"[INFO] - Flow shift: {self.flow_shift}") + logging.info(f"[INFO] - Mix uniform ratio: {self.mix_uniform_ratio}") + logging.info(f"[INFO] - Use loss weighting: {self.use_loss_weighting}") (self.pipe, self.optimizer, self.device_mesh) = build_model_and_optimizer( model_id=self.model_id, @@ -215,6 +234,7 @@ def setup(self): cpu_offload=self.cpu_offload, fsdp_cfg=fsdp_cfg, optimizer_cfg=self.cfg.get("optim.optimizer", {}), + attention_backend=self.attention_backend, ) self.model = self.pipe.transformer @@ -309,6 +329,26 @@ def setup(self): self.load_checkpoint(self.restore_from) + # Init Flow Matching Pipeline V2 with model adapter + model_adapter = create_adapter(self.adapter_type, **self.adapter_kwargs) + self.flow_matching_pipeline = FlowMatchingPipeline( + model_adapter=model_adapter, + num_train_timesteps=self.num_train_timesteps, + timestep_sampling=self.timestep_sampling, + flow_shift=self.flow_shift, + i2v_prob=self.i2v_prob, + logit_mean=self.logit_mean, + logit_std=self.logit_std, + mix_uniform_ratio=self.mix_uniform_ratio, + sigma_min=self.sigma_min, + sigma_max=self.sigma_max, + use_loss_weighting=self.use_loss_weighting, + log_interval=self.log_interval, + summary_log_interval=self.summary_log_interval, + device=self.device, + ) + logging.info(f"[INFO] Flow Matching Pipeline V2 initialized with {self.adapter_type} adapter") + if is_main_process(): os.makedirs(self.checkpoint_config.checkpoint_dir, exist_ok=True) @@ -344,20 +384,11 @@ def run_train_validation_loop(self): micro_losses = [] for micro_batch in batch_group: try: - loss, _ = step_fsdp_transformer_t2v( - scheduler=self.pipe.scheduler, + loss, metrics = self.flow_matching_pipeline.step( model=self.model, batch=micro_batch, device=self.device, - bf16=self.bf16, - use_sigma_noise=self.use_sigma_noise, - timestep_sampling=self.timestep_sampling, - logit_mean=self.logit_mean, - logit_std=self.logit_std, - flow_shift=self.flow_shift, - mix_uniform_ratio=self.mix_uniform_ratio, - sigma_min=self.sigma_min, - sigma_max=self.sigma_max, + dtype=self.bf16, global_step=global_step, ) except Exception as exc: diff --git a/dfm/src/automodel/utils/data/preprocess_dataset.py b/dfm/src/automodel/utils/data/preprocess_dataset.py new file mode 100644 index 00000000..f4e221a8 --- /dev/null +++ b/dfm/src/automodel/utils/data/preprocess_dataset.py @@ -0,0 +1,545 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Preprocessing Script for HunyuanVideo-1.5 Training Data + +This script preprocesses videos and text captions for HunyuanVideo-1.5 training by: +1. Loading videos from a folder with meta.json metadata +2. Processing videos to ensure 4n+1 frames (required by VAE) +3. Encoding videos with VAE to get latents +4. Encoding text captions with text encoders (CLIP-like + LLaMA) and byT5 +5. Saving preprocessed data to .meta files for faster training + +Usage: + python preprocess_dataset.py \ + --data_dir /path/to/videos \ + --meta_file meta.json \ + --output_dir /path/to/output \ + --pretrained_model_root /path/to/models \ + --target_frames 121 \ + --target_height 720 \ + --target_width 1280 +""" + +import argparse +import json +import logging +import pickle +from pathlib import Path +from typing import Any, Dict, List, Optional + +import imageio +import numpy as np +import torch +from diffusers import HunyuanVideo15ImageToVideoPipeline +from PIL import Image +from tqdm import tqdm + + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def resize_and_center_crop(image, target_width, target_height): + if target_height == image.shape[0] and target_width == image.shape[1]: + return image + + pil_image = Image.fromarray(image) + original_width, original_height = pil_image.size + scale_factor = max(target_width / original_width, target_height / original_height) + resized_width = int(round(original_width * scale_factor)) + resized_height = int(round(original_height * scale_factor)) + resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS) + left = (resized_width - target_width) / 2 + top = (resized_height - target_height) / 2 + right = (resized_width + target_width) / 2 + bottom = (resized_height + target_height) / 2 + cropped_image = resized_image.crop((left, top, right, bottom)) + return np.array(cropped_image) + + +def str_to_bool(value): + """Convert string to boolean.""" + if value is None: + return True + if isinstance(value, bool): + return value + if isinstance(value, str): + value = value.lower().strip() + if value in ("true", "1", "yes", "on"): + return True + elif value in ("false", "0", "no", "off"): + return False + raise argparse.ArgumentTypeError(f"Boolean value expected, got: {value}") + + +def format_text_input(prompt: List[str], system_message: str) -> List[Dict[str, Any]]: + """ + Apply text to template. + + Args: + prompt (List[str]): Input text. + system_message (str): System message. + + Returns: + List[Dict[str, Any]]: List of chat conversation. + """ + + template = [ + [{"role": "system", "content": system_message}, {"role": "user", "content": p if p else " "}] for p in prompt + ] + + return template + + +def load_video(video_path: str, start_frame: int = 0, end_frame: Optional[int] = None) -> np.ndarray: + """ + Load video from file. + + Args: + video_path: Path to video file + start_frame: Starting frame index + end_frame: Ending frame index (None means to the end) + + Returns: + Video frames as numpy array [F, H, W, C] in uint8 [0, 255] + """ + reader = imageio.get_reader(video_path, "ffmpeg") + frames = [] + + try: + for i, frame in enumerate(reader): + if i < start_frame: + continue + if end_frame is not None and i >= end_frame: + break + frames.append(frame) + finally: + reader.close() + + if len(frames) == 0: + raise ValueError(f"No frames loaded from {video_path}") + + return np.stack(frames, axis=0) # [F, H, W, C] + + +def adjust_frames_to_4n_plus_1(frames: np.ndarray, target_frames: Optional[int] = None) -> np.ndarray: + """ + Adjust number of frames to 4n+1 format required by VAE. + + Args: + frames: Input frames [F, H, W, C] + target_frames: Target number of frames (must be 4n+1). If None, adjust to closest 4n+1 + + Returns: + Adjusted frames [F', H, W, C] where F' = 4n+1 + """ + num_frames = frames.shape[0] + + if target_frames is not None: + # Validate target_frames is 4n+1 + if (target_frames - 1) % 4 != 0: + raise ValueError(f"target_frames must be 4n+1, got {target_frames}") + + if num_frames < target_frames: + # Repeat frames if not enough + logger.warning( + f"Video has {num_frames} frames, but target is {target_frames}. Some frames will be repeated." + ) + indices = np.linspace(0, num_frames - 1, target_frames).astype(int) + frames = frames[indices] + elif num_frames > target_frames: + # Sample frames uniformly + logger.debug(f"Sampling {target_frames} frames from {num_frames} total frames") + indices = np.linspace(0, num_frames - 1, target_frames).astype(int) + frames = frames[indices] + + return frames + else: + # Find closest 4n+1 + n = (num_frames - 1) // 4 + target = 4 * n + 1 + + if target < 1: + target = 1 + + if num_frames != target: + logger.debug(f"Adjusting {num_frames} frames to {target} frames (4n+1 format)") + if num_frames < target: + indices = np.linspace(0, num_frames - 1, target).astype(int) + else: + indices = np.linspace(0, num_frames - 1, target).astype(int) + frames = frames[indices] + + return frames + + +def preprocess_video(frames: np.ndarray, target_height: int, target_width: int) -> torch.Tensor: + """ + Preprocess video frames to target resolution and convert to tensor. + + Args: + frames: Input frames [F, H, W, C] in uint8 [0, 255] + target_height: Target height + target_width: Target width + + Returns: + Preprocessed video tensor [C, F, H, W] in float32 [-1, 1] + """ + num_frames = frames.shape[0] + processed_frames = [] + + for i in range(num_frames): + frame = frames[i] # [H, W, C] + # Resize and center crop + frame = resize_and_center_crop(frame, target_width, target_height) + processed_frames.append(frame) + + processed_frames = np.stack(processed_frames, axis=0) # [F, H, W, C] + + # Convert to tensor and normalize to [-1, 1] + video_tensor = torch.from_numpy(processed_frames).float() / 255.0 # [F, H, W, C] in [0, 1] + video_tensor = video_tensor * 2.0 - 1.0 # [0, 1] -> [-1, 1] + video_tensor = video_tensor.permute(3, 0, 1, 2) # [C, F, H, W] + + return video_tensor + + +class VideoPreprocessor: + def __init__( + self, + pretrained_model_root: str, + transformer_version: str = "720p_t2v", + device: str = "cuda", + dtype: str = "fp16", + ): + """ + Initialize video preprocessor with models. + + Args: + pretrained_model_root: Path to pretrained models + transformer_version: Transformer version (not loaded, only for pipeline setup) + device: Device to use ('cuda' or 'cpu') + dtype: Data type for encoding ('fp16' or 'bf16' or 'fp32') + """ + self.device = torch.device(device) + + # Set dtype for VAE encoding + if dtype == "fp16": + self.vae_dtype = torch.float16 + elif dtype == "bf16": + self.vae_dtype = torch.bfloat16 + else: + self.vae_dtype = torch.float32 + + logger.info(f"Loading models from {pretrained_model_root}") + logger.info(f"Using device: {device}, dtype: {dtype}") + + # Load pipeline (we only need VAE and text encoders, not transformer) + self.pipeline = HunyuanVideo15ImageToVideoPipeline.from_pretrained( + "hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_i2v", torch_dtype=torch.float16, cpu_offload=True + ) + + self.vae = self.pipeline.vae + self.text_encoder = self.pipeline.text_encoder + self.tokenizer = self.pipeline.tokenizer + + self.tokenizer_max_length = 1000 + self.crop_start = 108 + self.num_hidden_layers_to_skip = 2 + # Set models to eval mode and move to device + if hasattr(self.vae, "enable_tiling"): + self.vae.enable_tiling(tile_sample_min_height=64, tile_sample_min_width=64, tile_overlap_factor=0.25) + logger.info("VAE tiling enabled") + + if hasattr(self.vae, "enable_slicing"): + self.vae.enable_slicing() + logger.info("VAE slicing enabled") + + logger.info("Models loaded successfully") + + @torch.no_grad() + def encode_vae(self, video: torch.Tensor) -> torch.Tensor: + """ + Encode video with VAE. + + Args: + video: Video tensor [C, F, H, W] in float32 [-1, 1] + + Returns: + Latents tensor [C_latent, F, H_latent, W_latent] + """ + if video.max() > 1.0 or video.min() < -1.0: + raise ValueError(f"Video must be in range [-1, 1], got [{video.min()}, {video.max()}]") + + # Add batch dimension + video = video.unsqueeze(0) # [1, C, F, H, W] + video = video.to(device=self.device) + + with torch.autocast(device_type="cuda", dtype=self.vae_dtype, enabled=(self.device.type == "cuda")): + latents = self.vae.encode(video).latent_dist.sample() + if hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor: + latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor + else: + latents = latents * self.vae.config.scaling_factor + + # Remove batch dimension + latents = latents.squeeze(0) # [C_latent, F, H_latent, W_latent] + latents = latents.detach().cpu() + + return latents + + def preprocess_single_video( + self, + video_path: str, + caption: str, + start_frame: int = 0, + end_frame: Optional[int] = None, + target_frames: Optional[int] = None, + target_height: int = 720, + target_width: int = 1280, + data_type: str = "video", + ) -> Dict[str, torch.Tensor]: + """ + Preprocess a single video and caption. + + Args: + video_path: Path to video file + caption: Text caption + start_frame: Starting frame index + end_frame: Ending frame index + target_frames: Target number of frames (must be 4n+1) + target_height: Target height + target_width: Target width + data_type: "video" or "image" + + Returns: + Dictionary containing preprocessed data + """ + # Load video + logger.debug(f"Loading video: {video_path}") + frames = load_video(video_path, start_frame, end_frame) + logger.debug(f"Loaded {frames.shape[0]} frames from {video_path}") + + # Adjust frames to 4n+1 + frames = adjust_frames_to_4n_plus_1(frames, target_frames) + logger.debug(f"Adjusted to {frames.shape[0]} frames (4n+1 format)") + + # Preprocess video + + video_tensor = preprocess_video(frames, target_height, target_width) + logger.debug(f"Preprocessed video shape: {video_tensor.shape}") + + # Encode with VAE + self.vae.to(self.device) + self.vae.eval() + logger.debug("Encoding with VAE...") + latents = self.encode_vae(video_tensor) + logger.debug(f"Latents shape: {latents.shape}") + self.vae.to("cpu") + + # Encode text + self.text_encoder.to(self.device) + self.text_encoder.eval() + logger.debug("Encoding text...") + # text_encodings = self.encode_text(caption) + prompt_embeds, prompt_embeds_mask, prompt_embeds_2, prompt_embeds_mask_2 = self.pipeline.encode_prompt( + prompt=caption, device=self.device, dtype=torch.float16, batch_size=1, num_videos_per_prompt=1 + ) + self.text_encoder.to("cpu") + + logger.debug("Encoding first frame for image embedding...") + self.pipeline.image_encoder.to(self.device) + first_frame = frames[0] + image_embeds = self.pipeline.encode_image( + image=first_frame, batch_size=1, device=self.device, dtype=torch.float16 + ) + logger.info(f"!!!Image embeddings shape: {image_embeds.shape}") + self.pipeline.image_encoder.to("cpu") + + result = { + "video_latents": latents.unsqueeze(0), # Add batch dim: [1, C, F, H, W] - already detached above + "text_embeddings": prompt_embeds.detach().cpu(), # Already [1, seq_len, dim] + "text_mask": prompt_embeds_mask.detach().cpu(), # [1, seq_len] + "text_embeddings_2": prompt_embeds_2.detach().cpu(), # [1, seq_len, dim] + "text_mask_2": prompt_embeds_mask_2.detach().cpu(), # [1, seq_len] + "image_embeds": image_embeds.detach().cpu(), # [1, 729, 1152] + "metadata": { + "text": caption, + "data_type": data_type, + "video_shape": list(video_tensor.shape), # [C, F, H, W] + "latent_shape": list(latents.shape), # [C_latent, F, H_latent, W_latent] + }, + "original_filename": Path(video_path).name, + "original_video_path": str(video_path), + "num_frames": video_tensor.shape[1], # F + "deterministic_latents": "vae_encoded", + "memory_optimization": f"dtype_{self.vae_dtype}", + } + logger.debug( + f"Result shapes - video_latents: {result['video_latents'].shape}, " + f"text_embeddings: {result['text_embeddings'].shape}" + ) + + return result + + +def main(): + parser = argparse.ArgumentParser(description="Preprocess videos and captions for HunyuanVideo-1.5 training") + + # Data parameters + parser.add_argument("--data_dir", type=str, required=True, help="Directory containing videos") + parser.add_argument("--meta_file", type=str, default="meta.json", help="Metadata JSON file name") + parser.add_argument("--output_dir", type=str, required=True, help="Output directory for .meta files") + + # Model parameters + parser.add_argument("--pretrained_model_root", type=str, required=False, help="Path to pretrained models") + parser.add_argument( + "--transformer_version", type=str, default="720p_t2v", help="Transformer version (default: 720p_t2v)" + ) + + # Processing parameters + parser.add_argument( + "--target_frames", + type=int, + default=9, + help="Target number of frames (must be 4n+1, e.g., 1, 5, 9, 13, 17, 21, ..., 121)", + ) + parser.add_argument("--target_height", type=int, default=720, help="Target video height") + parser.add_argument("--target_width", type=int, default=1280, help="Target video width") + parser.add_argument( + "--data_type", + type=str, + default="video", + choices=["video", "image"], + help="Data type for text encoding (default: video)", + ) + + # Caption field + parser.add_argument( + "--caption_field", + type=str, + default="vila_caption", + help="Field name in meta.json containing captions (default: vila_caption)", + ) + + # Device parameters + parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)") + parser.add_argument( + "--dtype", + type=str, + default="fp16", + choices=["fp16", "bf16", "fp32"], + help="Data type for encoding (default: fp16)", + ) + + # Other parameters + parser.add_argument("--batch_size", type=int, default=1, help="Batch size (currently only 1 is supported)") + parser.add_argument( + "--num_workers", type=int, default=0, help="Number of worker processes (currently only 0 is supported)" + ) + + args = parser.parse_args() + + # Validate target_frames is 4n+1 + if (args.target_frames - 1) % 4 != 0: + raise ValueError(f"target_frames must be 4n+1 (e.g., 1, 5, 9, 13, 17, 21, ..., 121), got {args.target_frames}") + + # Setup paths + data_dir = Path(args.data_dir) + meta_path = data_dir / args.meta_file + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load metadata + logger.info(f"Loading metadata from {meta_path}") + with open(meta_path, "r") as f: + metadata = json.load(f) + + logger.info(f"Found {len(metadata)} videos in metadata") + + # Initialize preprocessor + logger.info("Initializing preprocessor...") + preprocessor = VideoPreprocessor( + pretrained_model_root=args.pretrained_model_root, + transformer_version=args.transformer_version, + device=args.device, + dtype=args.dtype, + ) + + # Process each video + logger.info("Starting preprocessing...") + successful = 0 + failed = 0 + + for item in tqdm(metadata, desc="Processing videos"): + try: + file_name = item["file_name"] + video_path = data_dir / file_name + + # Check if video exists + if not video_path.exists(): + logger.warning(f"Video not found: {video_path}") + failed += 1 + continue + + # Get caption + caption = item.get(args.caption_field, "") + if not caption: + logger.warning(f"No caption found for {file_name}") + failed += 1 + continue + + # Get frame range if specified + start_frame = item.get("start_frame", 0) + end_frame = item.get("end_frame", None) + if end_frame is not None: + end_frame = end_frame + 1 # end_frame is inclusive in meta.json, but exclusive in our code + + # Preprocess + result = preprocessor.preprocess_single_video( + video_path=str(video_path), + caption=caption, + start_frame=start_frame, + end_frame=end_frame, + target_frames=args.target_frames, + target_height=args.target_height, + target_width=args.target_width, + data_type=args.data_type, + ) + + # Save .meta file using pickle (compatible with wan21.py dataloader) + output_path = output_dir / f"{Path(file_name).stem}.meta" + with open(output_path, "wb") as f: + pickle.dump(result, f) + + successful += 1 + + except Exception as e: + logger.error(f"Failed to process {item.get('file_name', 'unknown')}: {e}") + import traceback + + logger.error(traceback.format_exc()) + failed += 1 + + logger.info("=" * 80) + logger.info("Preprocessing complete!") + logger.info(f"Successful: {successful}") + logger.info(f"Failed: {failed}") + logger.info(f"Output directory: {output_dir}") + logger.info("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/dfm/src/automodel/utils/validate_hunyuan.py b/dfm/src/automodel/utils/validate_hunyuan.py new file mode 100644 index 00000000..a03e7d03 --- /dev/null +++ b/dfm/src/automodel/utils/validate_hunyuan.py @@ -0,0 +1,275 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import pickle +from pathlib import Path + +import torch +from diffusers import HunyuanVideo15Pipeline +from diffusers.utils import export_to_video + + +def parse_args(): + p = argparse.ArgumentParser("HunyuanVideo-1.5 T2V Validation") + + # Model configuration + p.add_argument("--model_id", type=str, default="hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v") + p.add_argument("--transformer_version", type=str, default="720p_t2v", help="Transformer version") + p.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint (optional)") + + # Data - load from .meta files + p.add_argument("--meta_folder", type=str, required=True, help="Folder containing .meta files with prompts") + + # Generation settings + p.add_argument("--num_samples", type=int, default=None, help="Number of samples (default: all)") + p.add_argument("--num_inference_steps", type=int, default=50) + p.add_argument("--guidance_scale", type=float, default=6.0) + p.add_argument("--negative_prompt", type=str, default="") + p.add_argument("--seed", type=int, default=42) + + # Video settings + p.add_argument("--height", type=int, default=480) + p.add_argument("--width", type=int, default=832) + p.add_argument("--num_frames", type=int, default=129) + p.add_argument("--fps", type=int, default=16) + + # Flow matching settings + p.add_argument("--flow_shift", type=float, default=5.0, help="Flow shift for inference") + + # Output + p.add_argument("--output_dir", type=str, default="./validation_outputs_hunyuan") + + return p.parse_args() + + +def load_prompts_from_meta_files(meta_folder: str): + """ + Load prompts from .meta files. + Each .meta file contains a 'metadata' dict with 'vila_caption'. + + Returns list of dicts: [{"prompt": "...", "name": "...", "meta_file": "..."}, ...] + """ + meta_folder = Path(meta_folder) + meta_files = sorted(list(meta_folder.glob("*.meta"))) + + if not meta_files: + raise FileNotFoundError(f"No .meta files found in {meta_folder}") + + print(f"[INFO] Found {len(meta_files)} .meta files") + + prompts = [] + + for meta_file in meta_files: + try: + with open(meta_file, "rb") as f: + data = pickle.load(f) + + # Extract prompt from metadata + metadata = data.get("metadata", {}) + prompt = metadata.get("vila_caption", "") + + if not prompt: + print(f"[WARNING] No vila_caption in {meta_file.name}, skipping...") + continue + + # Get filename without extension + name = meta_file.stem + + prompts.append({"prompt": prompt, "name": name, "meta_file": str(meta_file)}) + + except Exception as e: + print(f"[WARNING] Failed to load {meta_file.name}: {e}") + continue + + if not prompts: + raise ValueError(f"No valid prompts found in {meta_folder}") + + return prompts + + +def main(): + args = parse_args() + + print("=" * 80) + print("HunyuanVideo-1.5 Text-to-Video Validation") + print("=" * 80) + + # Load prompts from .meta files + print(f"\n[1] Loading prompts from .meta files in: {args.meta_folder}") + prompts = load_prompts_from_meta_files(args.meta_folder) + + if args.num_samples: + prompts = prompts[: args.num_samples] + + print(f"[INFO] Loaded {len(prompts)} prompts") + + # Show first few prompts + print("\n[INFO] Sample prompts:") + for i, item in enumerate(prompts[:3]): + print(f" {i + 1}. {item['name']}: {item['prompt'][:60]}...") + + # Load pipeline + print(f"\n[2] Loading pipeline: {args.model_id}") + pipe = HunyuanVideo15Pipeline.from_pretrained( + args.model_id, + torch_dtype=torch.bfloat16, + trust_remote_code=True, + ) + pipe.to("cuda") + + # Enable VAE optimizations (critical for memory) + pipe.vae.enable_slicing() + pipe.vae.enable_tiling() + print("[INFO] Enabled VAE slicing and tiling") + + # Enable CPU offloading to reduce GPU memory usage (optional but helpful) + # Uncomment if you need more memory savings: + # pipe.enable_model_cpu_offload() + # print("[INFO] Enabled model CPU offloading") + + # Load checkpoint if provided + if args.checkpoint: + print(f"\n[3] Loading checkpoint: {args.checkpoint}") + + # Try EMA checkpoint first (best quality) + ema_path = os.path.join(args.checkpoint, "ema_shadow.pt") + consolidated_path = os.path.join(args.checkpoint, "consolidated_model.bin") + sharded_dir = os.path.join(args.checkpoint, "model") + + if os.path.exists(ema_path): + print("[INFO] Loading EMA checkpoint (best quality)...") + ema_state = torch.load(ema_path, map_location="cuda") + pipe.transformer.load_state_dict(ema_state, strict=True) + print("[INFO] ✅ Loaded from EMA checkpoint") + elif os.path.exists(consolidated_path): + print("[INFO] Loading consolidated checkpoint...") + state_dict = torch.load(consolidated_path, map_location="cuda") + pipe.transformer.load_state_dict(state_dict, strict=True) + print("[INFO] ✅ Loaded from consolidated checkpoint") + elif os.path.isdir(sharded_dir) and any(name.endswith(".distcp") for name in os.listdir(sharded_dir)): + print(f"[INFO] Detected sharded FSDP checkpoint at: {sharded_dir}") + print("[INFO] Loading sharded checkpoint via PyTorch Distributed Checkpoint (single process)...") + + import torch.distributed as dist + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint import load as dist_load + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp.api import ShardedStateDictConfig + + # Initialize a single-process group if not already initialized + init_dist = False + if not dist.is_initialized(): + dist.init_process_group(backend="gloo", rank=0, world_size=1) + init_dist = True + + # Wrap current transformer with FSDP to load sharded weights + base_transformer = pipe.transformer + + # Ensure uniform dtype before FSDP wraps/flattening + base_transformer.to(dtype=torch.bfloat16) + fsdp_transformer = FSDP(base_transformer, use_orig_params=True) + + # Configure to expect sharded state dict + FSDP.set_state_dict_type( + fsdp_transformer, + StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), + ) + + # Load shards into the FSDP-wrapped model + model_state = fsdp_transformer.state_dict() + dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir)) + fsdp_transformer.load_state_dict(model_state) + + # Unwrap back to the original module for inference + pipe.transformer = fsdp_transformer.module + + # Move to CUDA bf16 for inference + pipe.transformer.to("cuda", dtype=torch.bfloat16) + + if init_dist: + dist.destroy_process_group() + + print("[INFO] ✅ Loaded from sharded FSDP checkpoint") + else: + print("[WARNING] No consolidated or EMA checkpoint found") + print("[INFO] Using base model") + + # Create output directory + os.makedirs(args.output_dir, exist_ok=True) + + # Generate videos + print("\n[4] Generating videos...") + print(f"[INFO] Settings: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps") + print(f"[INFO] Guidance scale: {args.guidance_scale}, Flow shift: {args.flow_shift}") + + torch.manual_seed(args.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(args.seed) + pipe.enable_model_cpu_offload() + pipe.transformer.set_attention_backend("_flash_3_hub") + for i, item in enumerate(prompts): + prompt = item["prompt"] + name = item["name"] + + print(f"\n[{i + 1}/{len(prompts)}] Generating: {name}") + print(f" Prompt: {prompt[:80]}...") + + try: + # Generate from scratch + generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + output = pipe( + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + num_inference_steps=args.num_inference_steps, + generator=generator, + ).frames[0] + + # Save video + output_path = os.path.join(args.output_dir, f"{name}.mp4") + export_to_video(output, output_path, fps=args.fps) + + print(f" ✅ Saved to {output_path}") + + # Clear GPU memory after successful generation to prevent accumulation + del output + torch.cuda.empty_cache() + torch.cuda.synchronize() + + except Exception as e: + print(f" ❌ Failed: {e}") + import traceback + + traceback.print_exc() + + # Critical: Clear GPU memory after failure to prevent accumulation + torch.cuda.empty_cache() + torch.cuda.synchronize() + + continue + + print("\n" + "=" * 80) + print("✅ Validation complete!") + print(f"📁 Videos saved to: {args.output_dir}") + print("=" * 80) + + +if __name__ == "__main__": + main() diff --git a/examples/automodel/finetune/finetune.py b/examples/automodel/finetune/finetune.py index 9f0f3d7f..e24c5bf3 100644 --- a/examples/automodel/finetune/finetune.py +++ b/examples/automodel/finetune/finetune.py @@ -16,12 +16,12 @@ from nemo_automodel.components.config._arg_parser import parse_args_and_load_config -from dfm.src.automodel.recipes.train import TrainWan21DiffusionRecipe +from dfm.src.automodel.recipes.train import TrainDiffusionRecipe def main(default_config_path="examples/automodel/finetune/wan2_1_t2v_flow.yaml"): cfg = parse_args_and_load_config(default_config_path) - recipe = TrainWan21DiffusionRecipe(cfg) + recipe = TrainDiffusionRecipe(cfg) recipe.setup() recipe.run_train_validation_loop() diff --git a/examples/automodel/finetune/hunyuan_t2v_flow.yaml b/examples/automodel/finetune/hunyuan_t2v_flow.yaml new file mode 100644 index 00000000..e177d853 --- /dev/null +++ b/examples/automodel/finetune/hunyuan_t2v_flow.yaml @@ -0,0 +1,87 @@ +# HunyuanVideo-1.5 720p T2V Training Configuration +# +# This configuration file is fully compatible with TrainDiffusionRecipe class +# (dfm/src/automodel/recipes/train.py) using FlowMatchingPipelineV2 + +# Model configuration +model: + pretrained_model_name_or_path: "hunyuanvideo-community/HunyuanVideo-1.5-Diffusers-720p_t2v" + mode: "finetune" # "finetune" or "pretrain" + cache_dir: null # Optional: specify cache directory for model weights + attention_backend: "_flash_3_hub" + +# Optimizer configuration +optim: + learning_rate: 5e-6 + + optimizer: + weight_decay: 0.01 + betas: [0.9, 0.999] + +# FSDP (Fully Sharded Data Parallel) configuration +fsdp: + enable_fsdp: true + dp_size: 8 # Auto-calculate based on world_size and other parallel dimensions + dp_replicate_size: 1 + tp_size: 1 # Tensor parallelism size + cp_size: 1 # Context parallelism size + pp_size: 1 # Pipeline parallelism size + cpu_offload: false + activation_checkpointing: true + use_hf_tp_plan: false + +# Flow matching V2 configuration +flow_matching: + adapter_type: "hunyuan" # Options: "hunyuan", "simple" + adapter_kwargs: + use_condition_latents: true + default_image_embed_shape: [729, 1152] + timestep_sampling: "logit_normal" # Options: "uniform", "logit_normal", "lognorm", "mix", "mode" + logit_mean: 0.0 + logit_std: 1.0 + flow_shift: 3.0 # Flow shift for training + mix_uniform_ratio: 0.1 # For "mix" timestep sampling + sigma_min: 0.0 + sigma_max: 1.0 + num_train_timesteps: 1000 + i2v_prob: 0.3 + use_loss_weighting: false + log_interval: 1000 # Steps between detailed logs + summary_log_interval: 100 # Steps between summary logs + +# Training step scheduler configuration +step_scheduler: + num_epochs: 30 + local_batch_size: 1 # Batch size per GPU + global_batch_size: 8 # Effective batch size across all GPUs (with gradient accumulation) + ckpt_every_steps: 1000 # Save checkpoint every N steps + log_every: 10 # Log metrics every N steps + +# Data configuration +data: + dataloader: + _target_: dfm.src.automodel.datasets.build_dataloader + meta_folder: /lustre/fsw/portfolios/coreai/users/pthombre/Automodel/H21/DFM/hunyuanTrainingImages2/ + num_workers: 2 + device: cpu + +# Checkpoint configuration +checkpoint: + enabled: true + checkpoint_dir: /opt/DFM/hunyuan_t2v_flow_outputs_base_recipe_flowPipelineV2/ + model_save_format: torch_save + save_consolidated: false + restore_from: null + +wandb: + project: hunyuan-video-training + mode: online + name: 720p_t2v_run + +# Distributed environment configuration +dist_env: + backend: "nccl" + init_method: "env://" + +# Random seed +seed: 42 diff --git a/examples/automodel/finetune/wan2_1_t2v_flow.yaml b/examples/automodel/finetune/wan2_1_t2v_flow.yaml index 6f45fa66..525cacf1 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow.yaml @@ -3,7 +3,7 @@ seed: 42 wandb: project: wan-t2v-flow-matching mode: online - name: wan2_1_t2v_fm_updated + name: wan2_1_t2v_fm_v2 dist_env: backend: nccl @@ -21,7 +21,7 @@ step_scheduler: data: dataloader: - _target_: dfm.src.automodel.datasets.build_wan21_dataloader + _target_: dfm.src.automodel.datasets.build_dataloader meta_folder: /lustre/fsw/portfolios/coreai/users/linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta/ num_workers: 2 device: cpu @@ -32,13 +32,22 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +# Flow matching V2 configuration flow_matching: - use_sigma_noise: true - timestep_sampling: uniform + adapter_type: "simple" # Options: "hunyuan", "simple" + adapter_kwargs: {} + timestep_sampling: "uniform" # Options: "uniform", "logit_normal", "lognorm", "mix", "mode" logit_mean: 0.0 logit_std: 1.0 flow_shift: 3.0 mix_uniform_ratio: 0.1 + sigma_min: 0.0 + sigma_max: 1.0 + num_train_timesteps: 1000 + i2v_prob: 0.3 + use_loss_weighting: true + log_interval: 100 + summary_log_interval: 10 fsdp: tp_size: 1 diff --git a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml index bbb15d3e..47d8e975 100644 --- a/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml +++ b/examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml @@ -3,7 +3,7 @@ seed: 42 wandb: project: wan-t2v-flow-matching mode: online - name: wan2_1_t2v_fm_updated + name: wan2_1_t2v_fm_multinode_v2 dist_env: backend: nccl @@ -21,7 +21,7 @@ step_scheduler: data: dataloader: - _target_: dfm.src.automodel.datasets.build_wan21_dataloader + _target_: dfm.src.automodel.datasets.build_dataloader meta_folder: /lustre/fsw/portfolios/coreai/users/linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta/ num_workers: 2 device: cpu @@ -33,13 +33,22 @@ optim: weight_decay: 0.01 betas: [0.9, 0.999] +# Flow matching V2 configuration flow_matching: - use_sigma_noise: true - timestep_sampling: uniform + adapter_type: "simple" # Options: "hunyuan", "simple" + adapter_kwargs: {} + timestep_sampling: "uniform" # Options: "uniform", "logit_normal", "lognorm", "mix", "mode" logit_mean: 0.0 logit_std: 1.0 flow_shift: 3.0 mix_uniform_ratio: 0.1 + sigma_min: 0.0 + sigma_max: 1.0 + num_train_timesteps: 1000 + i2v_prob: 0.3 + use_loss_weighting: true + log_interval: 100 + summary_log_interval: 10 fsdp: tp_size: 1 diff --git a/pyproject.toml b/pyproject.toml index e508957d..fccd7cc9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,8 @@ classifiers = [ ] dependencies = [ "accelerate", - "diffusers==0.35.1", + "diffusers>=0.36.0", + "kernels", "easydict", "ftfy", "imageio", diff --git a/uv.lock b/uv.lock index 3534f6be..7d341eba 100644 --- a/uv.lock +++ b/uv.lock @@ -1372,10 +1372,11 @@ wheels = [ [[package]] name = "diffusers" -version = "0.35.1" +version = "0.36.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, + { name = "httpx" }, { name = "huggingface-hub" }, { name = "importlib-metadata" }, { name = "numpy" }, @@ -1384,9 +1385,9 @@ dependencies = [ { name = "requests" }, { name = "safetensors" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/49/05/c4c8736c14e0efe9a835fb91c6ff5e1abddf9894a2f2a28fffe6429378a6/diffusers-0.35.1.tar.gz", hash = "sha256:6f4dc0c9d309a4c4914a2179646f2bc801b5e395a43295fff3b5f9dbd3e28fd3", size = 3369127, upload-time = "2025-08-20T04:16:10.668Z" } +sdist = { url = "https://files.pythonhosted.org/packages/88/45/ccb2e2180ddf475a0f931dac6a50346310e4c464ce3cccb8a65d1fc1e16d/diffusers-0.36.0.tar.gz", hash = "sha256:a9cde8721b415bde6a678f2d02abb85396487e1b0e0d2b4abb462d14a9825ab0", size = 3795088, upload-time = "2025-12-08T10:14:34.255Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/06/a7/c53f294f34d9e1584388721b3d7aa024ea1ac46e86d0c302fc3db40ed960/diffusers-0.35.1-py3-none-any.whl", hash = "sha256:fe29ff10200970c7c5934c6488c213e2a77a03dad5e6fa00bbd8e1d04234cb0e", size = 4121424, upload-time = "2025-08-20T04:16:08.359Z" }, + { url = "https://files.pythonhosted.org/packages/35/50/281f92cb1f83854dbd79b6e958b3bc5018607e2542971d41604ba7a14b2f/diffusers-0.36.0-py3-none-any.whl", hash = "sha256:525d42abc74bfc3b2db594999961295c054b48ef40a11724dacf50e6abd1af98", size = 4597884, upload-time = "2025-12-08T10:14:31.979Z" }, ] [[package]] @@ -2390,6 +2391,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/41/45/1a4ed80516f02155c51f51e8cedb3c1902296743db0bbc66608a0db2814f/jsonschema_specifications-2025.9.1-py3-none-any.whl", hash = "sha256:98802fee3a11ee76ecaca44429fda8a41bff98b00a0f2838151b113f210cc6fe", size = 18437, upload-time = "2025-09-08T01:34:57.871Z" }, ] +[[package]] +name = "kernels" +version = "0.11.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "huggingface-hub" }, + { name = "packaging" }, + { name = "pyyaml" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f0/21/9cdb65155a1d8f8e93c71222470835c36812a704db885cdca68543ad9915/kernels-0.11.5.tar.gz", hash = "sha256:ac95579cb6c1d924f50acd18d2dfac2cbe5233d276a877a200794a619c99beb9", size = 50254, upload-time = "2025-12-17T15:02:40.024Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/62/b8/4b88d8f26c7dca96692eacd3d332fcb38c96b1faf4d7fa22e9a995a7c664/kernels-0.11.5-py3-none-any.whl", hash = "sha256:a07a24ca458d6635cd6eff4a9332483e8a9ee531aee5927b425c3230bf6d613b", size = 46484, upload-time = "2025-12-17T15:02:38.448Z" }, +] + [[package]] name = "kiwisolver" version = "1.4.10rc0" @@ -3623,6 +3639,7 @@ dependencies = [ { name = "ftfy" }, { name = "imageio" }, { name = "imageio-ffmpeg" }, + { name = "kernels" }, { name = "megatron-energon" }, { name = "opencv-python-headless" }, ] @@ -3679,11 +3696,12 @@ torch-cu124 = [ [package.metadata] requires-dist = [ { name = "accelerate" }, - { name = "diffusers", specifier = "==0.35.1" }, + { name = "diffusers", specifier = ">=0.36.0" }, { name = "easydict" }, { name = "ftfy" }, { name = "imageio" }, { name = "imageio-ffmpeg" }, + { name = "kernels" }, { name = "megatron-energon" }, { name = "opencv-python-headless", specifier = "==4.10.0.84" }, ] @@ -5549,28 +5567,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.14.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/df/55/cccfca45157a2031dcbb5a462a67f7cf27f8b37d4b3b1cd7438f0f5c1df6/ruff-0.14.4.tar.gz", hash = "sha256:f459a49fe1085a749f15414ca76f61595f1a2cc8778ed7c279b6ca2e1fd19df3", size = 5587844, upload-time = "2025-11-06T22:07:45.033Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/17/b9/67240254166ae1eaa38dec32265e9153ac53645a6c6670ed36ad00722af8/ruff-0.14.4-py3-none-linux_armv6l.whl", hash = "sha256:e6604613ffbcf2297cd5dcba0e0ac9bd0c11dc026442dfbb614504e87c349518", size = 12606781, upload-time = "2025-11-06T22:07:01.841Z" }, - { url = "https://files.pythonhosted.org/packages/46/c8/09b3ab245d8652eafe5256ab59718641429f68681ee713ff06c5c549f156/ruff-0.14.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d99c0b52b6f0598acede45ee78288e5e9b4409d1ce7f661f0fa36d4cbeadf9a4", size = 12946765, upload-time = "2025-11-06T22:07:05.858Z" }, - { url = "https://files.pythonhosted.org/packages/14/bb/1564b000219144bf5eed2359edc94c3590dd49d510751dad26202c18a17d/ruff-0.14.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9358d490ec030f1b51d048a7fd6ead418ed0826daf6149e95e30aa67c168af33", size = 11928120, upload-time = "2025-11-06T22:07:08.023Z" }, - { url = "https://files.pythonhosted.org/packages/a3/92/d5f1770e9988cc0742fefaa351e840d9aef04ec24ae1be36f333f96d5704/ruff-0.14.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:81b40d27924f1f02dfa827b9c0712a13c0e4b108421665322218fc38caf615c2", size = 12370877, upload-time = "2025-11-06T22:07:10.015Z" }, - { url = "https://files.pythonhosted.org/packages/e2/29/e9282efa55f1973d109faf839a63235575519c8ad278cc87a182a366810e/ruff-0.14.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f5e649052a294fe00818650712083cddc6cc02744afaf37202c65df9ea52efa5", size = 12408538, upload-time = "2025-11-06T22:07:13.085Z" }, - { url = "https://files.pythonhosted.org/packages/8e/01/930ed6ecfce130144b32d77d8d69f5c610e6d23e6857927150adf5d7379a/ruff-0.14.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa082a8f878deeba955531f975881828fd6afd90dfa757c2b0808aadb437136e", size = 13141942, upload-time = "2025-11-06T22:07:15.386Z" }, - { url = "https://files.pythonhosted.org/packages/6a/46/a9c89b42b231a9f487233f17a89cbef9d5acd538d9488687a02ad288fa6b/ruff-0.14.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:1043c6811c2419e39011890f14d0a30470f19d47d197c4858b2787dfa698f6c8", size = 14544306, upload-time = "2025-11-06T22:07:17.631Z" }, - { url = "https://files.pythonhosted.org/packages/78/96/9c6cf86491f2a6d52758b830b89b78c2ae61e8ca66b86bf5a20af73d20e6/ruff-0.14.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a9f3a936ac27fb7c2a93e4f4b943a662775879ac579a433291a6f69428722649", size = 14210427, upload-time = "2025-11-06T22:07:19.832Z" }, - { url = "https://files.pythonhosted.org/packages/71/f4/0666fe7769a54f63e66404e8ff698de1dcde733e12e2fd1c9c6efb689cb5/ruff-0.14.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95643ffd209ce78bc113266b88fba3d39e0461f0cbc8b55fb92505030fb4a850", size = 13658488, upload-time = "2025-11-06T22:07:22.32Z" }, - { url = "https://files.pythonhosted.org/packages/ee/79/6ad4dda2cfd55e41ac9ed6d73ef9ab9475b1eef69f3a85957210c74ba12c/ruff-0.14.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:456daa2fa1021bc86ca857f43fe29d5d8b3f0e55e9f90c58c317c1dcc2afc7b5", size = 13354908, upload-time = "2025-11-06T22:07:24.347Z" }, - { url = "https://files.pythonhosted.org/packages/b5/60/f0b6990f740bb15c1588601d19d21bcc1bd5de4330a07222041678a8e04f/ruff-0.14.4-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:f911bba769e4a9f51af6e70037bb72b70b45a16db5ce73e1f72aefe6f6d62132", size = 13587803, upload-time = "2025-11-06T22:07:26.327Z" }, - { url = "https://files.pythonhosted.org/packages/c9/da/eaaada586f80068728338e0ef7f29ab3e4a08a692f92eb901a4f06bbff24/ruff-0.14.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:76158a7369b3979fa878612c623a7e5430c18b2fd1c73b214945c2d06337db67", size = 12279654, upload-time = "2025-11-06T22:07:28.46Z" }, - { url = "https://files.pythonhosted.org/packages/66/d4/b1d0e82cf9bf8aed10a6d45be47b3f402730aa2c438164424783ac88c0ed/ruff-0.14.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:f3b8f3b442d2b14c246e7aeca2e75915159e06a3540e2f4bed9f50d062d24469", size = 12357520, upload-time = "2025-11-06T22:07:31.468Z" }, - { url = "https://files.pythonhosted.org/packages/04/f4/53e2b42cc82804617e5c7950b7079d79996c27e99c4652131c6a1100657f/ruff-0.14.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c62da9a06779deecf4d17ed04939ae8b31b517643b26370c3be1d26f3ef7dbde", size = 12719431, upload-time = "2025-11-06T22:07:33.831Z" }, - { url = "https://files.pythonhosted.org/packages/a2/94/80e3d74ed9a72d64e94a7b7706b1c1ebaa315ef2076fd33581f6a1cd2f95/ruff-0.14.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:5a443a83a1506c684e98acb8cb55abaf3ef725078be40237463dae4463366349", size = 13464394, upload-time = "2025-11-06T22:07:35.905Z" }, - { url = "https://files.pythonhosted.org/packages/54/1a/a49f071f04c42345c793d22f6cf5e0920095e286119ee53a64a3a3004825/ruff-0.14.4-py3-none-win32.whl", hash = "sha256:643b69cb63cd996f1fc7229da726d07ac307eae442dd8974dbc7cf22c1e18fff", size = 12493429, upload-time = "2025-11-06T22:07:38.43Z" }, - { url = "https://files.pythonhosted.org/packages/bc/22/e58c43e641145a2b670328fb98bc384e20679b5774258b1e540207580266/ruff-0.14.4-py3-none-win_amd64.whl", hash = "sha256:26673da283b96fe35fa0c939bf8411abec47111644aa9f7cfbd3c573fb125d2c", size = 13635380, upload-time = "2025-11-06T22:07:40.496Z" }, - { url = "https://files.pythonhosted.org/packages/30/bd/4168a751ddbbf43e86544b4de8b5c3b7be8d7167a2a5cb977d274e04f0a1/ruff-0.14.4-py3-none-win_arm64.whl", hash = "sha256:dd09c292479596b0e6fec8cd95c65c3a6dc68e9ad17b8f2382130f87ff6a75bb", size = 12663065, upload-time = "2025-11-06T22:07:42.603Z" }, +version = "0.11.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/5b/3ae20f89777115944e89c2d8c2e795dcc5b9e04052f76d5347e35e0da66e/ruff-0.11.4.tar.gz", hash = "sha256:f45bd2fb1a56a5a85fae3b95add03fb185a0b30cf47f5edc92aa0355ca1d7407", size = 3933063, upload-time = "2025-04-04T18:24:52.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/db/baee59ac88f57527fcbaad3a7b309994e42329c6bc4d4d2b681a3d7b5426/ruff-0.11.4-py3-none-linux_armv6l.whl", hash = "sha256:d9f4a761ecbde448a2d3e12fb398647c7f0bf526dbc354a643ec505965824ed2", size = 10106493, upload-time = "2025-04-04T18:23:56.751Z" }, + { url = "https://files.pythonhosted.org/packages/c1/d6/9a0962cbb347f4ff98b33d699bf1193ff04ca93bed4b4222fd881b502154/ruff-0.11.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:8c1747d903447d45ca3d40c794d1a56458c51e5cc1bc77b7b64bd2cf0b1626cc", size = 10876382, upload-time = "2025-04-04T18:24:02.391Z" }, + { url = "https://files.pythonhosted.org/packages/3a/8f/62bab0c7d7e1ae3707b69b157701b41c1ccab8f83e8501734d12ea8a839f/ruff-0.11.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:51a6494209cacca79e121e9b244dc30d3414dac8cc5afb93f852173a2ecfc906", size = 10237050, upload-time = "2025-04-04T18:24:05.387Z" }, + { url = "https://files.pythonhosted.org/packages/09/96/e296965ae9705af19c265d4d441958ed65c0c58fc4ec340c27cc9d2a1f5b/ruff-0.11.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f171605f65f4fc49c87f41b456e882cd0c89e4ac9d58e149a2b07930e1d466f", size = 10424984, upload-time = "2025-04-04T18:24:08.134Z" }, + { url = "https://files.pythonhosted.org/packages/e5/56/644595eb57d855afed6e54b852e2df8cd5ca94c78043b2f29bdfb29882d5/ruff-0.11.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ebf99ea9af918878e6ce42098981fc8c1db3850fef2f1ada69fb1dcdb0f8e79e", size = 9957438, upload-time = "2025-04-04T18:24:11.061Z" }, + { url = "https://files.pythonhosted.org/packages/86/83/9d3f3bed0118aef3e871ded9e5687fb8c5776bde233427fd9ce0a45db2d4/ruff-0.11.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edad2eac42279df12e176564a23fc6f4aaeeb09abba840627780b1bb11a9d223", size = 11547282, upload-time = "2025-04-04T18:24:13.739Z" }, + { url = "https://files.pythonhosted.org/packages/40/e6/0c6e4f5ae72fac5ccb44d72c0111f294a5c2c8cc5024afcb38e6bda5f4b3/ruff-0.11.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f103a848be9ff379fc19b5d656c1f911d0a0b4e3e0424f9532ececf319a4296e", size = 12182020, upload-time = "2025-04-04T18:24:16.799Z" }, + { url = "https://files.pythonhosted.org/packages/b5/92/4aed0e460aeb1df5ea0c2fbe8d04f9725cccdb25d8da09a0d3f5b8764bf8/ruff-0.11.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:193e6fac6eb60cc97b9f728e953c21cc38a20077ed64f912e9d62b97487f3f2d", size = 11679154, upload-time = "2025-04-04T18:24:19.797Z" }, + { url = "https://files.pythonhosted.org/packages/1b/d3/7316aa2609f2c592038e2543483eafbc62a0e1a6a6965178e284808c095c/ruff-0.11.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7af4e5f69b7c138be8dcffa5b4a061bf6ba6a3301f632a6bce25d45daff9bc99", size = 13905985, upload-time = "2025-04-04T18:24:24.542Z" }, + { url = "https://files.pythonhosted.org/packages/63/80/734d3d17546e47ff99871f44ea7540ad2bbd7a480ed197fe8a1c8a261075/ruff-0.11.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126b1bf13154aa18ae2d6c3c5efe144ec14b97c60844cfa6eb960c2a05188222", size = 11348343, upload-time = "2025-04-04T18:24:27.742Z" }, + { url = "https://files.pythonhosted.org/packages/04/7b/70fc7f09a0161dce9613a4671d198f609e653d6f4ff9eee14d64c4c240fb/ruff-0.11.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e8806daaf9dfa881a0ed603f8a0e364e4f11b6ed461b56cae2b1c0cab0645304", size = 10308487, upload-time = "2025-04-04T18:24:30.59Z" }, + { url = "https://files.pythonhosted.org/packages/1a/22/1cdd62dabd678d75842bf4944fd889cf794dc9e58c18cc547f9eb28f95ed/ruff-0.11.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5d94bb1cc2fc94a769b0eb975344f1b1f3d294da1da9ddbb5a77665feb3a3019", size = 9929091, upload-time = "2025-04-04T18:24:33.24Z" }, + { url = "https://files.pythonhosted.org/packages/9f/20/40e0563506332313148e783bbc1e4276d657962cc370657b2fff20e6e058/ruff-0.11.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:995071203d0fe2183fc7a268766fd7603afb9996785f086b0d76edee8755c896", size = 10924659, upload-time = "2025-04-04T18:24:36.728Z" }, + { url = "https://files.pythonhosted.org/packages/b5/41/eef9b7aac8819d9e942f617f9db296f13d2c4576806d604aba8db5a753f1/ruff-0.11.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:7a37ca937e307ea18156e775a6ac6e02f34b99e8c23fe63c1996185a4efe0751", size = 11428160, upload-time = "2025-04-04T18:24:40.08Z" }, + { url = "https://files.pythonhosted.org/packages/ff/61/c488943414fb2b8754c02f3879de003e26efdd20f38167ded3fb3fc1cda3/ruff-0.11.4-py3-none-win32.whl", hash = "sha256:0e9365a7dff9b93af933dab8aebce53b72d8f815e131796268709890b4a83270", size = 10311496, upload-time = "2025-04-04T18:24:42.94Z" }, + { url = "https://files.pythonhosted.org/packages/b6/2b/2a1c8deb5f5dfa3871eb7daa41492c4d2b2824a74d2b38e788617612a66d/ruff-0.11.4-py3-none-win_amd64.whl", hash = "sha256:5a9fa1c69c7815e39fcfb3646bbfd7f528fa8e2d4bebdcf4c2bd0fa037a255fb", size = 11399146, upload-time = "2025-04-04T18:24:45.651Z" }, + { url = "https://files.pythonhosted.org/packages/4f/03/3aec4846226d54a37822e4c7ea39489e4abd6f88388fba74e3d4abe77300/ruff-0.11.4-py3-none-win_arm64.whl", hash = "sha256:d435db6b9b93d02934cf61ef332e66af82da6d8c69aefdea5994c89997c7a0fc", size = 10450306, upload-time = "2025-04-04T18:24:49.603Z" }, ] [[package]]