From 75ed5b206f53e1a96006d6ccc2ae40e10340a263 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 30 Oct 2025 07:32:08 -0700 Subject: [PATCH 01/80] first commit --- .../data/prepare_energon_dataset_wan.py | 404 +++++++++ .../megatron/data/wan_energon_datamodule.py | 47 + dfm/src/megatron/data/wan_taskencoder.py | 192 ++++ .../model/wan/flow_matching/__init__.py | 13 + .../flow_matching/flow_inference_pipeline.py | 568 ++++++++++++ .../model/wan/flow_matching/flow_pipeline.py | 223 +++++ .../wan/flow_matching/time_shift_utils.py | 108 +++ .../model/wan/inference/configs/__init__.py | 52 ++ .../wan/inference/configs/shared_config.py | 18 + .../wan/inference/configs/wan_i2v_14B.py | 35 + .../wan/inference/configs/wan_t2v_14B.py | 28 + .../wan/inference/configs/wan_t2v_1_3B.py | 28 + .../model/wan/inference/utils/fm_solvers.py | 858 ++++++++++++++++++ .../wan/inference/utils/fm_solvers_unipc.py | 801 ++++++++++++++++ .../model/wan/inference/utils/utils.py | 117 +++ .../megatron/model/wan/modules/__init__.py | 13 + dfm/src/megatron/model/wan/modules/t5.py | 512 +++++++++++ .../megatron/model/wan/modules/tokenizers.py | 81 ++ dfm/src/megatron/model/wan/modules/vae.py | 662 ++++++++++++++ dfm/src/megatron/model/wan/rope_utils.py | 65 ++ dfm/src/megatron/model/wan/utils/utils.py | 128 +++ dfm/src/megatron/model/wan/wan_layer_spec.py | 591 ++++++++++++ dfm/src/megatron/model/wan/wan_model.py | 332 +++++++ 23 files changed, 5876 insertions(+) create mode 100644 dfm/src/megatron/data/prepare_energon_dataset_wan.py create mode 100644 dfm/src/megatron/data/wan_energon_datamodule.py create mode 100644 dfm/src/megatron/data/wan_taskencoder.py create mode 100644 dfm/src/megatron/model/wan/flow_matching/__init__.py create mode 100644 dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py create mode 100644 dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py create mode 100644 dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py create mode 100644 dfm/src/megatron/model/wan/inference/configs/__init__.py create mode 100644 dfm/src/megatron/model/wan/inference/configs/shared_config.py create mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py create mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py create mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py create mode 100644 dfm/src/megatron/model/wan/inference/utils/fm_solvers.py create mode 100644 dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py create mode 100644 dfm/src/megatron/model/wan/inference/utils/utils.py create mode 100644 dfm/src/megatron/model/wan/modules/__init__.py create mode 100644 dfm/src/megatron/model/wan/modules/t5.py create mode 100644 dfm/src/megatron/model/wan/modules/tokenizers.py create mode 100644 dfm/src/megatron/model/wan/modules/vae.py create mode 100644 dfm/src/megatron/model/wan/rope_utils.py create mode 100644 dfm/src/megatron/model/wan/utils/utils.py create mode 100644 dfm/src/megatron/model/wan/wan_layer_spec.py create mode 100644 dfm/src/megatron/model/wan/wan_model.py diff --git a/dfm/src/megatron/data/prepare_energon_dataset_wan.py b/dfm/src/megatron/data/prepare_energon_dataset_wan.py new file mode 100644 index 00000000..a8464aa6 --- /dev/null +++ b/dfm/src/megatron/data/prepare_energon_dataset_wan.py @@ -0,0 +1,404 @@ +import os +import json +import pickle +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +import torch +import webdataset as wds + +from diffusers import AutoencoderKLWan +from transformers import AutoTokenizer, UMT5EncoderModel + + +def _map_interpolation(resize_mode: str) -> int: + interpolation_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + if resize_mode not in interpolation_map: + raise ValueError(f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}") + return interpolation_map[resize_mode] + + +def _calculate_resize_dimensions( + original_height: int, + original_width: int, + target_size: Optional[Tuple[int, int]], + maintain_aspect_ratio: bool, +) -> Tuple[int, int]: + if target_size is None: + return original_height, original_width + + target_height, target_width = target_size + if not maintain_aspect_ratio: + return target_height, target_width + + original_aspect = original_width / max(1, original_height) + target_aspect = target_width / max(1, target_height) + + if original_aspect > target_aspect: + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + new_height = target_height + new_width = int(round(target_height * original_aspect)) + + return new_height, new_width + + +def _resize_frame( + frame: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> np.ndarray: + if target_size is None: + return frame + + original_height, original_width = frame.shape[:2] + resize_height, resize_width = _calculate_resize_dimensions( + original_height, original_width, target_size, maintain_aspect_ratio + ) + + interpolation = _map_interpolation(resize_mode) + resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) + + if maintain_aspect_ratio and center_crop: + target_height, target_width = target_size + if resize_height != target_height or resize_width != target_width: + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = min(resize_height, y_start + target_height) + x_end = min(resize_width, x_start + target_width) + resized_frame = resized_frame[y_start:y_end, x_start:x_end] + + if resized_frame.shape[0] < target_height or resized_frame.shape[1] < target_width: + pad_height = max(0, target_height - resized_frame.shape[0]) + pad_width = max(0, target_width - resized_frame.shape[1]) + resized_frame = np.pad( + resized_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 + ) + + return resized_frame + + +def _read_sidecar_caption(jsonl_path: Path) -> str: + if not jsonl_path.exists(): + return "" + try: + with open(jsonl_path, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + except Exception: + continue + # Prefer keys used across datasets + for key in ("vila_caption", "gemini_v2_caption", "caption", "text"): + if key in obj and isinstance(obj[key], str): + return obj[key] + # If no known key, try first string value + for v in obj.values(): + if isinstance(v, str): + return v + break + except Exception: + return "" + return "" + + +def _get_total_frames(video_path: str) -> int: + cap = cv2.VideoCapture(video_path) + total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + cap.release() + return max(0, total) + + +def _load_metadata(video_folder: Path) -> List[Dict]: + meta_path = video_folder / "meta.json" + if meta_path.exists(): + with open(meta_path, "r") as f: + return json.load(f) + + # Fallback: scan for .mp4 files with sidecar .jsonl; use full frame range + items: List[Dict] = [] + for entry in sorted(video_folder.iterdir()): + if not entry.is_file(): + continue + if entry.suffix.lower() != ".mp4": + continue + video_name = entry.name + video_path = str(entry) + total_frames = _get_total_frames(video_path) + start_frame = 0 + end_frame = max(0, total_frames - 1) + sidecar = entry.with_suffix("") + # Handle names with additional dots gracefully + sidecar_jsonl = Path(str(entry).rsplit(".", 1)[0] + ".jsonl") + caption = _read_sidecar_caption(sidecar_jsonl) + items.append( + { + "file_name": video_name, + "start_frame": start_frame, + "end_frame": end_frame, + "vila_caption": caption, + } + ) + if not items: + raise FileNotFoundError(f"No meta.json and no .mp4 files found in {video_folder}") + return items + + +def _load_frames_cv2( + video_path: str, + start_frame: int, + end_frame: int, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, + target_dtype: torch.dtype, +) -> torch.Tensor: + cap = cv2.VideoCapture(video_path) + frames: List[np.ndarray] = [] + + cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame) + for frame_idx in range(start_frame, end_frame + 1): + ret, frame = cap.read() + if not ret: + break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = _resize_frame(frame, target_size, resize_mode, maintain_aspect_ratio, center_crop) + frame = frame.astype(np.float32) / 255.0 + frames.append(frame) + cap.release() + + if not frames: + raise ValueError(f"No frames loaded from {video_path}") + + video_array = np.array(frames) # T, H, W, C in [0,1] + video_tensor = torch.from_numpy(video_array) # T, H, W, C + video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # 1, C, T, H, W + video_tensor = video_tensor.to(dtype=target_dtype) + return video_tensor + + +@torch.no_grad() +def _init_hf_models( + model_id: str, + device: str, + enable_memory_optimization: bool, +): + dtype = torch.float16 if device.startswith("cuda") else torch.float32 + + text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=dtype, + ) + text_encoder.to(device) + text_encoder.eval() + + vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=dtype, + ) + vae.to(device) + vae.eval() + if enable_memory_optimization: + vae.enable_slicing() + vae.enable_tiling() + + tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer") + + return vae, text_encoder, tokenizer, dtype + + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = caption.strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + return outputs + + +@torch.no_grad() +def _encode_video_latents( + vae: AutoencoderKLWan, + device: str, + video_tensor: torch.Tensor, + deterministic_latents: bool, +) -> torch.Tensor: + video_tensor = video_tensor.to(device=device, dtype=vae.dtype) + video_tensor = video_tensor * 2.0 - 1.0 # [0,1] -> [-1,1] + + latent_dist = vae.encode(video_tensor) + if deterministic_latents: + video_latents = latent_dist.latent_dist.mean + else: + video_latents = latent_dist.latent_dist.sample() + + latent_mean = video_latents.mean().item() + latent_std = video_latents.std().item() + + if abs(latent_mean) < 0.5 and 0.5 < latent_std < 2.0: + final_latents = video_latents + else: + if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): + raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + final_latents = (video_latents - latents_mean) / latents_std + + return final_latents + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Prepare WAN WebDataset shards using HF automodel encoders and resizing" + ) + parser.add_argument("--video_folder", type=str, required=True, help="Folder containing videos and meta.json") + parser.add_argument("--output_dir", type=str, required=True, help="Directory to write webdataset shards") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.1-T2V-14B-Diffusers", + help="Wan2.1 model ID (e.g., Wan-AI/Wan2.1-T2V-14B-Diffusers or Wan-AI/Wan2.1-T2V-1.3B-Diffusers)", + ) + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use") + parser.add_argument( + "--stochastic", + action="store_true", + help="Use stochastic encoding (sampling) instead of deterministic posterior mean", + ) + parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") + parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per shard") + + # Resize arguments (match automodel) + parser.add_argument("--height", type=int, default=None, help="Target height for video frames") + parser.add_argument("--width", type=int, default=None, help="Target width for video frames") + parser.add_argument( + "--resize_mode", + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode for resizing", + ) + parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") + parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") + + args = parser.parse_args() + + video_folder = Path(args.video_folder) + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + shard_pattern = str(output_dir / "shard-%06d.tar") + + # Target size + target_size = None + if args.height is not None and args.width is not None: + target_size = (args.height, args.width) + elif (args.height is None) ^ (args.width is None): + parser.error("Both --height and --width must be specified together") + + # Init HF models + vae, text_encoder, tokenizer, model_dtype = _init_hf_models( + model_id=args.model, + device=args.device, + enable_memory_optimization=not args.no_memory_optimization, + ) + + # Load metadata list + metadata_list = _load_metadata(video_folder) + + with wds.ShardWriter(shard_pattern, maxcount=args.shard_maxcount) as sink: + written = 0 + for index, meta in enumerate(metadata_list): + video_name = meta["file_name"] + start_frame = int(meta["start_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive + caption_text = meta.get("vila_caption", "") + + video_path = str(video_folder / video_name) + # Load frames using the same OpenCV + resize path as automodel + video_tensor = _load_frames_cv2( + video_path=video_path, + start_frame=start_frame, + end_frame=end_frame, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + target_dtype=model_dtype, + ) + + # Encode text and video with HF models exactly like automodel + text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) + latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) + + # Move to CPU without changing dtype; keep exact values to match automodel outputs + text_embed_cpu = text_embed.detach().to(device="cpu") + latents_cpu = latents.detach().to(device="cpu") + + # Reshape to match Mcore's Wan input format + text_embed_cpu = text_embed_cpu[0] + latents_cpu = latents_cpu[0] + + # Build JSON side-info similar to prepare_energon script + C, T, H, W = video_tensor.shape[1:] # 1,C,T,H,W + json_data = { + "video_path": video_path, + "processed_frames": int(T), + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + sample = { + "__key__": f"{index:06}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu), + "json": json_data, + } + sink.write(sample) + written += 1 + + print("Done writing shards using HF automodel encoders.") + + +if __name__ == "__main__": + main() + + diff --git a/dfm/src/megatron/data/wan_energon_datamodule.py b/dfm/src/megatron/data/wan_energon_datamodule.py new file mode 100644 index 00000000..98774e81 --- /dev/null +++ b/dfm/src/megatron/data/wan_energon_datamodule.py @@ -0,0 +1,47 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import logging +from typing import Any, Dict, Literal + +from torch import int_repr + +from megatron.bridge.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule +from megatron.bridge.data.wan.wan_taskencoder import WanTaskEncoder +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + +@dataclass(kw_only=True) +class WanDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=WanTaskEncoder(seq_length=self.seq_length), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() \ No newline at end of file diff --git a/dfm/src/megatron/data/wan_taskencoder.py b/dfm/src/megatron/data/wan_taskencoder.py new file mode 100644 index 00000000..097a8583 --- /dev/null +++ b/dfm/src/megatron/data/wan_taskencoder.py @@ -0,0 +1,192 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import torch +import torch.nn.functional as F +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample["json"], + pth=sample["pth"], + pickle=sample["pickle"], + ) + + +class WanTaskEncoder(DefaultTaskEncoder): + """ + Task encoder for Wan dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + patch_spatial (int): The spatial patch size. Defaults to 2. + patch_temporal (int): The temporal patch size. Defaults to 1. + seq_length (int): The sequence length. Defaults to 1024. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + seq_length: int = 1024, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.seq_length = seq_length + + + # def actual_encode_sample(self, sample: dict) -> dict: + + # video_latent = sample["pth"] + # context_embeddings = sample["pickle"] + # video_metadata = sample["json"] + + # # sanity quality check + # if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + # raise SkipSample() + # if torch.max(torch.abs(video_latent)) > 1e3: + # raise SkipSample() + + # # calculate grid size + # grid_size = grid_sizes_calculation( + # input_shape = video_latent.shape[1:], + # patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), + # ) + + # ### Note: shape of sample's values + # # video_latent: [latents_channels, F_latents, W_latents, H_latents] + # # grid_size: [F_patches, W_patches, H_patches] + # # context_embeddings: [context_seq_len, text_embedding_dim] + + # return dict( + # video_latent=video_latent, + # grid_size=grid_size, + # context_embeddings=context_embeddings, + # video_metadata=video_metadata, + # ) + + + def encode_sample(self, sample: dict) -> dict: + + # mock encode sample + video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) + # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) + grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) + context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) + video_metadata = {} + + return dict( + video_latent=video_latent, + grid_size=grid_size, + context_embeddings=context_embeddings, + video_metadata=video_metadata, + ) + + + def batch(self, samples: list[dict]) -> dict: + + # process video latents + # do padding here for video latents + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # running patchify + video_latents = patchify([sample["video_latent"] for sample in samples], self.patch_size) + + # build per-sample loss masks (1 for valid tokens pre-padding) + loss_masks = [torch.ones(v.shape[0]) for v in video_latents] + # calculate all sequence lengths of video latents for self-attention (for videos, we do this before padding to get original seq len) + seq_len_q = [v.shape[0] for v in video_latents] + seq_len_q = torch.tensor(seq_len_q, dtype=torch.int32) + + + # padding and stack video latents + max_video_seq_len = max([video_latent.shape[0] for video_latent in video_latents]) + # CAVEAT: + # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because + # because pipeline parallelism requires pre-specified sequence length to create buffer + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if max_video_seq_len > self.seq_length: + raise ValueError(f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}") + else: + # set max_video_seq_len to DataModule's seq_length + max_video_seq_len = self.seq_length + # CAVEAT: + # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] + # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") + if parallel_state.get_context_parallel_world_size() > 1: + batch_size = len(video_latents) + assert batch_size == 1, "Error: Batch size must be 1 when using context parallelism" + sharding_factor = parallel_state.get_context_parallel_world_size() * 2 + max_video_seq_len = ((max_video_seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor + video_latents = [F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents] + video_latents = torch.stack(video_latents, dim=1) + # pad and stack loss masks to shape [S_max, B] + loss_masks = [F.pad(m, (0, max_video_seq_len - m.shape[0])) for m in loss_masks] + loss_masks = torch.stack(loss_masks, dim=1) + + # process grid sizes + grid_sizes = [torch.tensor(sample["grid_size"], dtype=torch.int32) for sample in samples] + grid_sizes = torch.stack(grid_sizes, dim=0) + + # process text embeddings + # pad here for text embeddings + context_max_len = 512 + context_embeddings = [sample["context_embeddings"] for sample in samples] + context_embeddings = [F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) for context_embedding in context_embeddings] + # calculate all sequence lengths of context embeddings for cross-attention (for videos, we do this after padding to get padded seq len) + seq_len_kv = [c.shape[0] for c in context_embeddings] + seq_len_kv = torch.tensor(seq_len_kv, dtype=torch.int32) + # stack context embeddings + context_embeddings = torch.stack(context_embeddings, dim=1) + + # process video metadata + video_metadata = [sample["video_metadata"] for sample in samples] + + return dict( + video_latents = video_latents, + max_video_seq_len = max_video_seq_len, + grid_sizes = grid_sizes, + context_embeddings = context_embeddings, + loss_mask = loss_masks, + seq_len_q = seq_len_q, + seq_len_kv = seq_len_kv, + video_metadata = video_metadata, + ) \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/flow_matching/__init__.py b/dfm/src/megatron/model/wan/flow_matching/__init__.py new file mode 100644 index 00000000..d9155f92 --- /dev/null +++ b/dfm/src/megatron/model/wan/flow_matching/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py new file mode 100644 index 00000000..fedef4f4 --- /dev/null +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -0,0 +1,568 @@ +import gc +import logging +import math +import os +import random +import sys +import types +import re +from contextlib import contextmanager +from functools import partial + +import torch +import torch.cuda.amp as amp +import torch.distributed as dist +from tqdm import tqdm + +from megatron.bridge.models.wan.wan_model import WanModel +from megatron.bridge.models.wan.wan_provider import WanModelProvider +from megatron.bridge.models.wan.modules.t5 import T5EncoderModel +from megatron.bridge.models.wan.modules import WanVAE +from megatron.bridge.models.wan.inference.utils.fm_solvers import ( + FlowDPMSolverMultistepScheduler, + get_sampling_sigmas, + retrieve_timesteps, +) +from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state +from torch.nn import functional as F +from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp + +import math +from typing import Tuple, Union + +class FlowInferencePipeline: + + def __init__( + self, + config, + checkpoint_dir, + checkpoint_step=None, + t5_checkpoint_dir=None, + vae_checkpoint_dir=None, + device_id=0, + rank=0, + t5_cpu=False, + + tensor_parallel_size=1, + context_parallel_size=1, + pipeline_parallel_size=1, + sequence_parallel=False, + pipeline_dtype=torch.float32, + ): + r""" + Initializes the FlowInferencePipeline with the given parameters. + + Args: + config (EasyDict): + Object containing model parameters initialized from config.py + checkpoint_dir (`str`): + Path to directory containing model checkpoints + t5_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing T5 checkpoint and tokenizer; falls back to `checkpoint_dir` if None. + vae_checkpoint_dir (`str`, *optional*, defaults to None): + Optional directory containing VAE checkpoint; falls back to `checkpoint_dir` if None. + device_id (`int`, *optional*, defaults to 0): + Id of target GPU device + rank (`int`, *optional*, defaults to 0): + Process rank for distributed training + t5_cpu (`bool`, *optional*, defaults to False): + Whether to place T5 model on CPU. Only works without t5_fsdp. + """ + self.device = torch.device(f"cuda:{device_id}") + self.config = config + self.rank = rank + self.t5_cpu = t5_cpu + self.tensor_parallel_size = tensor_parallel_size + self.context_parallel_size = context_parallel_size + self.pipeline_parallel_size = pipeline_parallel_size + self.sequence_parallel = sequence_parallel + self.pipeline_dtype = pipeline_dtype + self.num_train_timesteps = config.num_train_timesteps + self.param_dtype = config.param_dtype + + self.text_encoder = T5EncoderModel( + text_len=config.text_len, + dtype=config.t5_dtype, + device=torch.device('cpu'), + checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), + tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), + shard_fn=None) + + self.vae_stride = config.vae_stride + self.patch_size = config.patch_size + self.vae = WanVAE( + vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), + device=self.device) + + wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) + self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) + + # DEBUGGING + # set qkv_format to to "thd" for context parallelism + self.model.config.qkv_format = "sbhd" + + # set self.sp_size=1 for later use, just to respect the original Wan inference code + self.sp_size = 1 + + if dist.is_initialized(): + dist.barrier() + self.model.to(self.device) + + self.sample_neg_prompt = config.sample_neg_prompt + + + def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: + r""" + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (torch.Tensor): + Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes.tolist()): + u = u[:math.prod(v)].view(*v, *self.patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) + out.append(u) + return out + + + def setup_model_from_checkpoint(self, checkpoint_dir): + provider = WanModelProvider() + provider.tensor_model_parallel_size = self.tensor_parallel_size + provider.pipeline_model_parallel_size = self.pipeline_parallel_size + provider.context_parallel_size = self.context_parallel_size + provider.sequence_parallel = self.sequence_parallel + provider.pipeline_dtype = self.pipeline_dtype + # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run + provider.finalize() + provider.initialize_model_parallel(seed=0) + + ## Read from megatron checkpoint + from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model + model = _load_megatron_model( + checkpoint_dir, + mp_overrides={ + "tensor_model_parallel_size": self.tensor_parallel_size, + "pipeline_model_parallel_size": self.pipeline_parallel_size, + "context_parallel_size": self.context_parallel_size, + "sequence_parallel": self.sequence_parallel, + "pipeline_dtype": self.pipeline_dtype, + }, + ) + if isinstance(model, list): + model = model[0] + if hasattr(model, "module"): + model = model.module + + return model + + def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: + """ + Resolve checkpoint directory: + - If checkpoint_step is provided, use base_dir/iter_{step:07d} + - Otherwise, pick the largest iter_######## subdirectory under base_dir + """ + if checkpoint_step is not None: + path = os.path.join(base_dir, f"iter_{int(checkpoint_step):07d}") + if os.path.isdir(path): + logging.info(f"Using specified checkpoint: {path}") + return path + raise FileNotFoundError(f"Specified checkpoint step {checkpoint_step} not found at {path}") + + if not os.path.isdir(base_dir): + raise FileNotFoundError(f"Checkpoint base directory does not exist: {base_dir}") + + pattern = re.compile(r"^iter_(\d+)$") + try: + _, latest_path = max( + ((int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name)), + key=lambda x: x[0], + ) + except ValueError: + raise FileNotFoundError( + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + + logging.info(f"Auto-selected latest checkpoint: {latest_path}") + return latest_path + + + def forward_pp_step( + self, + latent_model_input: torch.Tensor, + grid_sizes: list[Tuple[int, int, int]], + max_video_seq_len: int, + timestep: torch.Tensor, + arg_c: dict, + ) -> torch.Tensor: + """ + Forward pass supporting pipeline parallelism. + """ + + from megatron.core import parallel_state + from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank + + pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() + is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) + is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) + + # PP=1: no pipeline parallelism + if pp_world_size == 1: + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + return noise_pred_pp + + # PP>1: pipeline parallelism + hidden_size = self.model.config.hidden_size + batch_size = latent_model_input.shape[1] + # noise prediction shape for communication between first and last pipeline stages + noise_pred_pp_shape = list(latent_model_input.shape) + + if is_pp_first: + # First stage: compute multimodal + first PP slice, send activations, then receive sampled token + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + if is_pp_last: + # Last stage: recv activations, run final slice + output, sample, broadcast + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + noise_pred_pp = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + return noise_pred_pp + + # Intermediate stages: recv -> run local slice -> send -> receive broadcast token + recv_buffer = torch.empty( + (max_video_seq_len, batch_size, hidden_size), + dtype=next(self.model.parameters()).dtype, + device=latent_model_input[0].device, + ) + recv_from_prev_pipeline_rank_(recv_buffer) + recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + self.model.set_input_tensor(recv_buffer) + hidden_states = self.model( + latent_model_input, + grid_sizes=grid_sizes, + t=timestep, + **arg_c) + send_to_next_pipeline_rank(hidden_states) + + noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) + return noise_pred_pp + + + def generate(self, + prompts, + sizes, + frame_nums, + shift=5.0, + sample_solver='unipc', + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True): + r""" + Generates video frames from text prompt using diffusion process. + + Args: + prompts (`list[str]`): + Text prompt for content generation + sizes (list[tuple[int, int]]): + Controls video resolution, (width,height). + frame_nums (`list[int]`): + How many frames to sample from a video. The number should be 4n+1 + shift (`float`, *optional*, defaults to 5.0): + Noise schedule shift parameter. Affects temporal dynamics + sample_solver (`str`, *optional*, defaults to 'unipc'): + Solver used to sample the video. + sampling_steps (`int`, *optional*, defaults to 40): + Number of diffusion sampling steps. Higher values improve quality but slow generation + guide_scale (`float`, *optional*, defaults 5.0): + Classifier-free guidance scale. Controls prompt adherence vs. creativity + n_prompt (`str`, *optional*, defaults to ""): + Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` + seed (`int`, *optional*, defaults to -1): + Random seed for noise generation. If -1, use random seed. + offload_model (`bool`, *optional*, defaults to True): + If True, offloads models to CPU during generation to save VRAM + + Returns: + torch.Tensor: + Generated video frames tensor. Dimensions: (C, N H, W) where: + - C: Color channels (3 for RGB) + - N: Number of frames (81) + - H: Frame height (from size) + - W: Frame width from size) + """ + + # preprocess + target_shapes = [] + for size, frame_num in zip(sizes, frame_nums): + target_shapes.append((self.vae.model.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2])) + + max_video_seq_len = 0 + seq_lens = [] + for target_shape in target_shapes: + seq_len = math.ceil((target_shape[2] * target_shape[3]) / + (self.patch_size[1] * self.patch_size[2]) * + target_shape[1] / self.sp_size) * self.sp_size + seq_lens.append(seq_len) + max_video_seq_len = max(seq_lens) + + if n_prompt == "": + n_prompt = self.sample_neg_prompt + seed = seed if seed >= 0 else random.randint(0, sys.maxsize) + seed_g = torch.Generator(device=self.device) + seed_g.manual_seed(seed) + + + ## process context + context_max_len = 512 + context_lens = [] + contexts = [] + contexts_null = [] + for prompt in prompts: + if not self.t5_cpu: + self.text_encoder.model.to(self.device) + context = self.text_encoder([prompt], self.device)[0] + context_null = self.text_encoder([n_prompt], self.device)[0] + if offload_model: + self.text_encoder.model.cpu() + else: + context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len + contexts.append(context) + contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] + contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] + contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts = torch.stack(contexts, dim=1) + contexts_null = torch.stack(contexts_null, dim=1) + + + ## setup noise + noises = [] + for target_shape in target_shapes: + noises.append( + torch.randn( + target_shape[0], + target_shape[1], + target_shape[2], + target_shape[3], + dtype=torch.float32, + device=self.device, + generator=seed_g) + ) + + + # calculate grid_sizes + grid_sizes = [grid_sizes_calculation( + input_shape =u.shape[1:], + patch_size=self.model.patch_size, + ) for u in noises] + grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) + + + @contextmanager + def noop_no_sync(): + yield + + no_sync = getattr(self.model, 'no_sync', noop_no_sync) + + # evaluation mode + with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): + + if sample_solver == 'unipc': + # Create a prototype scheduler to compute shared timesteps + sample_scheduler = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sample_scheduler.set_timesteps( + sampling_steps, device=self.device, shift=shift) + timesteps = sample_scheduler.timesteps + + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + s = FlowUniPCMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + s.set_timesteps(sampling_steps, device=self.device, shift=shift) + schedulers.append(s) + elif sample_solver == 'dpm++': + sample_scheduler = FlowDPMSolverMultistepScheduler( + num_train_timesteps=self.num_train_timesteps, + shift=1, + use_dynamic_shifting=False) + sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) + timesteps, _ = retrieve_timesteps( + sample_scheduler, + device=self.device, + sigmas=sampling_sigmas) + else: + raise NotImplementedError("Unsupported solver.") + + # sample videos + latents = noises + + from megatron.core.packed_seq_params import PackedSeqParams + cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) + cu_q = cu_q.to(torch.int32).to(self.device) + cu_kv_self = cu_q + cu_kv_cross = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(context_lens), dim=0)]) + cu_kv_cross = cu_kv_cross.to(torch.int32).to(self.device) + packed_seq_params = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_self, + qkv_format=self.model.config.qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_q, + cu_seqlens_kv=cu_kv_cross, + qkv_format=self.model.config.qkv_format, + ), + } + + + arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + + for _, t in enumerate(tqdm(timesteps)): + + batch_size = len(latents) + + # patchify latents + unpatchified_latents = latents + latents = patchify(latents, self.patch_size) + # pad to have same length + for i in range(batch_size): + latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) + latents = torch.stack(latents, dim=1) + + + latent_model_input = latents + timestep = [t] * batch_size + timestep = torch.stack(timestep) + + # run context parallelism slitting + if parallel_state.get_context_parallel_world_size() > 1: + latent_model_input = split_inputs_cp(latent_model_input, 0) + arg_c['context'] = split_inputs_cp(arg_c['context'], 0) + arg_null['context'] = split_inputs_cp(arg_null['context'], 0) + + self.model.to(self.device) + noise_pred_cond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) + + noise_pred_uncond = self.forward_pp_step( + latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + + # run context parallelism gathering + if parallel_state.get_context_parallel_world_size() > 1: + arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep + arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep + # TODO: does this step slow down speed??? + noise_pred_cond = noise_pred_cond.contiguous() + noise_pred_uncond = noise_pred_uncond.contiguous() + noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) + noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) + + # run unpatchify + unpatchified_noise_pred_cond = noise_pred_cond + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_uncond = noise_pred_uncond + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + + noise_preds = [] + for i in range(batch_size): + noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + noise_preds.append(noise_pred) + + # step and update latents + latents = [] + for i in range(batch_size): + + if sample_solver == 'unipc': + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + else: + temp_x0 = sample_scheduler.step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False, + generator=seed_g)[0] + latents.append(temp_x0.squeeze(0)) + + x0 = latents + if offload_model: + self.model.cpu() + torch.cuda.empty_cache() + if self.rank == 0: + videos = self.vae.decode(x0) + else: + videos = None + + del noises, latents + if sample_solver == 'unipc': + del schedulers + else: + del sample_scheduler + if offload_model: + gc.collect() + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + + return videos if self.rank == 0 else None diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py new file mode 100644 index 00000000..af82fa7e --- /dev/null +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, Optional, Tuple, List + +import numpy as np +import torch +from megatron.core import parallel_state +from torch import Tensor +from diffusers import WanPipeline +from megatron.bridge.models.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling +from megatron.bridge.models.wan.utils.utils import patchify, split_inputs_cp + +class FlowPipeline: + + def __init__( + self, + model_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", + seed=1234, + ): + """ + Initializes the FlowPipeline with the given parameters. + """ + self.pipe = WanPipeline.from_pretrained(model_id, vae=None, torch_dtype=torch.float32, text_encoder=None) + + + def training_step( + self, + model, + data_batch: dict[str, torch.Tensor], + # Flow matching parameters + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """ + Performs a single training step using flow matching algorithm. + + This method is responsible for executing one iteration of the model's training. It involves: + 1. Generate noise and add it to the input data. + 2. Pass the noisy data through the network to generate predictions. + 3. Compute the loss based on the difference between the predictions and target. + """ + + video_latents = data_batch['video_latents'] + max_video_seq_len = data_batch['max_video_seq_len'] + context_embeddings = data_batch['context_embeddings'] + loss_mask = data_batch['loss_mask'] + grid_sizes = data_batch['grid_sizes'] + packed_seq_params = data_batch['packed_seq_params'] + video_metadata = data_batch['video_metadata'] + + self.model = model + + batch_size = video_latents.shape[1] + device = video_latents.device + + # # # DEBUGGING precision + # # import torch.cuda.amp as amp + # # with amp.autocast(dtype=torch.bfloat16): + # # # Pass through model + # # ... + + # ======================================================================== + # Flow Matching Timestep Sampling + # ======================================================================== + + num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps + + if use_sigma_noise: + use_uniform = torch.rand(1).item() < mix_uniform_ratio + + if use_uniform or timestep_sampling == "uniform": + # Pure uniform: u ~ U(0, 1) + u = torch.rand(size=(batch_size,), device=device) + sampling_method = "uniform" + else: + # Density-based sampling + u = compute_density_for_timestep_sampling( + weighting_scheme=timestep_sampling, + batch_size=batch_size, + logit_mean=logit_mean, + logit_std=logit_std, + ).to(device) + sampling_method = timestep_sampling + + # Apply flow shift: σ = shift/(shift + (1/u - 1)) + u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero + sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) + sigma = torch.clamp(sigma, 0.0, 1.0) + + else: + # Simple uniform without shift + u = torch.rand(size=(batch_size,), device=device) + sigma = u + sampling_method = "uniform_no_shift" + + # ======================================================================== + # Manual Flow Matching Noise Addition + # ======================================================================== + + # Generate noise + noise = torch.randn_like(torch.ones([1, 16, grid_sizes[0][0], grid_sizes[0][1]*2, grid_sizes[0][2]*2], device=video_latents.device), dtype=torch.float32) + noise = patchify(noise, (1, 2, 2))[0].unsqueeze(1) + # DEBUGGING + # because video_latents might be padded, we need to make sure noise also be padded to have the same shape + seq_noise = noise.shape[0] + seq_video = video_latents.shape[0] + if seq_noise < seq_video: + pad_len = seq_video - seq_noise + pad = torch.zeros((pad_len, noise.shape[1], noise.shape[2]), device=noise.device, dtype=noise.dtype) + noise = torch.cat([noise, pad], dim=0) + + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) + # x_t = (1 - σ) * x_0 + σ * ε + sigma_reshaped = sigma.view(1, batch_size, 1) + noisy_latents = ( + (1.0 - sigma_reshaped) * video_latents.float() + + sigma_reshaped * noise + ) + + # Timesteps for model [0, 1000] + timesteps = sigma * num_train_timesteps + + # ======================================================================== + # Cast model inputs to bf16 + # ======================================================================== + + video_latents = video_latents.to(torch.bfloat16) + noisy_latents = noisy_latents.to(torch.bfloat16) + context_embeddings = context_embeddings.to(torch.bfloat16) + timesteps = timesteps.to(torch.bfloat16) + + # ======================================================================== + # Split accross context parallelism + # ======================================================================== + + if parallel_state.get_context_parallel_world_size() > 1: + video_latents = split_inputs_cp(video_latents, 0) + noisy_latents = split_inputs_cp(noisy_latents, 0) + noise = split_inputs_cp(noise, 0) + context_embeddings = split_inputs_cp(context_embeddings, 0) + split_loss_mask = split_inputs_cp(loss_mask, 0) + else: + video_latents = video_latents + noisy_latents = noisy_latents + noise = noise + context_embeddings = context_embeddings + split_loss_mask = loss_mask + + + # ======================================================================== + # Forward Pass + # ======================================================================== + + if parallel_state.is_pipeline_last_stage(): + + model_pred = self.model( + x = noisy_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + # ======================================================================== + # Target: Flow Matching Velocity + # ======================================================================== + + # Flow matching target: v = ε - x_0 + target = noise - video_latents.float() + + # ======================================================================== + # Loss with Flow Weighting + # ======================================================================== + + loss = torch.nn.functional.mse_loss( + model_pred.float(), + target.float(), + reduction="none" + ) + + # Flow weight: w = 1 + shift * σ + loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] + loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] + unweighted_loss = loss + weighted_loss = (loss * loss_weight) # shape [seq_length / cp_size, batch_size, -1] + + # Safety check + mean_weighted_loss = weighted_loss.mean() + if torch.isnan(mean_weighted_loss) or mean_weighted_loss > 100: + print(f"[ERROR] Loss explosion! Loss={mean_weighted_loss.item():.3f}") + print(f"[DEBUG] Stopping training - check hyperparameters") + raise ValueError(f"Loss exploded: {mean_weighted_loss.item()}") + + return model_pred, weighted_loss, split_loss_mask + + else: + hidden_states = self.model( + x = noisy_latents, + grid_sizes = grid_sizes, + t = timesteps, + context = context_embeddings, + max_seq_len = max_video_seq_len, + packed_seq_params=packed_seq_params, + ) + + return hidden_states \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py new file mode 100644 index 00000000..56faee4b --- /dev/null +++ b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py @@ -0,0 +1,108 @@ +# time_shift_utils.py - Timestep sampling and sigma computation utilities + +import math +import numpy as np +import torch + + +def time_shift( + t: torch.Tensor, + image_seq_len: int, + shift_type: str = "constant", + base_shift: float = 0.5, + max_shift: float = 1.15, + constant: float = 3.0, +): + """ + Convert timesteps to sigmas with sequence-length-aware shifting. + + Args: + t: timesteps in range [0, 1] + image_seq_len: number of tokens (frames * height * width / patch_size^2) + shift_type: "linear", "sqrt", or "constant" + base_shift: base shift for linear mode + max_shift: max shift for linear mode + constant: shift value for constant mode (default 3.0 matches Pika) + + Returns: + sigma values for noise scheduling + """ + if shift_type == "linear": + # Linear interpolation based on sequence length + mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) + + elif shift_type == "sqrt": + # Square root scaling (Flux-style) + # Assuming 128x128 latent space (1024x1024 image) gives mu=3 + mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) + return mu / (mu + (1 / t - 1)) + + elif shift_type == "constant": + # Constant shift (Pika default) + return constant / (constant + (1 / t - 1)) + + else: + # No shift, return original t + return t + + +def compute_density_for_timestep_sampling( + weighting_scheme: str, + batch_size: int, + logit_mean: float = 0.0, + logit_std: float = 1.0, + mode_scale: float = 1.29, +): + """ + Sample timesteps from different distributions for better training coverage. + + Args: + weighting_scheme: "uniform", "logit_normal", or "mode" + batch_size: number of samples to generate + logit_mean: mean for logit-normal distribution + logit_std: std for logit-normal distribution + mode_scale: scale for mode-based sampling + + Returns: + Tensor of shape (batch_size,) with values in [0, 1] + """ + if weighting_scheme == "logit_normal": + # SD3-style logit-normal sampling + u = torch.normal( + mean=logit_mean, + std=logit_std, + size=(batch_size,), + device="cpu" + ) + u = torch.nn.functional.sigmoid(u) + + elif weighting_scheme == "mode": + # Mode-based sampling (concentrates around certain timesteps) + u = torch.rand(size=(batch_size,), device="cpu") + u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + else: + # Uniform sampling (default) + u = torch.rand(size=(batch_size,), device="cpu") + + return u + + +def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): + """ + Compute loss weights for flow matching based on sigma values. + + Higher sigma (more noise) typically gets higher weight. + + Args: + sigma: sigma values in range [0, 1] + shift: weight scaling factor + + Returns: + Loss weights with same shape as sigma + """ + # Flow matching weight: weight = 1 + shift * sigma + # This gives more weight to noisier timesteps + weight = 1.0 + shift * sigma + return weight \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py new file mode 100644 index 00000000..a28c03c5 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -0,0 +1,52 @@ +import copy +import os + +os.environ['TOKENIZERS_PARALLELISM'] = 'false' + +from .wan_i2v_14B import i2v_14B +from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B + +# the config of t2i_14B is the same as t2v_14B +t2i_14B = copy.deepcopy(t2v_14B) +t2i_14B.__name__ = 'Config: Wan T2I 14B' + +# the config of flf2v_14B is the same as i2v_14B +flf2v_14B = copy.deepcopy(i2v_14B) +flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' +flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt + +WAN_CONFIGS = { + 't2v-14B': t2v_14B, + 't2v-1.3B': t2v_1_3B, + 'i2v-14B': i2v_14B, + 't2i-14B': t2i_14B, + 'flf2v-14B': flf2v_14B, + 'vace-1.3B': t2v_1_3B, + 'vace-14B': t2v_14B, +} + +SIZE_CONFIGS = { + '720*1280': (720, 1280), + '1280*720': (1280, 720), + '480*832': (480, 832), + '832*480': (832, 480), + '1024*1024': (1024, 1024), +} + +MAX_AREA_CONFIGS = { + '720*1280': 720 * 1280, + '1280*720': 1280 * 720, + '480*832': 480 * 832, + '832*480': 832 * 480, +} + +SUPPORTED_SIZES = { + 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2v-1.3B': ('480*832', '832*480'), + 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), + 't2i-14B': tuple(SIZE_CONFIGS.keys()), + 'vace-1.3B': ('480*832', '832*480'), + 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') +} diff --git a/dfm/src/megatron/model/wan/inference/configs/shared_config.py b/dfm/src/megatron/model/wan/inference/configs/shared_config.py new file mode 100644 index 00000000..37d3ae0c --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/configs/shared_config.py @@ -0,0 +1,18 @@ +import torch +from easydict import EasyDict + +#------------------------ Wan shared config ------------------------# +wan_shared_cfg = EasyDict() + +# t5 +wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_dtype = torch.bfloat16 +wan_shared_cfg.text_len = 512 + +# transformer +wan_shared_cfg.param_dtype = torch.bfloat16 + +# inference +wan_shared_cfg.num_train_timesteps = 1000 +wan_shared_cfg.sample_fps = 16 +wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py new file mode 100644 index 00000000..764d2ed8 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py @@ -0,0 +1,35 @@ +import torch +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan I2V 14B ------------------------# + +i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') +i2v_14B.update(wan_shared_cfg) +i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt + +i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +i2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# clip +i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' +i2v_14B.clip_dtype = torch.float16 +i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' +i2v_14B.clip_tokenizer = 'xlm-roberta-large' + +# vae +i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +i2v_14B.vae_stride = (4, 8, 8) + +# transformer +i2v_14B.patch_size = (1, 2, 2) +i2v_14B.dim = 5120 +i2v_14B.ffn_dim = 13824 +i2v_14B.freq_dim = 256 +i2v_14B.num_heads = 40 +i2v_14B.num_layers = 40 +i2v_14B.window_size = (-1, -1) +i2v_14B.qk_norm = True +i2v_14B.cross_attn_norm = True +i2v_14B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py new file mode 100644 index 00000000..c793f7f6 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py @@ -0,0 +1,28 @@ +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 14B ------------------------# + +t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') +t2v_14B.update(wan_shared_cfg) + +# t5 +t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_14B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_14B.vae_stride = (4, 8, 8) + +# transformer +t2v_14B.patch_size = (1, 2, 2) +t2v_14B.dim = 5120 +t2v_14B.ffn_dim = 13824 +t2v_14B.freq_dim = 256 +t2v_14B.num_heads = 40 +t2v_14B.num_layers = 40 +t2v_14B.window_size = (-1, -1) +t2v_14B.qk_norm = True +t2v_14B.cross_attn_norm = True +t2v_14B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py new file mode 100644 index 00000000..c8458ce8 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py @@ -0,0 +1,28 @@ +from easydict import EasyDict + +from .shared_config import wan_shared_cfg + +#------------------------ Wan T2V 1.3B ------------------------# + +t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') +t2v_1_3B.update(wan_shared_cfg) + +# t5 +t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' +t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' + +# vae +t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_1_3B.vae_stride = (4, 8, 8) + +# transformer +t2v_1_3B.patch_size = (1, 2, 2) +t2v_1_3B.dim = 1536 +t2v_1_3B.ffn_dim = 8960 +t2v_1_3B.freq_dim = 256 +t2v_1_3B.num_heads = 12 +t2v_1_3B.num_layers = 30 +t2v_1_3B.window_size = (-1, -1) +t2v_1_3B.qk_norm = True +t2v_1_3B.cross_attn_norm = True +t2v_1_3B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py b/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py new file mode 100644 index 00000000..a38b755c --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py @@ -0,0 +1,858 @@ +# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +# Convert dpm solver for flow matching + +import inspect +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available +from diffusers.utils.torch_utils import randn_tensor + +if is_scipy_available(): + pass + + +def get_sampling_sigmas(sampling_steps, shift): + sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] + sigma = (shift * sigma / (1 + (shift - 1) * sigma)) + + return sigma + + +def retrieve_timesteps( + scheduler, + num_inference_steps=None, + device=None, + timesteps=None, + sigmas=None, + **kwargs, +): + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. This determines the resolution of the diffusion process. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided + sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored + and used in multistep updates. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + shift (`float`, *optional*, defaults to 1.0): + A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling + process. + use_dynamic_shifting (`bool`, defaults to `False`): + Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is + applied on the fly. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent + saturation and improve photorealism. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and + `algorithm_type="dpmsolver++"`. + algorithm_type (`str`, defaults to `dpmsolver++`): + Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The + `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) + paper, and the `dpmsolver++` type implements the algorithms in the + [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or + `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, *optional*, defaults to "zero"): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + lambda_min_clipped (`float`, defaults to `-inf`): + Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the + cosine (`squaredcos_cap_v2`) noise schedule. + variance_type (`str`, *optional*): + Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output + contains the predicted Gaussian variance. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + algorithm_type: str = "dpmsolver++", + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + lambda_min_clipped: float = -float("inf"), + variance_type: Optional[str] = None, + invert_sigmas: bool = False, + ): + if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" + deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", + deprecation_message) + + # settings for DPM-Solver + if algorithm_type not in [ + "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" + ]: + if algorithm_type == "deis": + self.register_to_config(algorithm_type="dpmsolver++") + else: + raise NotImplementedError( + f"{algorithm_type} is not implemented for {self.__class__}") + + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" + ] and final_sigmas_type == "zero": + raise ValueError( + f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." + ) + + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + self._step_index = None + self._begin_index = None + # self.sigmas = self.sigmas.to( + # "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + # DPM-Solver++ needs to solve an integral of the data prediction model. + if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + + # DPM-Solver needs to solve an integral of the noise prediction model. + elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + if self.config.algorithm_type == "dpmsolver++": + x_t = (sigma_t / + sigma_s) * sample - (alpha_t * + (torch.exp(-h) - 1.0)) * model_output + elif self.config.algorithm_type == "dpmsolver": + x_t = (alpha_t / + alpha_s) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * model_output + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + x_t = ((alpha_t / alpha_s) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * model_output + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2211.01095 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * + (alpha_t * (torch.exp(-h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * + (sigma_t * (torch.exp(h) - 1.0)) * D1) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) + elif self.config.algorithm_type == "sde-dpmsolver++": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * + (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.solver_type == "heun": + x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / + (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) + elif self.config.algorithm_type == "sde-dpmsolver": + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - + (sigma_t * (torch.exp(h) - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + elif self.config.solver_type == "heun": + x_t = ((alpha_t / alpha_s0) * sample - 2.0 * + (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + + sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) + return x_t # pyright: ignore + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update + def multistep_dpm_solver_third_order_update( + self, + model_output_list: List[torch.Tensor], + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + """ + One step for the third-order multistep DPMSolver. + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by diffusion process. + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + + timestep_list = args[0] if len(args) > 0 else kwargs.pop( + "timestep_list", None) + prev_timestep = args[1] if len(args) > 1 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 2: + sample = args[2] + else: + raise ValueError( + " missing`sample` as a required keyward argument") + if timestep_list is not None: + deprecate( + "timestep_list", + "1.0.0", + "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( + self.sigmas[self.step_index + 1], # pyright: ignore + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], # pyright: ignore + self.sigmas[self.step_index - 2], # pyright: ignore + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) + + m0, m1, m2 = model_output_list[-1], model_output_list[ + -2], model_output_list[-3] + + h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 + r0, r1 = h_0 / h, h_1 / h + D0 = m0 + D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) + D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) + D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) + if self.config.algorithm_type == "dpmsolver++": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((sigma_t / sigma_s0) * sample - + (alpha_t * (torch.exp(-h) - 1.0)) * D0 + + (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - + (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) + elif self.config.algorithm_type == "dpmsolver": + # See https://arxiv.org/abs/2206.00927 for detailed derivations + x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * + (torch.exp(h) - 1.0)) * D0 - + (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - + (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) + return x_t # pyright: ignore + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + variance_noise: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.Tensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`LEdits++`]. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final or + (self.config.lower_order_final and len(self.timesteps) < 15) or + self.config.final_sigmas_type == "zero") + lower_order_second = ((self.step_index == len(self.timesteps) - 2) and + self.config.lower_order_final and + len(self.timesteps) < 15) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" + ] and variance_noise is None: + noise = randn_tensor( + model_output.shape, + generator=generator, + device=model_output.device, + dtype=torch.float32) + elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: + noise = variance_noise.to( + device=model_output.device, + dtype=torch.float32) # pyright: ignore + else: + noise = None + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update( + model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update( + self.model_outputs, sample=sample, noise=noise) + else: + prev_sample = self.multistep_dpm_solver_third_order_update( + self.model_outputs, sample=sample) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # Cast sample back to expected dtype + prev_sample = prev_sample.to(model_output.dtype) + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + Args: + sample (`torch.Tensor`): + The input sample. + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py b/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py new file mode 100644 index 00000000..8d960583 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py @@ -0,0 +1,801 @@ +# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py +# Convert unipc for flow matching + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import ( + KarrasDiffusionSchedulers, + SchedulerMixin, + SchedulerOutput, +) +from diffusers.utils import deprecate, is_scipy_available + +if is_scipy_available(): + import scipy.stats + + +class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, default `2`): + The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` + due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for + unconditional sampling. + prediction_type (`str`, defaults to "flow_prediction"): + Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts + the flow of the diffusion process. + thresholding (`bool`, defaults to `False`): + Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such + as Stable Diffusion. + dynamic_thresholding_ratio (`float`, defaults to 0.995): + The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. + sample_max_value (`float`, defaults to 1.0): + The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. + predict_x0 (`bool`, defaults to `True`): + Whether to use the updating algorithm on the predicted x0. + solver_type (`str`, default `bh2`): + Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` + otherwise. + lower_order_final (`bool`, default `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + disable_corrector (`list`, default `[]`): + Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` + and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is + usually disabled during the first few steps. + solver_p (`SchedulerMixin`, default `None`): + Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. + use_karras_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, + the sigmas are determined according to a sequence of noise levels {σi}. + use_exponential_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + steps_offset (`int`, defaults to 0): + An offset added to the inference steps, as required by some model families. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "flow_prediction", + shift: Optional[float] = 1.0, + use_dynamic_shifting=False, + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: List[int] = [], + solver_p: SchedulerMixin = None, + timestep_spacing: str = "linspace", + steps_offset: int = 0, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + + if solver_type not in ["bh1", "bh2"]: + if solver_type in ["midpoint", "heun", "logrho"]: + self.register_to_config(solver_type="bh2") + else: + raise NotImplementedError( + f"{solver_type} is not implemented for {self.__class__}") + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + alphas = np.linspace(1, 1 / num_train_timesteps, + num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + self.sigmas = sigmas + self.timesteps = sigmas * num_train_timesteps + + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector + self.solver_p = solver_p + self.last_sample = None + self._step_index = None + self._begin_index = None + + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps + def set_timesteps( + self, + num_inference_steps: Union[int, None] = None, + device: Union[str, torch.device] = None, + sigmas: Optional[List[float]] = None, + mu: Optional[Union[float, None]] = None, + shift: Optional[Union[float, None]] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + Args: + num_inference_steps (`int`): + Total number of the spacing of the time steps. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError( + " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" + ) + + if sigmas is None: + sigmas = np.linspace(self.sigma_max, self.sigma_min, + num_inference_steps + + 1).copy()[:-1] # pyright: ignore + + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore + else: + if shift is None: + shift = self.config.shift + sigmas = shift * sigmas / (1 + + (shift - 1) * sigmas) # pyright: ignore + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = ((1 - self.alphas_cumprod[0]) / + self.alphas_cumprod[0])**0.5 + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + timesteps = sigmas * self.config.num_train_timesteps + sigmas = np.concatenate([sigmas, [sigma_last] + ]).astype(np.float32) # pyright: ignore + + self.sigmas = torch.from_numpy(sigmas) + self.timesteps = torch.from_numpy(timesteps).to( + device=device, dtype=torch.int64) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + self.last_sample = None + if self.solver_p: + self.solver_p.set_timesteps(self.num_inference_steps, device=device) + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to( + "cpu") # to avoid too much CPU/GPU communication + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, *remaining_dims = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float( + ) # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile( + abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + s = s.unsqueeze( + 1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp( + sample, -s, s + ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, *remaining_dims) + sample = sample.to(dtype) + + return sample + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma): + return sigma * self.config.num_train_timesteps + + def _sigma_to_alpha_sigma_t(self, sigma): + return 1 - sigma, sigma + + # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) + + def convert_model_output( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + **kwargs, + ) -> torch.Tensor: + r""" + Convert the model output to the corresponding type the UniPC algorithm needs. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + "missing `sample` as a required keyward argument") + if timestep is not None: + deprecate( + "timesteps", + "1.0.0", + "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + sigma = self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + + if self.predict_x0: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.config.prediction_type == "flow_prediction": + sigma_t = self.sigmas[self.step_index] + epsilon = sample - (1 - sigma_t) * model_output + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," + " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." + ) + + if self.config.thresholding: + sigma_t = self.sigmas[self.step_index] + x0_pred = sample - sigma_t * model_output + x0_pred = self._threshold_sample(x0_pred) + epsilon = model_output + x0_pred + + return epsilon + + def multistep_uni_p_bh_update( + self, + model_output: torch.Tensor, + *args, + sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model at the current timestep. + prev_timestep (`int`): + The previous discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + order (`int`): + The order of UniP at this timestep (corresponds to the *p* in UniPC-p). + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + prev_timestep = args[0] if len(args) > 0 else kwargs.pop( + "prev_timestep", None) + if sample is None: + if len(args) > 1: + sample = args[1] + else: + raise ValueError( + " missing `sample` as a required keyward argument") + if order is None: + if len(args) > 2: + order = args[2] + else: + raise ValueError( + " missing `order` as a required keyward argument") + if prev_timestep is not None: + deprecate( + "prev_timestep", + "1.0.0", + "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + model_output_list = self.model_outputs + + s0 = self.timestep_list[-1] + m0 = model_output_list[-1] + x = sample + + if self.solver_p: + x_t = self.solver_p.step(model_output, s0, x).prev_sample + return x_t + + sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ + self.step_index] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - i # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_p = torch.linalg.solve(R[:-1, :-1], + b[:-1]).to(device).to(x.dtype) + else: + D1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - alpha_t * B_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + pred_res = torch.einsum("k,bkc...->bc...", rhos_p, + D1s) # pyright: ignore + else: + pred_res = 0 + x_t = x_t_ - sigma_t * B_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.Tensor, + *args, + last_sample: torch.Tensor = None, + this_sample: torch.Tensor = None, + order: int = None, # pyright: ignore + **kwargs, + ) -> torch.Tensor: + """ + One step for the UniC (B(h) version). + + Args: + this_model_output (`torch.Tensor`): + The model outputs at `x_t`. + this_timestep (`int`): + The current timestep `t`. + last_sample (`torch.Tensor`): + The generated sample before the last predictor `x_{t-1}`. + this_sample (`torch.Tensor`): + The generated sample after the last predictor `x_{t}`. + order (`int`): + The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. + + Returns: + `torch.Tensor`: + The corrected sample tensor at the current timestep. + """ + this_timestep = args[0] if len(args) > 0 else kwargs.pop( + "this_timestep", None) + if last_sample is None: + if len(args) > 1: + last_sample = args[1] + else: + raise ValueError( + " missing`last_sample` as a required keyward argument") + if this_sample is None: + if len(args) > 2: + this_sample = args[2] + else: + raise ValueError( + " missing`this_sample` as a required keyward argument") + if order is None: + if len(args) > 3: + order = args[3] + else: + raise ValueError( + " missing`order` as a required keyward argument") + if this_timestep is not None: + deprecate( + "this_timestep", + "1.0.0", + "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", + ) + + model_output_list = self.model_outputs + + m0 = model_output_list[-1] + x = last_sample + x_t = this_sample + model_t = this_model_output + + sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ + self.step_index - 1] # pyright: ignore + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + + h = lambda_t - lambda_s0 + device = this_sample.device + + rks = [] + D1s = [] + for i in range(1, order): + si = self.step_index - (i + 1) # pyright: ignore + mi = model_output_list[-(i + 1)] + alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) + lambda_si = torch.log(alpha_si) - torch.log(sigma_si) + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + D1s.append((mi - m0) / rk) # pyright: ignore + + rks.append(1.0) + rks = torch.tensor(rks, device=device) + + R = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.config.solver_type == "bh1": + B_h = hh + elif self.config.solver_type == "bh2": + B_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + R.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / B_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + R = torch.stack(R) + b = torch.tensor(b, device=device) + + if len(D1s) > 0: + D1s = torch.stack(D1s, dim=1) + else: + D1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) + else: + rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if D1s is not None: + corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) + else: + corr_res = 0 + D1_t = model_t - m0 + x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) + x_t = x_t.to(x.dtype) + return x_t + + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + indices = (schedule_timesteps == timestep).nonzero() + + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + pos = 1 if len(indices) > 1 else 0 + + return indices[pos].item() + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step(self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + return_dict: bool = True, + generator=None) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep UniPC. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + use_corrector = ( + self.step_index > 0 and + self.step_index - 1 not in self.disable_corrector and + self.last_sample is not None # pyright: ignore + ) + + model_output_convert = self.convert_model_output( + model_output, sample=sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + last_sample=self.last_sample, + this_sample=sample, + order=self.this_order, + ) + + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep # pyright: ignore + + if self.config.lower_order_final: + this_order = min(self.config.solver_order, + len(self.timesteps) - + self.step_index) # pyright: ignore + else: + this_order = self.config.solver_order + + self.this_order = min(this_order, + self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 # pyright: ignore + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + def scale_model_input(self, sample: torch.Tensor, *args, + **kwargs) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.Tensor`): + The input sample. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + return sample + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.IntTensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to( + device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point( + timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to( + original_samples.device, dtype=torch.float32) + timesteps = timesteps.to( + original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [ + self.index_for_timestep(t, schedule_timesteps) + for t in timesteps + ] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) + noisy_samples = alpha_t * original_samples + sigma_t * noise + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/dfm/src/megatron/model/wan/inference/utils/utils.py b/dfm/src/megatron/model/wan/inference/utils/utils.py new file mode 100644 index 00000000..a57f9bb9 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/utils/utils.py @@ -0,0 +1,117 @@ +import argparse +import binascii +import os +import os.path as osp + +import imageio +import torch +import torchvision + +__all__ = ['cache_video', 'cache_image', 'str2bool'] + + +def rand_name(length=8, suffix=''): + name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') + if suffix: + if not suffix.startswith('.'): + suffix = '.' + suffix + name += suffix + return name + + +def cache_video(tensor, + save_file=None, + fps=30, + suffix='.mp4', + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + cache_file = osp.join('/tmp', rand_name( + suffix=suffix)) if save_file is None else save_file + + # save to cache + error = None + for _ in range(retry): + try: + # preprocess + tensor = tensor.clamp(min(value_range), max(value_range)) + tensor = torch.stack([ + torchvision.utils.make_grid( + u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1).permute(1, 2, 3, 0) + tensor = (tensor * 255).type(torch.uint8).cpu() + + # write video + writer = imageio.get_writer( + cache_file, fps=fps, codec='libx264', quality=8) + for frame in tensor.numpy(): + writer.append_data(frame) + writer.close() + return cache_file + except Exception as e: + error = e + continue + else: + print(f'cache_video failed, error: {error}', flush=True) + return None + + +def cache_image(tensor, + save_file, + nrow=8, + normalize=True, + value_range=(-1, 1), + retry=5): + # cache file + suffix = osp.splitext(save_file)[1] + if suffix.lower() not in [ + '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' + ]: + suffix = '.png' + + # save to cache + error = None + for _ in range(retry): + try: + tensor = tensor.clamp(min(value_range), max(value_range)) + torchvision.utils.save_image( + tensor, + save_file, + nrow=nrow, + normalize=normalize, + value_range=value_range) + return save_file + except Exception as e: + error = e + continue + + +def str2bool(v): + """ + Convert a string to a boolean. + + Supported true values: 'yes', 'true', 't', 'y', '1' + Supported false values: 'no', 'false', 'f', 'n', '0' + + Args: + v (str): String to convert. + + Returns: + bool: Converted boolean value. + + Raises: + argparse.ArgumentTypeError: If the value cannot be converted to boolean. + """ + if isinstance(v, bool): + return v + v_lower = v.lower() + if v_lower in ('yes', 'true', 't', 'y', '1'): + return True + elif v_lower in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected (True/False)') diff --git a/dfm/src/megatron/model/wan/modules/__init__.py b/dfm/src/megatron/model/wan/modules/__init__.py new file mode 100644 index 00000000..435f1eef --- /dev/null +++ b/dfm/src/megatron/model/wan/modules/__init__.py @@ -0,0 +1,13 @@ +from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model +from .tokenizers import HuggingfaceTokenizer +from .vae import WanVAE + + +__all__ = [ + 'WanVAE', + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', + 'HuggingfaceTokenizer', +] diff --git a/dfm/src/megatron/model/wan/modules/t5.py b/dfm/src/megatron/model/wan/modules/t5.py new file mode 100644 index 00000000..fecd989e --- /dev/null +++ b/dfm/src/megatron/model/wan/modules/t5.py @@ -0,0 +1,512 @@ +# Modified from transformers.models.t5.modeling_t5 +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + 'T5Model', + 'T5Encoder', + 'T5Decoder', + 'T5EncoderModel', +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_( + m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh( + math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, + -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum('bnij,bjnc->binc', attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + + def __init__(self, + dim, + dim_attn, + dim_ffn, + num_heads, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) + + def forward(self, + x, + mask=None, + encoder_states=None, + encoder_mask=None, + pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding( + x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn( + self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ + torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( + 0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / + math.log(self.max_dist / max_exact) * + (num_buckets - max_exact)).long() + rel_pos_large = torch.min( + rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + + def __init__(self, + vocab, + dim, + dim_attn, + dim_ffn, + num_heads, + num_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ + else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding( + num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([ + T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, + shared_pos, dropout) for _ in range(num_layers) + ]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), + x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + + def __init__(self, + vocab_size, + dim, + dim_attn, + dim_ffn, + num_heads, + encoder_layers, + decoder_layers, + num_buckets, + shared_pos=True, + dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, encoder_layers, num_buckets, + shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, + num_heads, decoder_layers, num_buckets, + shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, + encoder_only=False, + decoder_only=False, + return_tokenizer=False, + tokenizer_kwargs={}, + dtype=torch.float32, + device='cpu', + **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('encoder_layers') + _ = kwargs.pop('decoder_layers') + elif decoder_only: + model_cls = T5Decoder + kwargs['vocab'] = kwargs.pop('vocab_size') + kwargs['num_layers'] = kwargs.pop('decoder_layers') + _ = kwargs.pop('encoder_layers') + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict( + vocab_size=256384, + dim=4096, + dim_attn=4096, + dim_ffn=10240, + num_heads=64, + encoder_layers=24, + decoder_layers=24, + num_buckets=32, + shared_pos=False, + dropout=0.1) + cfg.update(**kwargs) + return _t5('umt5-xxl', **cfg) + + +class T5EncoderModel: + + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = umt5_xxl( + encoder_only=True, + return_tokenizer=False, + dtype=dtype, + device=device).eval().requires_grad_(False) + logging.info(f'loading {checkpoint_path}') + model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer( + name=tokenizer_path, seq_len=text_len, clean='whitespace') + + def __call__(self, texts, device): + ids, mask = self.tokenizer( + texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/dfm/src/megatron/model/wan/modules/tokenizers.py b/dfm/src/megatron/model/wan/modules/tokenizers.py new file mode 100644 index 00000000..a69972ad --- /dev/null +++ b/dfm/src/megatron/model/wan/modules/tokenizers.py @@ -0,0 +1,81 @@ +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + +__all__ = ['HuggingfaceTokenizer'] + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace('_', ' ') + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join( + part.translate(str.maketrans('', '', string.punctuation)) + for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans('', '', string.punctuation)) + text = text.lower() + text = re.sub(r'\s+', ' ', text) + return text.strip() + + +class HuggingfaceTokenizer: + + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, 'whitespace', 'lower', 'canonicalize') + self.name = name + self.seq_len = seq_len + self.clean = clean + + # init tokenizer + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop('return_mask', False) + + # arguments + _kwargs = {'return_tensors': 'pt'} + if self.seq_len is not None: + _kwargs.update({ + 'padding': 'max_length', + 'truncation': True, + 'max_length': self.seq_len + }) + _kwargs.update(**kwargs) + + # tokenization + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(u) for u in sequence] + ids = self.tokenizer(sequence, **_kwargs) + + # output + if return_mask: + return ids.input_ids, ids.attention_mask + else: + return ids.input_ids + + def _clean(self, text): + if self.clean == 'whitespace': + text = whitespace_clean(basic_clean(text)) + elif self.clean == 'lower': + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == 'canonicalize': + text = canonicalize(basic_clean(text)) + return text diff --git a/dfm/src/megatron/model/wan/modules/vae.py b/dfm/src/megatron/model/wan/modules/vae.py new file mode 100644 index 00000000..d4f1ef1d --- /dev/null +++ b/dfm/src/megatron/model/wan/modules/vae.py @@ -0,0 +1,662 @@ +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + 'WanVAE', +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = (self.padding[2], self.padding[2], self.padding[1], + self.padding[1], 2 * self.padding[0], 0) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. + + def forward(self, x): + return F.normalize( + x, dim=(1 if self.channel_first else + -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + + def __init__(self, dim, mode): + assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', + 'downsample3d') + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == 'upsample2d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + elif mode == 'upsample3d': + self.resample = nn.Sequential( + Upsample(scale_factor=(2., 2.), mode='nearest-exact'), + nn.Conv2d(dim, dim // 2, 3, padding=1)) + self.time_conv = CausalConv3d( + dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + + elif mode == 'downsample2d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == 'downsample3d': + self.resample = nn.Sequential( + nn.ZeroPad2d((0, 1, 0, 1)), + nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d( + dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == 'upsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = 'Rep' + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] != 'Rep': + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + if cache_x.shape[2] < 2 and feat_cache[ + idx] is not None and feat_cache[idx] == 'Rep': + cache_x = torch.cat([ + torch.zeros_like(cache_x).to(cache_x.device), + cache_x + ], + dim=2) + if feat_cache[idx] == 'Rep': + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), + 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.resample(x) + x = rearrange(x, '(b t) c h w -> b c t h w', t=t) + + if self.mode == 'downsample3d': + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + + cache_x = x[:, :, -1:, :, :].clone() + # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': + # # cache last frame of last two chunk + # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) + + x = self.time_conv( + torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 + conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) + conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix + conv.weight.data.copy_(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1)) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ + if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, 'b c t h w -> (b t) c h w') + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, + -1).permute(0, 1, 3, + 2).contiguous().chunk( + 3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, '(b t) c h w-> b c t h w', t=t) + return x + identity + + +class Encoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + for _ in range(num_res_blocks): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + downsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # downsample block + if i != len(dim_mult) - 1: + mode = 'downsample3d' if temperal_downsample[ + i] else 'downsample2d' + downsamples.append(Resample(out_dim, mode=mode)) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout)) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +class Decoder3d(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2**(len(dim_mult) - 2) + + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout)) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + # residual (+attention) blocks + if i == 1 or i == 2 or i == 3: + in_dim = in_dim // 2 + for _ in range(num_res_blocks + 1): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + if scale in attn_scales: + upsamples.append(AttentionBlock(out_dim)) + in_dim = out_dim + + # upsample block + if i != len(dim_mult) - 1: + mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' + upsamples.append(Resample(out_dim, mode=mode)) + scale *= 2.0 + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), nn.SiLU(), + CausalConv3d(out_dim, 3, 3, padding=1)) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + ## conv1 + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat([ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( + cache_x.device), cache_x + ], + dim=2) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + + def __init__(self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, + attn_scales, self.temperal_downsample, dropout) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, + attn_scales, self.temperal_upsample, dropout) + + def forward(self, x): + mu, log_var = self.encode(x) + z = self.reparameterize(mu, log_var) + x_recon = self.decode(z) + return x_recon, mu, log_var + + def encode(self, x, scale): + self.clear_cache() + ## cache + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + ## 对encode输入的x,按时间拆分为1、4、4、4.... + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( + 1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + # z: [b,c,t,h,w] + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( + 1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + else: + out_ = self.decoder( + x[:, :, i:i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx) + out = torch.cat([out, out_], 2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + #cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): + """ + Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. + """ + # params + cfg = dict( + dim=96, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[False, True, True], + dropout=0.0) + cfg.update(**kwargs) + + # init model + with torch.device('meta'): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f'loading {pretrained_path}') + model.load_state_dict( + torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class WanVAE: + + def __init__(self, + z_dim=16, + vae_pth='cache/vae_step_411000.pth', + dtype=torch.float, + device="cuda"): + self.dtype = dtype + self.device = device + + mean = [ + -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, + 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 + ] + std = [ + 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, + 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 + ] + self.mean = torch.tensor(mean, dtype=dtype, device=device) + self.std = torch.tensor(std, dtype=dtype, device=device) + self.scale = [self.mean, 1.0 / self.std] + + # init model + self.model = _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + ).eval().requires_grad_(False).to(device) + + def encode(self, videos): + """ + videos: A list of videos each with shape [C, T, H, W]. + """ + with amp.autocast(dtype=self.dtype): + return [ + self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) + for u in videos + ] + + def decode(self, zs): + with amp.autocast(dtype=self.dtype): + return [ + self.model.decode(u.unsqueeze(0), + self.scale).float().clamp_(-1, 1).squeeze(0) + for u in zs + ] diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py new file mode 100644 index 00000000..93d0e933 --- /dev/null +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -0,0 +1,65 @@ +import torch +from torch.cuda import amp +from megatron.bridge.models.wan.utils.utils import split_inputs_cp +from megatron.core import parallel_state + +class Wan3DRopeEmbeddings(torch.nn.Module): + """ + Wan 3D RoPE embeddings implementation. + Implements Wan's 3D RoPE embeddings for Mcore Attention based on Wan's implementation at https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py. + """ + + def __init__(self, dim_head, max_position_len): + super().__init__() + self.freqs = torch.cat([ + self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)) + ], dim=1) + + def rope_params(self, max_position_len, dim_head, theta=10000): + assert dim_head % 2 == 0 + freqs = torch.outer( + torch.arange(max_position_len), + 1.0 / torch.pow(theta, + torch.arange(0, dim_head, 2).div(dim_head))) + return freqs + + def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): + self.freqs = self.freqs.to(device) # ??? do we need to put this here, or the when we move WanModel to device, it also move freqs to device? + + n, c = n_head, dim_head // 2 + + # split freqs + freqs = self.freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + freqs_real = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + freqs_real_i = torch.cat([ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) + ], dim=-1).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting + + # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE + freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) + + # Pad freqs_real_i to (max_seq_len, 1, 1, dim_head) with 0s + if freqs_real_i.shape[0] < max_seq_len: + pad_shape = (max_seq_len - freqs_real_i.shape[0], 1, 1, dim_head) + freqs_real_i = torch.cat( + [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)] + ) + freqs_real.append(freqs_real_i) + + # Each freqs_real[i] is (max_seq_len, 1, 1, dim_head) + # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) + freqs_real = torch.cat(freqs_real, dim=1) + + # TODO: if run context/sequence related parallel, then we need to scatter + # the freqs_real to the context parallel region, using specific cp_rank split method + if parallel_state.get_context_parallel_world_size() > 1: + freqs_real = split_inputs_cp(freqs_real, 0) + + return freqs_real \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/utils/utils.py b/dfm/src/megatron/model/wan/utils/utils.py new file mode 100644 index 00000000..8551c6fc --- /dev/null +++ b/dfm/src/megatron/model/wan/utils/utils.py @@ -0,0 +1,128 @@ +import torch +from typing import Tuple +from torch.distributed import all_gather +import megatron.core.parallel_state as parallel_state +import math + +def grid_sizes_calculation( + input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) + patch_size: Tuple[int, int, int], # (pF, pH, pW) +) -> Tuple[int, int, int]: + """ + Compute the (f,h,w) output spatial/temporal dimensions of a Conv3d patch embedder. + """ + + F_latents, H_latents, W_latents = input_shape + pF, pH, pW = patch_size + F_patches = F_latents // pF + H_patches = H_latents // pH + W_patches = W_latents // pW + + return [F_patches, H_patches, W_patches] + + +def patchify(x, patch_size): + """ + Convert a list of reconstructed video tensor into patch embeddings. + This method is the inverse of `unpatchify`. + + Args: + x (list[torch.Tensor]): list of tensors, each with shape [c, F_patches * pF, H_patches * pH, W_patches * pW] + patch_size (tuple): (pF, pH, pW) + + Returns: + torch.Tensor: shape [ (F_patches * H_patches * W_patches), (c * pF * pH * pW)], + """ + out = [] + for u in x: + c, F_pF, H_pH, W_pW = u.shape + pF, pH, pW = patch_size + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ + "Spatial dimensions must be divisible by patch size." + + F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW + + # split spatial dims into (grid, patch) and reorder to match original patch layout: + # start: (c, F_patches * pF, H_patches * pH, W_patches * pW) + # reshape -> (c, F_patches, pF, H_patches, pH, W_patches, pW) + # permute -> (F_patches, H_patches, W_patches, pF, pH, pW, c) + t = u.reshape(c, F_patches, pF, H_patches, pH, W_patches, pW) + t = t.permute(1, 3, 5, 2, 4, 6, 0) + + num_patches = F_patches * H_patches * W_patches + out.append(t.reshape(num_patches, c * (pF * pH * pW))) + return out + + +def unpatchify(x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int]) -> list[torch.Tensor]: + """ + Reconstruct video tensors from patch embeddings into a list of videotensors. + + Args: + x (list[torch.Tensor]): + list of tensors, each with shape [seq_len, c * pF * pH * pW] + grid_sizes (list[Tuple[int, int, int]]): + list of tensors, each with original spatial-temporal grid dimensions before patching, + (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] + """ + + c = out_dim + out = [] + for u, v in zip(x, grid_sizes): + u = u[:math.prod(v)].view(*v, *patch_size, c) + u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u.reshape(c, *[i * j for i, j in zip(v, patch_size)]) + out.append(u) + return out + + +def split_inputs_cp(x: torch.Tensor, seq_dim: int = 0) -> torch.Tensor: + """ + Split input tensor along the sequence dimension for context parallelism. + + Args: + x: Input tensor to be split. (e.g. shape [seq_len, batch_size, ...]) + seq_dim: The dimension along which to split the input (sequence dimension). + + Returns: + A slice of the input tensor corresponding to the current rank. (e.g. shape [seq_len/cp_size, batch_size, ...]) + """ + + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" + x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) + seq_idx = torch.tensor([cp_rank], device=x.device) + x = x.index_select(seq_dim, seq_idx) + # Note that the new sequence length is the original sequence length / cp_size + x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) + return x + + +def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: + """ + Concatenate tensors from multiple processes along a specified dimension. + + Args: + x: Input tensor to be concatenated. (e.g. shape [seq_len/cp_size, batch_size, ...]) + seq_dim: The dimension along which to concatenate the input tensors. + + Returns: + A tensor with the concatenated tensors. (e.g. shape [seq_len, batch_size, ...]) + """ + + cp_group = parallel_state.get_context_parallel_group() + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] + # Attempt to gather tensors from all ranks + # PyTorch’s all_gather orders outputs by rank within the group, which matches how chunks were selected by cp_rank + all_gather(gathered_tensors, x, group=cp_group) + gathered_tensors = torch.cat(gathered_tensors, dim=seq_dim) + return gathered_tensors + else: + return x diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py new file mode 100644 index 00000000..f98576ad --- /dev/null +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -0,0 +1,591 @@ + +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import copy +from dataclasses import dataclass +from typing import Union, Optional + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.transformer.attention import ( + CrossAttention, + CrossAttentionSubmodules, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.custom_layers.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TERowParallelLinear, +) +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.module import MegatronModule +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules +from megatron.core.utils import make_viewless_tensor +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.extensions.transformer_engine import TENorm + +try: + import transformer_engine # pylint: disable=unused-import + + HAVE_TE = True + from megatron.core.extensions.transformer_engine import SplitAlongDim + +except ImportError: + HAVE_TE = False + SplitAlongDim = None + + +class WanLayerNorm(nn.LayerNorm): + + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x).type_as(x) + + +@dataclass +class WanSelfAttentionSubmodules: + """ + Configuration class for specifying the submodules of a self-attention. + """ + + linear_qkv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +@dataclass +class WanCrossAttentionSubmodules: + """ + Configuration class for specifying the submodules of a cross-attention. + """ + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + layernorm_across_head: bool = False + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +class WanSelfAttention(SelfAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanSelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class WanCrossAttention(CrossAttention): + def __init__( + self, + config: TransformerConfig, + submodules: WanCrossAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_head = submodules.layernorm_across_head + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_head: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_head: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + import transformer_engine as te + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + # gather query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_head: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_head: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_head is True + if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + return query, key, value + + +@dataclass +class WanWithAdaLNSubmodules(TransformerLayerSubmodules): + temporal_self_attention: Union[ModuleSpec, type] = IdentityOp + full_self_attention: Union[ModuleSpec, type] = IdentityOp + norm1: Union[ModuleSpec, type] = None + norm3: Union[ModuleSpec, type] = None + norm2: Union[ModuleSpec, type] = None + + +class WanAdaLN(MegatronModule): + """ + Adaptive Layer Normalization Module for DiT. + """ + + def __init__( + self, config: TransformerConfig + ): + super().__init__(config) + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) + + setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + + def forward(self, timestep_emb): + e = (self.modulation + timestep_emb).chunk(6, dim=1) + return e + + # @jit_fuser + def modulate(self, x, shift, scale): + return x * (1 + scale) + shift + + # @jit_fuser + def scale_add(self, residual, x, gate): + return residual + gate * x + + +class WanLayerWithAdaLN(TransformerLayer): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + + DiT with Adapative Layer Normalization. + """ + + def __init__( + self, + config: TransformerConfig, + submodules: TransformerLayerSubmodules, + layer_number: int = 1, + hidden_dropout: float = None, + pg_collection: Optional[ProcessGroupCollection] = None, + vp_stage: Optional[int] = None, + ): + super().__init__( + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout + ) + + # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? + # # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. + # cp_override_config = copy.deepcopy(config) + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + + self.full_self_attention = build_module( + submodules.full_self_attention, + config=self.config, + layer_number=layer_number, + ) + + self.adaLN = WanAdaLN(config=self.config) + self.norm1 = build_module( + submodules.norm1, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False + ) + self.norm3 = build_module( + submodules.norm3, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=True, + ) + self.norm2 = build_module( + submodules.norm2, + dim=config.hidden_size, + eps=config.layernorm_epsilon, + elementwise_affine=False, + ) + + + def forward( + self, + hidden_states, + attention_mask=None, + context=None, + context_mask=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + attention_bias=None, + inference_params=None, + packed_seq_params=None, + sequence_len_offset=None, + inference_context=None, + ): + # the timestep embedding is stored in attention_mask argument + timestep_emb = attention_mask + rope_emb = rotary_pos_emb + + shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) + # transpose to bring it to [1, b, ...] format + shift_full = shift_full.transpose(0, 1) + scale_full = scale_full.transpose(0, 1) + gate_full = gate_full.transpose(0, 1) + shift_mlp = shift_mlp.transpose(0, 1) + scale_mlp = scale_mlp.transpose(0, 1) + gate_mlp = gate_mlp.transpose(0, 1) + + # ******************************************** full self attention ******************************************* + + # adaLN with scale + shift + gate + pre_full_attn_layernorm_output_ada = self.adaLN.modulate( + self.norm1(hidden_states), + shift=shift_full, + scale=scale_full, + ) + + attention_output, bias = self.full_self_attention( + pre_full_attn_layernorm_output_ada, + attention_mask=None, + rotary_pos_emb=rope_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + packed_seq_params=packed_seq_params['self_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=attention_output, gate=gate_full) + + # ******************************************** cross attention ****************************************************** + + attention_output, bias = self.cross_attention( + self.norm3(hidden_states), + attention_mask=context_mask, + key_value_states=context, + packed_seq_params=packed_seq_params['cross_attention'], + ) + if bias is not None: + attention_output = attention_output + bias + + hidden_states = hidden_states + attention_output + + # ******************************************** mlp ****************************************************** + + pre_mlp_layernorm_output_ada = self.adaLN.modulate( + self.norm2(hidden_states), + shift=shift_mlp, + scale=scale_mlp, + ) + + mlp_output, bias = self.mlp(pre_mlp_layernorm_output_ada) + if bias is not None: + mlp_output = mlp_output + bias + + hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) + + # TODO: Jit compiled function creates 'view' tensor. This tensor + # potentially gets saved in the MPU checkpoint function context, + # which rejects view tensors. While making a viewless tensor here + # won't result in memory savings (like the data loader, or + # p2p_communication), it serves to document the origin of this + # 'view' tensor. ??? + output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True) + # output = hidden_states + + return output, context + + +import transformer_engine as te +def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: + params = {"attn_mask_type": AttnMaskType.padding} + return ModuleSpec( + module=WanLayerWithAdaLN, + submodules=WanWithAdaLNSubmodules( + norm1=WanLayerNorm, + norm3=WanLayerNorm, + norm2=WanLayerNorm, + full_self_attention=ModuleSpec( + module=WanSelfAttention, + params=params, + submodules=WanSelfAttentionSubmodules( + linear_qkv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + cross_attention=ModuleSpec( + module=WanCrossAttention, + params=params, + submodules=WanCrossAttentionSubmodules( + linear_q=TEColumnParallelLinear, + linear_kv=TEColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + layernorm_across_head=True, + q_layernorm=TENorm, + k_layernorm=TENorm, + ), + ), + mlp=ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear, + # by default, activation_func is openai_gelu, which is equivalent to nn.GELU(approximate='tanh') + linear_fc2=TERowParallelLinear, + ), + ), + ), + ) diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py new file mode 100644 index 00000000..d11b7803 --- /dev/null +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from typing import Dict, Literal, Optional, Tuple, List, Union + +import math +import torch +import torch.cuda.amp as amp +import torch.nn as nn +from megatron.core import parallel_state, tensor_parallel +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.enums import ModelType +from megatron.core.transformer.transformer_block import TransformerBlock +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint +from megatron.bridge.models.wan.wan_layer_spec import ( + get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, +) +from megatron.bridge.models.wan.wan_layer_spec import WanLayerNorm +from torch import Tensor +from .rope_utils import Wan3DRopeEmbeddings + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position + + # calculation + sinusoid = torch.outer( + position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + +class Head(nn.Module): + + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, C] + """ + e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) + x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + return x + + +class WanModel(VisionModule): + """ + WanModel is a VisionModule that implements a Wan model. + Attributes: + config (TransformerConfig): Configuration for the transformer. + pre_process (bool): Whether to apply pre-processing steps. + post_process (bool): Whether to apply post-processing steps. + fp16_lm_cross_entropy (bool): Whether to use fp16 for cross-entropy loss. + parallel_output (bool): Whether to use parallel output. + transformer_decoder_layer_spec (WanLayerWithAdaLNspec): Specification for the transformer decoder layer. + model_type (ModelType): Type of the model. + """ + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + fp16_lm_cross_entropy: bool = False, + parallel_output: bool = True, + transformer_decoder_layer_spec=WanLayerWithAdaLNspec, + **kwargs, + ): + super(WanModel, self).__init__(config=config) + + self.config: TransformerConfig = config + + self.transformer_decoder_layer_spec = transformer_decoder_layer_spec() + self.pre_process = pre_process + self.post_process = post_process + self.fp16_lm_cross_entropy = fp16_lm_cross_entropy + self.parallel_output = parallel_output + + # megatron core pipelining currently depends on model type + # TODO: remove this dependency ? + self.model_type = ModelType.encoder_or_decoder + + self.num_heads = self.config.num_attention_heads + self.freq_dim = self.config.freq_dim + self.in_channels = self.config.in_channels + self.out_channels = self.config.out_channels + self.patch_spatial = self.config.patch_spatial + self.patch_temporal = self.config.patch_temporal + self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) + + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + self.share_embeddings_and_output_weights = False + + ###################################### + ########## Wan architecture ########## + + # embeddings + if self.pre_process: + self.patch_embedding = nn.Conv3d( + self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + + self.text_embedding = nn.Sequential( + nn.Linear(self.config.text_dim, self.config.hidden_size), nn.GELU(approximate='tanh'), + nn.Linear(self.config.hidden_size, self.config.hidden_size)) + + self.time_embedding = nn.Sequential( + nn.Linear(self.freq_dim, self.config.hidden_size), nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size * 6)) + + self.rope_embeddings = Wan3DRopeEmbeddings(dim_head = self.config.hidden_size // self.num_heads, max_position_len = 1024) + + # decoder blocks + self.decoder = TransformerBlock( + config=self.config, + spec=self.transformer_decoder_layer_spec, + pre_process=self.pre_process, + post_process=self.post_process, + post_layer_norm=False, + ) + + # output head + if self.post_process: + self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps = 1e-6) + + + def forward( + self, + x: Tensor, + grid_sizes: list[Tuple[int, int, int]], + t: Tensor, + context: Tensor, + max_seq_len: int, + packed_seq_params: PackedSeqParams = None, + **kwargs, + ) -> Tensor: + """Forward pass. + + Args: + x List[Tensor]: list of vae encoded data (in_channel, f, h, w) + grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) + t Tensor: timesteps + context List[Tensor]: list of context (text_len, hidden_size) + max_seq_len int: maximum sequence length + packed_seq_params PackedSeqParams: packed sequence parameters + + Returns: + Tensor: output tensor (still patchified) of shape [seq_len, batch_size, hidden_size] + """ + ################################# + ########## Wan forward ########## + + # ============= embedders ============= + + # run input embedding + if self.pre_process: + # x.shape [s, b, c * pF * pH * pW] + seq_len, batch_size, _ = x.shape + c = self.out_channels + pF, pH, pW = self.patch_size + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + + # split sequence for sequence_parallel + # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? + if self.config.sequence_parallel: + x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + + else: + # intermediate stage of pipeline + x = self.decoder.input_tensor + + # time embeddings + e = self.time_embedding( + sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) + ) + e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + + # context embeddings + context = self.text_embedding(context) # shape [text_len, b, hidden_size] + + + # ============= decoder ============= + # calculate rotary pos emb + n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + + # run decoder + x = self.decoder( + hidden_states=x, + attention_mask=e0, + context=context, + context_mask=None, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=None, + rotary_pos_sin=None, + packed_seq_params=packed_seq_params, + ) + + # return if not post_process + if not self.post_process: + return x + + # head + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + + # gather outputs for sequence_parallel + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # automatically gathered in ColumnParallelLinear forward pass. + # However, in Wan models, we need to gather the outputs manually. + if self.config.sequence_parallel: + x = tensor_parallel.gather_from_sequence_parallel_region(x) + + return x # output: x.shape [s, b, c * pF * pH * pW] + + + def set_input_tensor(self, input_tensor: Tensor) -> None: + """Sets input tensor to the model. + + See megatron.model.transformer.set_input_tensor() + + Args: + input_tensor (Tensor): Sets the input tensor for the model. + """ + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + + assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" + self.decoder.set_input_tensor(input_tensor[0]) + + + def sharded_state_dict( + self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None + ) -> ShardedStateDict: + """Sharded state dict implementation for GPTModel backward-compatibility (removing extra state). + + Args: + prefix (str): Module name prefix. + sharded_offsets (tuple): PP related offsets, expected to be empty at this module level. + metadata (Optional[Dict]): metadata controlling sharded state dict creation. + + Returns: + ShardedStateDict: sharded state dict for the GPTModel + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + # DEBUGGING + # for module in ["t_embedder"]: + # for param_name, param in getattr(self, module).named_parameters(): + # weight_key = f"{prefix}{module}.{param_name}" + # self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + # DEBUGGING + # Ensure replica ids for non-transformer embedder weights include pipeline dimension + for module in ["text_embedding", "time_embedding", "time_projection"]: + if hasattr(self, module): + for param_name, param in getattr(self, module).named_parameters(): + weight_key = f"{prefix}{module}.{param_name}" + if weight_key in sharded_state_dict: + self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) + + return sharded_state_dict + + + def _set_embedder_weights_replica_id( + self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str + ) -> None: + """set replica ids of the weights in t_embedder for sharded state dict. + + Args: + sharded_state_dict (ShardedStateDict): state dict with the weight to tie + weight_key (str): key of the weight in the state dict. + This entry will be replaced with a tied version + + Returns: None, acts in-place + """ + tp_rank = parallel_state.get_tensor_model_parallel_rank() + vpp_rank = parallel_state.get_virtual_pipeline_model_parallel_rank() + vpp_rank = vpp_rank if vpp_rank else 0 + vpp_world = parallel_state.get_virtual_pipeline_model_parallel_world_size() + vpp_world = vpp_world if vpp_world else 1 + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + del sharded_state_dict[embedder_weight_key] + replica_id = ( + tp_rank, + (vpp_rank + pp_rank * vpp_world), + parallel_state.get_data_parallel_rank(with_context_parallel=True), + ) + + sharded_state_dict[embedder_weight_key] = make_sharded_tensor_for_checkpoint( + tensor=tensor, + key=embedder_weight_key, + replica_id=replica_id, + allow_shape_mismatch=False, + ) From 2ebfd503e04c9d6501839d179284c343f466dd09 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 30 Oct 2025 13:28:41 -0700 Subject: [PATCH 02/80] workable code --- dfm/src/megatron/data/Dit/base.py | 343 +++++++++++++++ dfm/src/megatron/data/Dit/data/__init__.py | 13 + .../Dit/data/diffusion_energon_datamodule.py | 176 ++++++++ .../data/Dit/data/diffusion_taskencoder.py | 256 +++++++++++ .../data/Dit/data/prepare_energon_dataset.py | 117 +++++ dfm/src/megatron/data/Dit/data/utils.py | 203 +++++++++ .../{ => wan}/prepare_energon_dataset_wan.py | 0 .../wan/prepare_energon_dataset_wan_images.py | 412 ++++++++++++++++++ .../data/{ => wan}/wan_energon_datamodule.py | 4 +- .../data/{ => wan}/wan_taskencoder.py | 72 +-- .../flow_matching/flow_inference_pipeline.py | 14 +- .../model/wan/flow_matching/flow_pipeline.py | 4 +- dfm/src/megatron/model/wan/rope_utils.py | 2 +- dfm/src/megatron/model/wan/wan_model.py | 4 +- 14 files changed, 1570 insertions(+), 50 deletions(-) create mode 100644 dfm/src/megatron/data/Dit/base.py create mode 100644 dfm/src/megatron/data/Dit/data/__init__.py create mode 100644 dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py create mode 100644 dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py create mode 100644 dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py create mode 100644 dfm/src/megatron/data/Dit/data/utils.py rename dfm/src/megatron/data/{ => wan}/prepare_energon_dataset_wan.py (100%) create mode 100644 dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py rename dfm/src/megatron/data/{ => wan}/wan_energon_datamodule.py (91%) rename dfm/src/megatron/data/{ => wan}/wan_taskencoder.py (82%) diff --git a/dfm/src/megatron/data/Dit/base.py b/dfm/src/megatron/data/Dit/base.py new file mode 100644 index 00000000..413dc686 --- /dev/null +++ b/dfm/src/megatron/data/Dit/base.py @@ -0,0 +1,343 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +from megatron.core import parallel_state +from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset +from torch.utils.data import DataLoader +from typing_extensions import Self +import logging +logger = logging.getLogger(__name__) + + +class EnergonMultiModalDataModule: + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + tokenizer, + image_processor, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 1, + num_workers: int = 1, + num_val_workers: int | None = None, + pin_memory: bool = True, + shuffle_buffer_size: int = 100, + max_samples_per_sequence: int | None = None, + multimodal_sample_config: Optional[Any] = None, + task_encoder: Optional[Any] = None, + decoder_seq_length: Optional[int] = None, + packing_buffer_size: Optional[int] = None, + validation_task_encoder: Optional[Any] = None, + **kwargs, + ) -> None: + """ + Initialize the EnergonMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. + max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. + Defaults to None (loads the whole tar file at once). + task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. + If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. + decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models + packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. + validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding + and batching samples for validation. Defaults to None and will be the same as task_encoder. + **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon + """ + + super().__init__() + self.path = path + self.tokenizer = tokenizer + self.image_processor = image_processor + self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.multimodal_sample_config = multimodal_sample_config + self.shuffle_buffer_size = shuffle_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence + self.task_encoder = task_encoder + self.init_global_step = 0 + self.train_dataloader_object = None + self.val_dataloader_object = None + self.packing_buffer_size = packing_buffer_size + self.validation_task_encoder = validation_task_encoder or self.task_encoder + self.num_val_workers = num_val_workers or self.num_workers + self.kwargs = kwargs + + + def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + + if split not in {'train', 'val'}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + + if split == "train": + task_encoder = self.task_encoder + else: + task_encoder = self.validation_task_encoder + + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + task_encoder=task_encoder, + worker_config=worker_config, + packing_buffer_size=self.packing_buffer_size, + split_part=split, + shuffle_buffer_size=self.shuffle_buffer_size, + max_samples_per_sequence=self.max_samples_per_sequence, + **self.kwargs, + ) + + return _dataset + + def build(self): + return self.train_dataloader(), self.val_dataloader() + + def train_dataloader(self) -> Any: + """ + Initialize and return the training DataLoader. + + This method initializes the DataLoader for the training dataset. It uses the global step + from the trainer to configure the data sampler and ensures that the parallel state is initialized + correctly for distributed training. + + Returns: + TRAIN_DATALOADERS: The DataLoader for the training dataset. + """ + + logger.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") + if self.train_dataloader_object: + return self.train_dataloader_object + if not parallel_state.is_initialized(): + logger.info( + f"Muiltimodal data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + logger.info( + f" Multimodal train dataloader initializing with" + f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " + ) + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + train_dataset = self.datasets_provider(worker_config, split='train') + energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) + self.train_dataloader_object = energon_dataloader + return self.train_dataloader_object + + def val_dataloader(self): + """ + Initialize and return the validation DataLoader. + + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. + + Returns: + EVAL_DATALOADERS: The DataLoader for the validation dataset. + """ + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + logger.info( + f"Muiltimodal val data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_val_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logger.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split='val') + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object + + def test_dataloader(self) -> None: + """ + Return None as test dataset split does not exist. + + This method overrides the test_dataloader method and returns None since the test dataset split + is not defined or used in this module. + + Returns: + None + """ + logger.warning("Multimodal dataloader test dataset split does not exist") + return None + + def state_dict(self) -> Dict[str, Any]: + """ + Save the state of the data module. + + This method is called when saving a checkpoint. It generates and saves the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Returns: + Dict[str, Any]: A dictionary containing the state of the data module. + """ + + if self.trainer: + dataloader_obj = self.trainer.train_dataloader + + state = [] + # All ranks should be zero except the dp rank. + if ( + parallel_state.get_context_parallel_rank() + or parallel_state.get_pipeline_model_parallel_rank() + or parallel_state.get_tensor_model_parallel_rank() + or parallel_state.get_expert_model_parallel_rank() + ) == 0: + # Save_state_global in energon assumes that we call it for only the first rank within each group that + # shares the same dataloader state. By making sure that current rank is the first rank in a model + # parallel group, we ensure this. + state = dataloader_obj.save_state_global(global_dst_rank=0) + + consumed_samples = self.data_sampler.compute_consumed_samples( + self.trainer.global_step - self.init_global_step + ) + + if state is None: + state = [] # Megatron core requires all the states on all the ranks to have same python + # type. Energon sends the state as a list + logger.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") + return {'dataloader_state': state, 'consumed_samples': consumed_samples} + + logger.warning("trainer object not connected to data module object returning empty state") + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + if not 'dataloader_state' in state_dict: + logger.warning( + f"Data loader state cannot be resumed from state_dict, " + f"it does not have the required key dataloader_state. It has {state_dict.keys()}" + ) + return + + state = state_dict['dataloader_state'] + try: + if self.trainer: + self.trainer.datamodule.train_dataloader().restore_state_global(state) + logger.info("Multimodal dataloader state restored") + else: + logger.error(f"Cannot restore state from state_dict {state_dict}") + raise ValueError( + "Cannot restore state from state_dict: " + "Is the trainer object is initialized and attached to datamodule???" + ) + except Exception as e: + logger.warning( + f"Failed to dataloader restore state due to [Please ensure you are using same version " + f"of energon while saving and loading, Continuing without restoring data loader] : {e}" + ) + + try: + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logger.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + + consumed_samples = state_dict['consumed_samples'] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + logger.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") + update_num_microbatches( + consumed_samples=consumed_samples, + consistency_check=False, + ) + + diff --git a/dfm/src/megatron/data/Dit/data/__init__.py b/dfm/src/megatron/data/Dit/data/__init__.py new file mode 100644 index 00000000..d9155f92 --- /dev/null +++ b/dfm/src/megatron/data/Dit/data/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py b/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py new file mode 100644 index 00000000..fa38e9c6 --- /dev/null +++ b/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import logging +from typing import Any, Dict, Literal + +from torch import int_repr + +from megatron.bridge.data.Dit.data.diffusion_taskencoder import BasicDiffusionTaskEncoder +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from megatron.energon import DefaultTaskEncoder, get_train_dataset +from megatron.bridge.data.Dit.base import EnergonMultiModalDataModule + +@dataclass(kw_only=True) +class DiffusionDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + task_encoder_seq_length: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=BasicDiffusionTaskEncoder(seq_length=self.task_encoder_seq_length), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + + + + +class DiffusionDataModule(EnergonMultiModalDataModule): + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 8, + num_workers: int = 1, + pin_memory: bool = True, + task_encoder: DefaultTaskEncoder = None, + use_train_split_for_val: bool = False, + ) -> None: + """ + Initialize the SimpleMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + """ + + super().__init__( + path=path, + tokenizer=None, + image_processor=None, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + task_encoder=task_encoder, + ) + self.use_train_split_for_val = use_train_split_for_val + + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + if split not in {"train", "val"}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + if self.use_train_split_for_val: + split = "train" + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + task_encoder=self.task_encoder, + worker_config=worker_config, + max_samples_per_sequence=None, + shuffle_buffer_size=100, + split_part=split, + batch_drop_last=True, + virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + ) + return _dataset + + def val_dataloader(self): + """ + Configure the validation DataLoader. + + This method configures the DataLoader for validation data. + + Parameters: + worker_config: Configuration for the data loader workers. + + Returns: + DataLoader: The DataLoader for validation data. + """ + if self.use_train_split_for_val: + return self.train_dataloader() + return super().val_dataloader() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + try: + super().load_state_dict(state_dict) + except Exception as e: + logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py b/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py new file mode 100644 index 00000000..7faa1aaa --- /dev/null +++ b/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import torch +import torch.nn.functional as F +from einops import rearrange +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample[".json"], + pth=sample[".pth"], + pickle=sample[".pickle"], + ) + + +class BasicDiffusionTaskEncoder(DefaultTaskEncoder): + """ + BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. + Attributes: + cookers (list): A list of Cooker objects used for processing. + max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. + text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. + Methods: + __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): + Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. + encode_sample(sample: dict) -> dict: + Encodes a given sample dictionary containing video and text data. + Args: + sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. + Returns: + dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. + Raises: + SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + text_embedding_padding_size: int = 512, + seq_length: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.max_frames = max_frames + self.text_embedding_padding_size = text_embedding_padding_size + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + + def encode_sample(self, sample: dict) -> dict: + video_latent = sample["pth"] + + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + info = sample["json"] + # remove batch dimension + video_latent = video_latent.squeeze(0) + # print(f"video_latent shape at start: {video_latent.shape}") + C, T, H, W = video_latent.shape + seq_len = ( + video_latent.shape[-1] + * video_latent.shape[-2] + * video_latent.shape[-3] + // self.patch_spatial**2 + // self.patch_temporal + ) + # seq_len = 1536 + is_image = T == 1 + + # print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") + if seq_len > self.seq_length: + print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") + raise SkipSample() + + if self.max_frames is not None: + video_latent = video_latent[:, : self.max_frames, :, :] + + # tpcp_size = parallel_state.get_tensor_model_parallel_world_size() + # if parallel_state.get_context_parallel_world_size() > 1: + # tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 + # if (T * H * W) % tpcp_size != 0: + # warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') + # raise SkipSample() + # print(f"video_latent shape before rearrange: {video_latent.shape}") + # video_latent shape before rearrange: torch.Size([16, 1, 64, 96]) + video_latent = rearrange( + video_latent, + "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", + ph=self.patch_spatial, + pw=self.patch_spatial, + pt=self.patch_temporal, + ) + # print(f"video_latent shape after rearrange: {video_latent.shape}") + # After reaaranging: video_latent shape after rearrange: torch.Size([1536, 64]) + # convert sample["pickle"] to numpy, and remove batch dimension + sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) + if is_image: + t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) + else: + t5_text_embeddings = torch.from_numpy(sample["pickle"][0]).to(torch.bfloat16) + t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] + + if t5_text_embeddings_seq_length > self.text_embedding_padding_size: + t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] + else: + t5_text_embeddings = F.pad( + t5_text_embeddings, + ( + 0, + 0, + 0, + self.text_embedding_padding_size - t5_text_embeddings_seq_length, + ), + ) + t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) + + if is_image: + h, w = info["image_height"], info["image_width"] + fps = torch.tensor([30] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) + else: + h, w = info["height"], info["width"] + fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) + image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), + "T H W d -> (T H W) d", + ) + + if self.seq_length is not None: + pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) + loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) + loss_mask[:seq_len] = 1 + video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) + else: + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + + print(f"Loss mask shape: {loss_mask.shape}") + print(f"video_latent shape final: {video_latent.shape}") + return dict( + video=video_latent, + t5_text_embeddings=t5_text_embeddings, + t5_text_mask=t5_text_mask, + image_size=image_size, + fps=fps, + num_frames=num_frames, + loss_mask=loss_mask, + seq_len_q=torch.tensor(seq_len, dtype=torch.int32), + seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), + pos_ids=pos_ids, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), + ) + + +class PosID3D: + def __init__(self, *, max_t=32, max_h=128, max_w=128): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device="cpu"), + torch.arange(self.max_h, device="cpu"), + torch.arange(self.max_w, device="cpu"), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +pos_id_3d = PosID3D() + + +def cook_raw_iamges(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'jpg': original images + - 'png': contains control images + - 'txt': contains raw text + """ + return dict( + **basic_sample_keys(sample), + images=sample["jpg"], + hint=sample["png"], + txt=sample["txt"], + ) + + +class RawImageDiffusionTaskEncoder(DefaultTaskEncoder): + """ + Dummy task encoder takes raw image input on CrudeDataset. + """ + + cookers = [ + # Cooker(cook), + Cooker(cook_raw_iamges), + ] diff --git a/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py b/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py new file mode 100644 index 00000000..56e57684 --- /dev/null +++ b/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import os +import pickle +from typing import Callable, List + +import nemo_run as run +import numpy as np +import torch +import torch.distributed as dist +import webdataset as wds + + +def get_start_end_idx_for_this_rank(dataset_size, rank, world_size): + """ + Calculate the start and end indices for a given rank in a distributed setting. + + Args: + dataset_size (int): The total size of the dataset. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + + Returns: + tuple: A tuple containing the start index (int) and end index (int) for the given rank. + """ + split_size = dataset_size // world_size + start_idx = rank * split_size + # The last rank takes the remainder + end_idx = start_idx + split_size if rank != world_size - 1 else dataset_size + return start_idx, end_idx + + +def dummy_process_func(input): + """ + Generates a sample dictionary containing random image latent tensor, text embedding, + and metadata based on the provided input key. + + Args: + input (str): The key to be used in the sample dictionary. + + Returns: + dict: A dictionary containing the following keys: + - "__key__": The input key. + - ".pth": A randomly generated image latent tensor with shape (3, 1, 720, 1280) and dtype torch.bfloat16. + - ".pickle": A pickled numpy array representing a random text embedding with shape (512, 2048). + - ".json": A dictionary containing metadata with keys: + - "image_height": The height of the image (720). + - "image_width": The width of the image (1280). + """ + C, T, H, W = 3, 1, 720, 1280 + image_latent = torch.randn(C, T, H, W, dtype=torch.bfloat16) + text_embedding = np.random.randn(512, 2048) + sample = { + "__key__": input, + ".pth": image_latent, + ".pickle": pickle.dumps(text_embedding), + ".json": { + "image_height": H, + "image_width": W, + }, + } + return sample + + +@torch.no_grad() +@run.cli.entrypoint +def prepare(process_func: Callable, inputs: List[str], output_dir: str = "output"): + """ + distributed prepration webdataset using the provided processing function, and writes the processed samples to tar files. + + Args: + process_func (Callable): A function that processes a single input and returns the processed sample. + inputs (List[str]): A list of input file paths or data entries to be processed. + output_dir (str, optional): The directory where the output tar files will be saved. Defaults to 'output'. + """ + rank = dist.get_rank() + world_size = torch.distributed.get_world_size() + + start_idx, end_idx = get_start_end_idx_for_this_rank(len(inputs), rank, world_size) + os.makedirs(output_dir, exist_ok=True) + output_tar = os.path.join(output_dir, f"rank{rank}-%06d.tar") + with wds.ShardWriter(output_tar, maxcount=10000) as sink: + for i in range(start_idx, end_idx): + sample = process_func(inputs[i]) + # Write the sample to the tar file + sink.write(sample) + + +@run.cli.factory(target=prepare) +def prepare_dummy_image_dataset() -> run.Partial: + recipe = run.Partial( + prepare, + process_func=dummy_process_func, + inputs=list(str(i) for i in range(1000)), + ) + return recipe + + +if __name__ == "__main__": + dist.init_process_group("nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + run.cli.main(prepare, default_factory=prepare_dummy_image_dataset) diff --git a/dfm/src/megatron/data/Dit/data/utils.py b/dfm/src/megatron/data/Dit/data/utils.py new file mode 100644 index 00000000..dbe8ebad --- /dev/null +++ b/dfm/src/megatron/data/Dit/data/utils.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import numpy as np + + +def minimal_crop(tensor, target_divisor): + """ + Crops the input tensor minimally so that the total number of elements + (T * H * W) is divisible by the specified target_divisor. + + Parameters: + - tensor: NumPy array of shape (C, T, H, W) + - target_divisor: Positive integer specifying the desired divisor + + Returns: + - cropped_tensor: Cropped tensor meeting the divisibility requirement + + Raises: + - ValueError: If it's impossible to meet the divisibility requirement + """ + if not isinstance(target_divisor, int) or target_divisor <= 0: + raise ValueError("target_divisor must be a positive integer greater than zero.") + + C, T, H, W = tensor.shape + total_elements = T * H * W + remainder = total_elements % target_divisor + + if remainder == 0: + return tensor # No cropping needed + + # Elements per unit length in each dimension + elements_per_T = H * W + elements_per_H = T * W + elements_per_W = T * H + + min_elements_removed = None + optimal_deltas = None + + # Limit the search range to avoid unnecessary computations + max_delta_T = min(T - 1, (remainder // elements_per_T) + 1) + max_delta_H = min(H - 1, (remainder // elements_per_H) + 1) + max_delta_W = min(W - 1, (remainder // elements_per_W) + 1) + + for delta_T in range(0, max_delta_T + 1): + for delta_H in range(0, max_delta_H + 1): + for delta_W in range(0, max_delta_W + 1): + if delta_T == delta_H == delta_W == 0: + continue # No cropping + + new_T = T - delta_T + new_H = H - delta_H + new_W = W - delta_W + + if new_T <= 0 or new_H <= 0 or new_W <= 0: + continue # Invalid dimensions + + new_total_elements = new_T * new_H * new_W + if new_total_elements % target_divisor == 0: + elements_removed = delta_T * elements_per_T + delta_H * elements_per_H + delta_W * elements_per_W + if min_elements_removed is None or elements_removed < min_elements_removed: + min_elements_removed = elements_removed + optimal_deltas = (delta_T, delta_H, delta_W) + + if optimal_deltas is None: + raise ValueError("Cannot crop tensor to meet divisibility requirement.") + + delta_T, delta_H, delta_W = optimal_deltas + + # Perform the cropping + # T dimension: crop from the end + end_T = T - delta_T + + # H dimension: center crop + start_H = delta_H // 2 + end_H = H - (delta_H - delta_H // 2) + + # W dimension: center crop + start_W = delta_W // 2 + end_W = W - (delta_W - delta_W // 2) + + cropped_tensor = tensor[:, :end_T, start_H:end_H, start_W:end_W] + return cropped_tensor + + +def test_no_cropping_needed(): + """Test when the tensor already meets the divisibility requirement.""" + C, T, H, W = 3, 8, 8, 8 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + assert cropped_tensor.shape == (C, T, H, W) + assert (T * H * W) % target_divisor == 0 + + +def test_minimal_cropping_T_dimension(): + """Test minimal cropping along the T dimension.""" + C, T, H, W = 3, 9, 7, 6 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_T = cropped_tensor.shape[1] + assert new_T == T - 1, cropped_tensor.shape + assert (new_T * H * W) % target_divisor == 0 + + +def test_minimal_cropping_H_dimension(): + """Test minimal cropping along the H dimension.""" + C, T, H, W = 3, 7, 9, 6 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_H = cropped_tensor.shape[2] + assert new_H == H - 1, cropped_tensor.shape + assert (T * new_H * W) % target_divisor == 0 + + +def test_minimal_cropping_W_dimension(): + """Test minimal cropping along the W dimension.""" + C, T, H, W = 3, 4, 3, 9 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_W = cropped_tensor.shape[3] + assert new_W == W - 1, cropped_tensor.shape + assert (T * H * new_W) % target_divisor == 0 + + +def test_cropping_multiple_dimensions(): + """Test when minimal cropping requires adjustments on multiple dimensions.""" + C, T, H, W = 3, 9, 9, 8 + target_divisor = 16 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_T, new_H, new_W = cropped_tensor.shape[1:] + assert new_T <= T and new_H <= H and new_W <= W + assert (new_T * new_H * new_W) % target_divisor == 0 + + +def test_large_tensor_high_divisor(): + """Test with a larger tensor and higher target_divisor.""" + C, T, H, W = 3, 50, 50, 50 + target_divisor = 1024 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + total_elements = cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3] + assert total_elements % target_divisor == 0 + + +def test_impossible_cropping(): + """Test that an error is raised when it's impossible to meet the requirement.""" + C, T, H, W = 3, 1, 1, 1 + target_divisor = 2 + tensor = np.zeros((C, T, H, W)) + try: + minimal_crop(tensor, target_divisor) + except ValueError: + pass + + +def test_invalid_target_divisor(): + """Test that an error is raised when target_divisor is invalid.""" + C, T, H, W = 3, 8, 8, 8 + tensor = np.zeros((C, T, H, W)) + try: + minimal_crop(tensor, -1) + except ValueError: + pass + + +def test_minimal_elements_removed(): + """Test that the minimal number of elements are removed.""" + C, T, H, W = 3, 7, 7, 7 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + elements_removed = (T * H * W) - (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) + print(cropped_tensor.shape) + assert elements_removed > 0 + assert (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) % target_divisor == 0 + + +test_no_cropping_needed() +test_minimal_elements_removed() +test_cropping_multiple_dimensions() +test_minimal_cropping_T_dimension() +test_minimal_cropping_H_dimension() +test_minimal_cropping_W_dimension() +test_impossible_cropping() +test_invalid_target_divisor() diff --git a/dfm/src/megatron/data/prepare_energon_dataset_wan.py b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan.py similarity index 100% rename from dfm/src/megatron/data/prepare_energon_dataset_wan.py rename to dfm/src/megatron/data/wan/prepare_energon_dataset_wan.py diff --git a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py new file mode 100644 index 00000000..67ea4539 --- /dev/null +++ b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py @@ -0,0 +1,412 @@ +import os +import io +import json +import tarfile +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Iterable + +import cv2 +import numpy as np +import torch + +from diffusers import AutoencoderKLWan +from transformers import AutoTokenizer, UMT5EncoderModel + + +def _map_interpolation(resize_mode: str) -> int: + interpolation_map = { + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "nearest": cv2.INTER_NEAREST, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4, + } + if resize_mode not in interpolation_map: + raise ValueError( + f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}" + ) + return interpolation_map[resize_mode] + + +def _calculate_resize_dimensions( + original_height: int, + original_width: int, + target_size: Optional[Tuple[int, int]], + maintain_aspect_ratio: bool, +) -> Tuple[int, int]: + if target_size is None: + return original_height, original_width + + target_height, target_width = target_size + if not maintain_aspect_ratio: + return target_height, target_width + + original_aspect = original_width / max(1, original_height) + target_aspect = target_width / max(1, target_height) + + if original_aspect > target_aspect: + new_width = target_width + new_height = int(round(target_width / max(1e-6, original_aspect))) + else: + new_height = target_height + new_width = int(round(target_height * original_aspect)) + + return new_height, new_width + + +def _resize_frame( + frame: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> np.ndarray: + if target_size is None: + return frame + + original_height, original_width = frame.shape[:2] + resize_height, resize_width = _calculate_resize_dimensions( + original_height, original_width, target_size, maintain_aspect_ratio + ) + + interpolation = _map_interpolation(resize_mode) + resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) + + if maintain_aspect_ratio and center_crop: + target_height, target_width = target_size + if resize_height != target_height or resize_width != target_width: + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = min(resize_height, y_start + target_height) + x_end = min(resize_width, x_start + target_width) + resized_frame = resized_frame[y_start:y_end, x_start:x_end] + + if resized_frame.shape[0] < target_height or resized_frame.shape[1] < target_width: + pad_height = max(0, target_height - resized_frame.shape[0]) + pad_width = max(0, target_width - resized_frame.shape[1]) + resized_frame = np.pad( + resized_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 + ) + + return resized_frame + + +def _decode_image_bytes_to_rgb(image_bytes: bytes) -> Optional[np.ndarray]: + array = np.frombuffer(image_bytes, dtype=np.uint8) + img_bgr = cv2.imdecode(array, cv2.IMREAD_COLOR) + if img_bgr is None: + return None + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + return img_rgb + + +def _image_to_video_tensor( + image_rgb: np.ndarray, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, + target_dtype: torch.dtype, +) -> torch.Tensor: + frame = _resize_frame(image_rgb, target_size, resize_mode, maintain_aspect_ratio, center_crop) + frame = frame.astype(np.float32) / 255.0 # H, W, C in [0,1] + + video_array = frame[None, ...] # T=1, H, W, C + video_tensor = torch.from_numpy(video_array) # T, H, W, C + video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # 1, C, T=1, H, W + video_tensor = video_tensor.to(dtype=target_dtype) + return video_tensor + + +@torch.no_grad() +def _init_hf_models( + model_id: str, + device: str, + enable_memory_optimization: bool, +): + dtype = torch.float16 if device.startswith("cuda") else torch.float32 + + text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=dtype, + ) + text_encoder.to(device) + text_encoder.eval() + + vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=dtype, + ) + vae.to(device) + vae.eval() + if enable_memory_optimization: + vae.enable_slicing() + vae.enable_tiling() + + tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer") + + return vae, text_encoder, tokenizer, dtype + + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = (caption or "").strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + return outputs + + +@torch.no_grad() +def _encode_video_latents( + vae: AutoencoderKLWan, + device: str, + video_tensor: torch.Tensor, + deterministic_latents: bool, +) -> torch.Tensor: + video_tensor = video_tensor.to(device=device, dtype=vae.dtype) + video_tensor = video_tensor * 2.0 - 1.0 # [0,1] -> [-1,1] + + latent_dist = vae.encode(video_tensor) + if deterministic_latents: + video_latents = latent_dist.latent_dist.mean + else: + video_latents = latent_dist.latent_dist.sample() + + latent_mean = video_latents.mean().item() + latent_std = video_latents.std().item() + + if abs(latent_mean) < 0.5 and 0.5 < latent_std < 2.0: + final_latents = video_latents + else: + if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): + raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") + latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) + final_latents = (video_latents - latents_mean) / latents_std + + return final_latents + + +def _iter_tar_images_and_captions(tf: tarfile.TarFile, image_exts: Tuple[str, ...]) -> Iterable[Tuple[tarfile.TarInfo, Optional[tarfile.TarInfo]]]: + members = [m for m in tf.getmembers() if m.isfile()] + # Map stem -> caption member + txt_map: Dict[str, tarfile.TarInfo] = {} + for m in members: + name = os.path.basename(m.name) + if name.lower().endswith(".txt"): + stem = os.path.splitext(name)[0] + txt_map[stem] = m + + for m in members: + name = os.path.basename(m.name) + lower = name.lower() + if lower.endswith(image_exts): + stem = os.path.splitext(name)[0] + caption_member = txt_map.get(stem, None) + yield m, caption_member + + +def _read_tar_member_bytes(tf: tarfile.TarFile, member: tarfile.TarInfo) -> bytes: + f = tf.extractfile(member) + if f is None: + return b"" + with f: + return f.read() + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description=( + "Prepare WAN encodings for image tar shards. Each .tar is written to a same-named directory " + "containing per-image VAE latents (.pth), T5 embeddings (.pkl), and metadata (.json)." + ) + ) + parser.add_argument("--input_dir", type=str, required=True, help="Directory containing .tar shards of images") + parser.add_argument("--output_root", type=str, required=True, help="Root directory to write per-tar output dirs") + parser.add_argument( + "--model", + default="Wan-AI/Wan2.1-T2V-14B-Diffusers", + help=( + "Wan2.1 model ID (e.g., Wan-AI/Wan2.1-T2V-14B-Diffusers or Wan-AI/Wan2.1-T2V-1.3B-Diffusers)" + ), + ) + parser.add_argument( + "--stochastic", + action="store_true", + help="Use stochastic encoding (sampling) instead of deterministic posterior mean", + ) + parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") + parser.add_argument( + "--image_exts", + type=str, + default=".jpg,.jpeg,.png,.webp", + help="Comma-separated list of image extensions to include", + ) + parser.add_argument("--height", type=int, default=None, help="Target height for image frames") + parser.add_argument("--width", type=int, default=None, help="Target width for image frames") + parser.add_argument( + "--resize_mode", + default="bilinear", + choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], + help="Interpolation mode for resizing", + ) + parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") + parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") + parser.add_argument( + "--skip-existing", + action="store_true", + help="Skip images whose output files already exist (all three: .pth, .pkl, .json)", + ) + + args = parser.parse_args() + + input_dir = Path(args.input_dir) + output_root = Path(args.output_root) + output_root.mkdir(parents=True, exist_ok=True) + + # Target size + target_size = None + if args.height is not None and args.width is not None: + target_size = (args.height, args.width) + elif (args.height is None) ^ (args.width is None): + parser.error("Both --height and --width must be specified together") + + # Init HF models + device = "cuda" + vae, text_encoder, tokenizer, model_dtype = _init_hf_models( + model_id=args.model, + device=device, + enable_memory_optimization=not args.no_memory_optimization, + ) + + image_exts = tuple(ext.strip().lower() for ext in args.image_exts.split(",") if ext.strip()) + + # Find tar files + tar_files = sorted([p for p in input_dir.iterdir() if p.is_file() and p.suffix.lower() == ".tar"]) + if not tar_files: + raise FileNotFoundError(f"No .tar files found in {input_dir}") + + # DEBUGGING + # for tar_path in tar_files: + for tar_path in tar_files[:1]: + tar_stem = tar_path.name[:-4] # drop .tar + out_dir = output_root / tar_stem + out_dir.mkdir(parents=True, exist_ok=True) + + processed = 0 + failed = 0 + + # Open tar for streaming read + try: + with tarfile.open(tar_path, mode="r:*") as tf: + for img_member, cap_member in _iter_tar_images_and_captions(tf, image_exts): + img_name = os.path.basename(img_member.name) + stem = os.path.splitext(img_name)[0] + + latents_path = out_dir / f"{stem}.pth" + text_path = out_dir / f"{stem}.pkl" + meta_path = out_dir / f"{stem}.json" + + if args.skip_existing and latents_path.exists() and text_path.exists() and meta_path.exists(): + continue + + try: + img_bytes = _read_tar_member_bytes(tf, img_member) + if not img_bytes: + failed += 1 + continue + rgb = _decode_image_bytes_to_rgb(img_bytes) + if rgb is None: + failed += 1 + continue + + caption_text = "" + if cap_member is not None: + try: + caption_bytes = _read_tar_member_bytes(tf, cap_member) + caption_text = caption_bytes.decode("utf-8", errors="ignore") + except Exception: + caption_text = "" + + video_tensor = _image_to_video_tensor( + image_rgb=rgb, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + target_dtype=model_dtype, + ) + + # Encode + text_embed = _encode_text(tokenizer, text_encoder, device, caption_text) + latents = _encode_video_latents( + vae, device, video_tensor, deterministic_latents=not args.stochastic + ) + + # Move to CPU and drop batch dim + text_embed_cpu = text_embed.detach().to(device="cpu")[0] + latents_cpu = latents.detach().to(device="cpu")[0] + + # Save outputs + torch.save(latents_cpu, latents_path) + # Use pickle for text embeddings to keep exact dtype/shape + with open(text_path, "wb") as f: + import pickle + + pickle.dump(text_embed_cpu, f, protocol=pickle.HIGHEST_PROTOCOL) + + # Metadata + C, T, H, W = video_tensor.shape[1:] + json_data = { + "source_tar": str(tar_path), + "tar_member": img_member.name, + "image_name": img_name, + "processed_frames": int(T), # always 1 + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + with open(meta_path, "w", encoding="utf-8") as f: + json.dump(json_data, f, ensure_ascii=False) + + processed += 1 + except Exception: + failed += 1 + continue + except Exception as e: + print(f"Failed to process tar {tar_path}: {e}") + continue + + print(f"Processed tar {tar_path.name}: {processed} ok, {failed} failed. Output -> {out_dir}") + + +if __name__ == "__main__": + main() + + diff --git a/dfm/src/megatron/data/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py similarity index 91% rename from dfm/src/megatron/data/wan_energon_datamodule.py rename to dfm/src/megatron/data/wan/wan_energon_datamodule.py index 98774e81..cc20dedc 100644 --- a/dfm/src/megatron/data/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -20,8 +20,8 @@ from torch import int_repr -from megatron.bridge.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule -from megatron.bridge.data.wan.wan_taskencoder import WanTaskEncoder +from dfm.src.megatron.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider @dataclass(kw_only=True) diff --git a/dfm/src/megatron/data/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py similarity index 82% rename from dfm/src/megatron/data/wan_taskencoder.py rename to dfm/src/megatron/data/wan/wan_taskencoder.py index 097a8583..5d504b4d 100644 --- a/dfm/src/megatron/data/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from megatron.energon import DefaultTaskEncoder, SkipSample from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state @@ -73,45 +73,28 @@ def __init__( self.seq_length = seq_length - # def actual_encode_sample(self, sample: dict) -> dict: - - # video_latent = sample["pth"] - # context_embeddings = sample["pickle"] - # video_metadata = sample["json"] - - # # sanity quality check - # if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): - # raise SkipSample() - # if torch.max(torch.abs(video_latent)) > 1e3: - # raise SkipSample() - - # # calculate grid size - # grid_size = grid_sizes_calculation( - # input_shape = video_latent.shape[1:], - # patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), - # ) - - # ### Note: shape of sample's values - # # video_latent: [latents_channels, F_latents, W_latents, H_latents] - # # grid_size: [F_patches, W_patches, H_patches] - # # context_embeddings: [context_seq_len, text_embedding_dim] + def encode_sample(self, sample: dict) -> dict: - # return dict( - # video_latent=video_latent, - # grid_size=grid_size, - # context_embeddings=context_embeddings, - # video_metadata=video_metadata, - # ) + video_latent = sample["pth"] + context_embeddings = sample["pickle"] + video_metadata = sample["json"] + # sanity quality check + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() - def encode_sample(self, sample: dict) -> dict: + # calculate grid size + grid_size = grid_sizes_calculation( + input_shape = video_latent.shape[1:], + patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), + ) - # mock encode sample - video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) - # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) - grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) - context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) - video_metadata = {} + ### Note: shape of sample's values + # video_latent: [latents_channels, F_latents, W_latents, H_latents] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] return dict( video_latent=video_latent, @@ -121,6 +104,23 @@ def encode_sample(self, sample: dict) -> dict: ) + # def encode_sample(self, sample: dict) -> dict: + + # # mock encode sample + # video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) + # # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) + # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) + # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) + # video_metadata = {} + + # return dict( + # video_latent=video_latent, + # grid_size=grid_size, + # context_embeddings=context_embeddings, + # video_metadata=video_metadata, + # ) + + def batch(self, samples: list[dict]) -> dict: # process video latents diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index fedef4f4..fe1ab4ed 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -14,20 +14,20 @@ import torch.distributed as dist from tqdm import tqdm -from megatron.bridge.models.wan.wan_model import WanModel +from dfm.src.megatron.model.wan.wan_model import WanModel from megatron.bridge.models.wan.wan_provider import WanModelProvider -from megatron.bridge.models.wan.modules.t5 import T5EncoderModel -from megatron.bridge.models.wan.modules import WanVAE -from megatron.bridge.models.wan.inference.utils.fm_solvers import ( +from dfm.src.megatron.model.wan.modules.t5 import T5EncoderModel +from dfm.src.megatron.model.wan.modules import WanVAE +from dfm.src.megatron.model.wan.inference.utils.fm_solvers import ( FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps, ) -from megatron.bridge.models.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler -from megatron.bridge.models.wan.utils.utils import grid_sizes_calculation, patchify +from dfm.src.megatron.model.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler +from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state from torch.nn import functional as F -from megatron.bridge.models.wan.utils.utils import split_inputs_cp, cat_outputs_cp +from dfm.src.megatron.model.wan.utils.utils import split_inputs_cp, cat_outputs_cp import math from typing import Tuple, Union diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index af82fa7e..62f9f322 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -19,8 +19,8 @@ from megatron.core import parallel_state from torch import Tensor from diffusers import WanPipeline -from megatron.bridge.models.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling -from megatron.bridge.models.wan.utils.utils import patchify, split_inputs_cp +from dfm.src.megatron.model.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling +from dfm.src.megatron.model.wan.utils.utils import patchify, split_inputs_cp class FlowPipeline: diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 93d0e933..3bd37733 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -1,6 +1,6 @@ import torch from torch.cuda import amp -from megatron.bridge.models.wan.utils.utils import split_inputs_cp +from dfm.src.megatron.model.wan.utils.utils import split_inputs_cp from megatron.core import parallel_state class Wan3DRopeEmbeddings(torch.nn.Module): diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index d11b7803..dc409f33 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -28,10 +28,10 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint -from megatron.bridge.models.wan.wan_layer_spec import ( +from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) -from megatron.bridge.models.wan.wan_layer_spec import WanLayerNorm +from dfm.src.megatron.model.wan.wan_layer_spec import WanLayerNorm from torch import Tensor from .rope_utils import Wan3DRopeEmbeddings From 7b834f03db838eecd2523576da2d5f24c26c35b9 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 30 Oct 2025 21:30:03 -0700 Subject: [PATCH 03/80] workable thd --- .../flow_matching/flow_inference_pipeline.py | 35 +++++++++-------- .../model/wan/flow_matching/flow_pipeline.py | 18 ++++++--- dfm/src/megatron/model/wan/rope_utils.py | 9 +++-- dfm/src/megatron/model/wan/utils/utils.py | 39 +++++++++++++++++++ 4 files changed, 75 insertions(+), 26 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index fe1ab4ed..e5023f22 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -99,9 +99,10 @@ def __init__( wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) - # DEBUGGING + # DEBUGGING thd # set qkv_format to to "thd" for context parallelism - self.model.config.qkv_format = "sbhd" + # self.model.config.qkv_format = "sbhd" + self.model.config.qkv_format = "thd" # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 @@ -486,11 +487,12 @@ def noop_no_sync(): timestep = [t] * batch_size timestep = torch.stack(timestep) - # run context parallelism slitting - if parallel_state.get_context_parallel_world_size() > 1: - latent_model_input = split_inputs_cp(latent_model_input, 0) - arg_c['context'] = split_inputs_cp(arg_c['context'], 0) - arg_null['context'] = split_inputs_cp(arg_null['context'], 0) + # DEBUGGING thd + # # run context parallelism slitting + # if parallel_state.get_context_parallel_world_size() > 1: + # latent_model_input = split_inputs_cp(latent_model_input, 0) + # arg_c['context'] = split_inputs_cp(arg_c['context'], 0) + # arg_null['context'] = split_inputs_cp(arg_null['context'], 0) self.model.to(self.device) noise_pred_cond = self.forward_pp_step( @@ -499,15 +501,16 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) - # run context parallelism gathering - if parallel_state.get_context_parallel_world_size() > 1: - arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep - arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep - # TODO: does this step slow down speed??? - noise_pred_cond = noise_pred_cond.contiguous() - noise_pred_uncond = noise_pred_uncond.contiguous() - noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) - noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) + # DEBUGGING thd + # # run context parallelism gathering + # if parallel_state.get_context_parallel_world_size() > 1: + # arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep + # arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep + # # TODO: does this step slow down speed??? + # noise_pred_cond = noise_pred_cond.contiguous() + # noise_pred_uncond = noise_pred_uncond.contiguous() + # noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) + # noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 62f9f322..0bf042fd 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -20,7 +20,7 @@ from torch import Tensor from diffusers import WanPipeline from dfm.src.megatron.model.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling -from dfm.src.megatron.model.wan.utils.utils import patchify, split_inputs_cp +from dfm.src.megatron.model.wan.utils.utils import patchify, split_inputs_cp, thd_split_inputs_cp class FlowPipeline: @@ -150,11 +150,17 @@ def training_step( # ======================================================================== if parallel_state.get_context_parallel_world_size() > 1: - video_latents = split_inputs_cp(video_latents, 0) - noisy_latents = split_inputs_cp(noisy_latents, 0) - noise = split_inputs_cp(noise, 0) - context_embeddings = split_inputs_cp(context_embeddings, 0) - split_loss_mask = split_inputs_cp(loss_mask, 0) + # DEBUGGING thd + # video_latents = split_inputs_cp(video_latents, 0) + # noisy_latents = split_inputs_cp(noisy_latents, 0) + # noise = split_inputs_cp(noise, 0) + # context_embeddings = split_inputs_cp(context_embeddings, 0) + # split_loss_mask = split_inputs_cp(loss_mask, 0) + video_latents = thd_split_inputs_cp(video_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noisy_latents = thd_split_inputs_cp(noisy_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + noise = thd_split_inputs_cp(noise, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + context_embeddings = thd_split_inputs_cp(context_embeddings, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) + split_loss_mask = thd_split_inputs_cp(loss_mask, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) else: video_latents = video_latents noisy_latents = noisy_latents diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 3bd37733..f96d0364 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -57,9 +57,10 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) freqs_real = torch.cat(freqs_real, dim=1) - # TODO: if run context/sequence related parallel, then we need to scatter - # the freqs_real to the context parallel region, using specific cp_rank split method - if parallel_state.get_context_parallel_world_size() > 1: - freqs_real = split_inputs_cp(freqs_real, 0) + # DEBUGGING thd + # # TODO: if run context/sequence related parallel, then we need to scatter + # # the freqs_real to the context parallel region, using specific cp_rank split method + # if parallel_state.get_context_parallel_world_size() > 1: + # freqs_real = split_inputs_cp(freqs_real, 0) return freqs_real \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/utils/utils.py b/dfm/src/megatron/model/wan/utils/utils.py index 8551c6fc..dea95241 100644 --- a/dfm/src/megatron/model/wan/utils/utils.py +++ b/dfm/src/megatron/model/wan/utils/utils.py @@ -3,6 +3,8 @@ from torch.distributed import all_gather import megatron.core.parallel_state as parallel_state import math +import torch.distributed as dist +import transformer_engine_torch as tex def grid_sizes_calculation( input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) @@ -126,3 +128,40 @@ def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: return gathered_tensors else: return x + +# DEBUGGING thd +def thd_split_inputs_cp(x: torch.Tensor, + cu_seqlens_q_padded: torch.Tensor, + cp_group: dist.ProcessGroup) -> torch.Tensor: + """ + Split a THD-packed tensor across CP ranks for inputs shaped [S, B, ...]. + + Args: + x: [S, B, ...] tensor (sequence first). + cu_seqlens_q_padded: 1D int32 THD cu_seqlens (padded) used for packing. + cp_group: context-parallel process group. + + Returns: + x_local: [S_local, B, ...] shard for this CP rank. + """ + # Move to [B, S, ...] to use THD partitioning along S + x_bs = x.transpose(0, 1).contiguous() # [B, S, ...] + + total_S = x_bs.size(1) + cp_size = dist.get_world_size(cp_group) + cp_rank = dist.get_rank(cp_group) + + # Compute this rank's THD partition indices (same API as during gather) + idx = tex.thd_get_partitioned_indices( + cu_seqlens_q_padded, # int32 offsets + total_S, + cp_size, + cp_rank, + ).to(device=x_bs.device, dtype=torch.long) # [S_local] + + # Take the shard along sequence dim + x_local_bs = x_bs.index_select(dim=1, index=idx).contiguous() # [B, S_local, ...] + + # Return to [S, B, ...] + x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] + return x_local \ No newline at end of file From 2152abdd0558572850ba250885aac052f45dc82f Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 07:04:02 -0700 Subject: [PATCH 04/80] clean up, remove all CP for sbhd, CP now is only for thd --- .../flow_matching/flow_inference_pipeline.py | 25 ++----------------- .../model/wan/flow_matching/flow_pipeline.py | 8 +----- dfm/src/megatron/model/wan/rope_utils.py | 10 +++----- dfm/src/megatron/model/wan/utils/utils.py | 2 +- 4 files changed, 8 insertions(+), 37 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index e5023f22..f316ea25 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -27,7 +27,6 @@ from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state from torch.nn import functional as F -from dfm.src.megatron.model.wan.utils.utils import split_inputs_cp, cat_outputs_cp import math from typing import Tuple, Union @@ -99,10 +98,8 @@ def __init__( wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) - # DEBUGGING thd - # set qkv_format to to "thd" for context parallelism - # self.model.config.qkv_format = "sbhd" - self.model.config.qkv_format = "thd" + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism + self.model.config.qkv_format = "thd" # "sbhd" # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 @@ -487,13 +484,6 @@ def noop_no_sync(): timestep = [t] * batch_size timestep = torch.stack(timestep) - # DEBUGGING thd - # # run context parallelism slitting - # if parallel_state.get_context_parallel_world_size() > 1: - # latent_model_input = split_inputs_cp(latent_model_input, 0) - # arg_c['context'] = split_inputs_cp(arg_c['context'], 0) - # arg_null['context'] = split_inputs_cp(arg_null['context'], 0) - self.model.to(self.device) noise_pred_cond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) @@ -501,17 +491,6 @@ def noop_no_sync(): noise_pred_uncond = self.forward_pp_step( latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) - # DEBUGGING thd - # # run context parallelism gathering - # if parallel_state.get_context_parallel_world_size() > 1: - # arg_c['context'] = cat_outputs_cp(arg_c['context'], 0) # we need to cat the context back together for the next timestep - # arg_null['context'] = cat_outputs_cp(arg_null['context'], 0) # we need to cat the context back together for the next timestep - # # TODO: does this step slow down speed??? - # noise_pred_cond = noise_pred_cond.contiguous() - # noise_pred_uncond = noise_pred_uncond.contiguous() - # noise_pred_cond = cat_outputs_cp(noise_pred_cond, 0) - # noise_pred_uncond = cat_outputs_cp(noise_pred_uncond, 0) - # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 0bf042fd..72ddc47b 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -20,7 +20,7 @@ from torch import Tensor from diffusers import WanPipeline from dfm.src.megatron.model.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling -from dfm.src.megatron.model.wan.utils.utils import patchify, split_inputs_cp, thd_split_inputs_cp +from dfm.src.megatron.model.wan.utils.utils import patchify, thd_split_inputs_cp class FlowPipeline: @@ -150,12 +150,6 @@ def training_step( # ======================================================================== if parallel_state.get_context_parallel_world_size() > 1: - # DEBUGGING thd - # video_latents = split_inputs_cp(video_latents, 0) - # noisy_latents = split_inputs_cp(noisy_latents, 0) - # noise = split_inputs_cp(noise, 0) - # context_embeddings = split_inputs_cp(context_embeddings, 0) - # split_loss_mask = split_inputs_cp(loss_mask, 0) video_latents = thd_split_inputs_cp(video_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) noisy_latents = thd_split_inputs_cp(noisy_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) noise = thd_split_inputs_cp(noise, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index f96d0364..00a2a519 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -1,6 +1,5 @@ import torch from torch.cuda import amp -from dfm.src.megatron.model.wan.utils.utils import split_inputs_cp from megatron.core import parallel_state class Wan3DRopeEmbeddings(torch.nn.Module): @@ -57,10 +56,9 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) freqs_real = torch.cat(freqs_real, dim=1) - # DEBUGGING thd - # # TODO: if run context/sequence related parallel, then we need to scatter - # # the freqs_real to the context parallel region, using specific cp_rank split method - # if parallel_state.get_context_parallel_world_size() > 1: - # freqs_real = split_inputs_cp(freqs_real, 0) + # Note: + # when running context_parallel, which must use "thd" for qkv_format, + # we don't need to scatter the freqs to the context parallel region, + # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region return freqs_real \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/utils/utils.py b/dfm/src/megatron/model/wan/utils/utils.py index dea95241..9fc86555 100644 --- a/dfm/src/megatron/model/wan/utils/utils.py +++ b/dfm/src/megatron/model/wan/utils/utils.py @@ -129,7 +129,7 @@ def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: else: return x -# DEBUGGING thd + def thd_split_inputs_cp(x: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, cp_group: dist.ProcessGroup) -> torch.Tensor: From 389a037b77a8357f523b8e70ed56318750299074 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 31 Oct 2025 14:37:10 -0700 Subject: [PATCH 05/80] run outside of Mbridge --- .../conf/wan_pretrain_override_example.yaml | 0 .../megatron/recipe/wan/example_commands.sh | 78 +++++ .../megatron/recipe/wan/inference_wan.py | 324 ++++++++++++++++++ .../megatron/recipe/wan/pretrain_wan.py | 184 ++++++++++ dfm/examples/megatron/recipe/wan/wan.py | 219 ++++++++++++ .../megatron/recipe/wan/wan_provider.py | 81 +++++ dfm/examples/megatron/recipe/wan/wan_step.py | 127 +++++++ .../flow_matching/flow_inference_pipeline.py | 2 +- 8 files changed, 1014 insertions(+), 1 deletion(-) create mode 100644 dfm/examples/megatron/recipe/wan/conf/wan_pretrain_override_example.yaml create mode 100644 dfm/examples/megatron/recipe/wan/example_commands.sh create mode 100644 dfm/examples/megatron/recipe/wan/inference_wan.py create mode 100644 dfm/examples/megatron/recipe/wan/pretrain_wan.py create mode 100644 dfm/examples/megatron/recipe/wan/wan.py create mode 100644 dfm/examples/megatron/recipe/wan/wan_provider.py create mode 100644 dfm/examples/megatron/recipe/wan/wan_step.py diff --git a/dfm/examples/megatron/recipe/wan/conf/wan_pretrain_override_example.yaml b/dfm/examples/megatron/recipe/wan/conf/wan_pretrain_override_example.yaml new file mode 100644 index 00000000..e69de29b diff --git a/dfm/examples/megatron/recipe/wan/example_commands.sh b/dfm/examples/megatron/recipe/wan/example_commands.sh new file mode 100644 index 00000000..13235aed --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/example_commands.sh @@ -0,0 +1,78 @@ +### set path to Megatron-Bridge +DFM_PATH=/path/to/dfm +MBRIDGE_PATH=/path/to/megatron-bridge +export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" + + +### install dependencies +pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 +python3 -m pip install --upgrade diffusers +pip install easydict +pip install imageio +pip install imageio-ffmpeg + + +### Convert checkpoint +See ${MBRIDGE_PATH}/examples/conversion/convert_wan_checkpoints.py for details. + + +### Finetuning +export HF_TOKEN=... +export WANDB_API_KEY=... +EXP_NAME=... +PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint +CHECKPOINT_DIR=/path/to/checkpoint_dir +DATASET_PATH=/path/to/dataset +cd ${MBRIDGE_PATH} +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.sequence_parallel=false \ + model.qkv_format=thd \ + dataset.path=${DATASET_PATH} \ + checkpoint.save=${CHECKPOINT_DIR} \ + checkpoint.load=${PRETRAINED_CHECKPOINT} \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=1 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=1 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name=${EXP_NAME} \ + logger.wandb_save_dir=${CHECKPOINT_DIR} + + +### Inferencing +# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" +# T5: models_t5_umt5-xxl-enc-bf16.pth, google +# VAE: Wan2.1_VAE.pth +export HF_TOKEN=... +CHECKPOINT_DIR=/path/to/checkpoint_dir +T5_DIR=/path/to/t5_weights +VAE_DIR=/path/to/vae_weights +cd ${MBRIDGE_PATH} +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 832*480 \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 1000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 \ No newline at end of file diff --git a/dfm/examples/megatron/recipe/wan/inference_wan.py b/dfm/examples/megatron/recipe/wan/inference_wan.py new file mode 100644 index 00000000..2f480a2b --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/inference_wan.py @@ -0,0 +1,324 @@ +# Example of running script for Wan inference. +# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ +# --task t2v-1.3B \ +# --sizes 480*832 \ +# --checkpoint_dir /path/to/wan_checkpoint_dir \ +# --t5_checkpoint_dir /path/to/t5_checkpoint_dir \ +# --vae_checkpoint_dir /path/to/vae_checkpoint_dir \ +# --frame_nums 81 \ +# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ +# --tensor_parallel_size 1 \ +# --context_parallel_size 1 \ +# --pipeline_parallel_size 1 \ +# --sequence_parallel False \ +# --base_seed 42 \ +# --sample_steps 50 + +import argparse +import logging +import os +import sys +import warnings +from datetime import datetime + +warnings.filterwarnings('ignore') + +import random + +import torch +import torch.distributed as dist +from PIL import Image + +from dfm.src.megatron.model.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline +from dfm.src.megatron.model.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from dfm.src.megatron.model.wan.inference.utils.utils import cache_video, str2bool + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, +} + + +def _validate_args(args): + # Basic check + assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." + assert args.t5_checkpoint_dir is not None, "Please specify the T5 checkpoint directory." + assert args.vae_checkpoint_dir is not None, "Please specify the VAE checkpoint directory." + assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" + + # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. + if args.sample_steps is None: + args.sample_steps = 50 + + if args.sample_shift is None: + args.sample_shift = 5.0 + + # Frames default handled later; no single frame arg anymore + + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( + 0, sys.maxsize) + # Size check: only validate provided --sizes; default handled later + if args.sizes is not None and len(args.sizes) > 0: + for s in args.sizes: + assert s in SUPPORTED_SIZES[args.task], ( + f"Unsupport size {s} for task {args.task}, supported sizes are: " + f"{', '.join(SUPPORTED_SIZES[args.task])}") + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate a image or video from a text prompt or image using Wan" + ) + parser.add_argument( + "--task", + type=str, + default="t2v-14B", + choices=list(WAN_CONFIGS.keys()), + help="The task to run.") + parser.add_argument( + "--sizes", + type=str, + nargs="+", + default=None, + choices=list(SIZE_CONFIGS.keys()), + help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" + ) + parser.add_argument( + "--frame_nums", + type=int, + nargs="+", + default=None, + help="List of frame counts (each should be 4n+1). Broadcasts if single value." + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main WAN checkpoint directory.") + parser.add_argument( + "--checkpoint_step", + type=int, + default=None, + help=( + "Optional training step to load, e.g. 1800 -> iter_0001800. " + "If not provided, the latest (largest) step in --checkpoint_dir is used.") + ) + parser.add_argument( + "--t5_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing T5 checkpoint/tokenizer") + parser.add_argument( + "--vae_checkpoint_dir", + type=str, + default=None, + help="Optional directory containing VAE checkpoint") + parser.add_argument( + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + ) + parser.add_argument( + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.") + parser.add_argument( + "--save_file", + type=str, + default=None, + help="The file to save the generated image or video to.") + parser.add_argument( + "--prompts", + type=str, + nargs="+", + default=None, + help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'" + ) + parser.add_argument( + "--base_seed", + type=int, + default=-1, + help="The seed to use for generating the image or video.") + parser.add_argument( + "--sample_solver", + type=str, + default='unipc', + choices=['unipc', 'dpm++'], + help="The solver used to sample.") + parser.add_argument( + "--sample_steps", type=int, default=None, help="The sampling steps.") + parser.add_argument( + "--sample_shift", + type=float, + default=None, + help="Sampling shift factor for flow matching schedulers.") + parser.add_argument( + "--sample_guide_scale", + type=float, + default=5.0, + help="Classifier free guidance scale.") + parser.add_argument( + "--tensor_parallel_size", + type=int, + default=1, + help="Tensor parallel size.") + parser.add_argument( + "--context_parallel_size", + type=int, + default=1, + help="Context parallel size.") + parser.add_argument( + "--pipeline_parallel_size", + type=int, + default=1, + help="Pipeline parallel size.") + parser.add_argument( + "--sequence_parallel", + type=str2bool, + default=False, + help="Sequence parallel.") + + args = parser.parse_args() + + _validate_args(args) + + return args + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def generate(args): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device = local_rank + _init_logging(rank) + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(local_rank) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + + cfg = WAN_CONFIGS[args.task] + + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + if "t2v" in args.task: + # Resolve prompts list (default to example prompt) + if args.prompts is not None and len(args.prompts) > 0: + prompts = args.prompts + else: + prompts = [EXAMPLE_PROMPT[args.task]["prompt"]] + + # Resolve sizes list (default to first supported size for task) + if args.sizes is not None and len(args.sizes) > 0: + size_keys = args.sizes + else: + size_keys = [SUPPORTED_SIZES[args.task][0]] + + # Resolve frame counts list (default 81) + if args.frame_nums is not None and len(args.frame_nums) > 0: + frame_nums = args.frame_nums + else: + frame_nums = [81] + + # Enforce 1:1 pairing across lists + assert len(prompts) == len(size_keys) == len(frame_nums), ( + f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " + f"must have the same length") + + logging.info("Creating flow inference pipeline.") + pipeline = FlowInferencePipeline( + config=cfg, + checkpoint_dir=args.checkpoint_dir, + checkpoint_step=args.checkpoint_step, + t5_checkpoint_dir=args.t5_checkpoint_dir, + vae_checkpoint_dir=args.vae_checkpoint_dir, + device_id=device, + rank=rank, + t5_cpu=args.t5_cpu, + tensor_parallel_size=args.tensor_parallel_size, + context_parallel_size=args.context_parallel_size, + pipeline_parallel_size=args.pipeline_parallel_size, + sequence_parallel=args.sequence_parallel, + pipeline_dtype=torch.float32, + ) + + # DEBUGGING + rank = dist.get_rank() + if rank == 0: + print("tensor_parallel_size:", args.tensor_parallel_size) + print("context_parallel_size:", args.context_parallel_size) + print("pipeline_parallel_size:", args.pipeline_parallel_size) + print("sequence_parallel:", args.sequence_parallel) + print("\n\n\n") + + logging.info( + f"Generating videos ...") + videos = pipeline.generate( + prompts=prompts, + sizes=[SIZE_CONFIGS[size] for size in size_keys], + frame_nums=frame_nums, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", + "_")[:50] + suffix = '.mp4' + formatted_save_file = f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + + if "t2v" in args.task: + logging.info(f"Saving generated video to {formatted_save_file}") + cache_video( + tensor=video[None], + save_file=formatted_save_file, + fps=cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + logging.info("Finished.") + + +if __name__ == "__main__": + args = _parse_args() + generate(args) diff --git a/dfm/examples/megatron/recipe/wan/pretrain_wan.py b/dfm/examples/megatron/recipe/wan/pretrain_wan.py new file mode 100644 index 00000000..7742397e --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/pretrain_wan.py @@ -0,0 +1,184 @@ + +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Wan Pretraining Script with YAML and CLI Configuration Overrides. + +This script provides a flexible way to pretrain Wan models using Megatron-Bridge with support for +both YAML configuration files and command-line overrides using Hydra-style syntax. + +Examples: + Basic usage with default configuration: + $ torchrun --nproc_per_node=8 pretrain_wan.py + + Using a custom YAML config file: + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_custom_config.yaml + + Using CLI overrides only: + $ torchrun --nproc_per_node=8 pretrain_wan.py model.tensor_model_parallel_size=4 train.train_iters=100000 + + Combining YAML and CLI overrides (CLI takes precedence): + $ torchrun --nproc_per_node=8 pretrain_wan.py --config-file conf/my_config.yaml \ + model.pipeline_dtype=torch.float16 \ + train.global_batch_size=512 + +Configuration Precedence: + 1. Base configuration from pretrain_config() recipe + 2. YAML overrides from --config-file (if provided) + 3. CLI overrides (highest precedence) + +Supported Override Syntax: + - Standard assignment: key=value + - Nested assignment: section.subsection.key=value + - Addition: +new_key=value + - Deletion: ~key_to_remove + - Type conversion: Automatic for basic types (int, float, bool, str) + - Complex types: torch.dtype, enums, etc. are supported +""" + +import argparse +import logging +import os +import sys +from pathlib import Path +from typing import Tuple + +from omegaconf import OmegaConf + +from dfm.examples.megatron.recipe.wan.wan import pretrain_config +from megatron.bridge.training.config import ConfigContainer +from dfm.examples.megatron.recipe.wan.wan_step import WanForwardStep +from megatron.bridge.training.pretrain import pretrain +from megatron.bridge.training.utils.omegaconf_utils import ( + apply_overrides, + create_omegaconf_dict_config, + parse_hydra_overrides, +) +from megatron.bridge.utils.common_utils import get_rank_safe + + +logger: logging.Logger = logging.getLogger(__name__) + + +# Define paths relative to this script's location +# Assumes this script (pretrain_wan.py) is in Megatron-Bridge/examples/recipes/wan/ +# and the config is in a 'conf' subdirectory. +SCRIPT_DIR: Path = Path(__file__).parent.resolve() +DEFAULT_CONFIG_FILENAME: str = "wan_pretrain_override_example.yaml" +DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME + +# DEBUGGING +import numpy as np +import torch +np.set_printoptions(precision=10, suppress=False) +torch.set_printoptions(precision=10, sci_mode=False) + +def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: + """Parse command line arguments, separating known script args from OmegaConf overrides.""" + parser = argparse.ArgumentParser( + description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", + formatter_class=argparse.RawTextHelpFormatter, + ) + parser.add_argument( + "--config-file", + type=str, + default=str(DEFAULT_CONFIG_FILE_PATH), + help="Path to the YAML OmegaConf override file. Default: conf/wan_pretrain_override_example.yaml", + ) + parser.add_argument("--debug", action="store_true", help="Enable debug logging") + + # Parse known args for the script, remaining will be treated as overrides + args, cli_dotlist_overrides = parser.parse_known_args() + return args, cli_dotlist_overrides + + +def main() -> None: + """ + Entry point for the Wan pretraining script. + + This function orchestrates the complete configuration workflow: + 1. Loads the base configuration from pretrain_config() recipe + 2. Applies YAML overrides from --config-file (if exists) + 3. Applies CLI overrides using Hydra-style syntax + 4. Starts Megatron pretraining with the final merged configuration + + Configuration merging preserves callable fields (like activation functions) + and handles type conversions automatically. + + Examples of CLI usage: + # Use default config with custom learning rate + torchrun --nproc_per_node=8 pretrain_wan.py optimizer.lr=0.0002 + + # Custom config file with additional overrides + torchrun --nproc_per_node=8 pretrain_wan.py --config-file my_config.yaml train.train_iters=50000 + + # Multiple overrides for distributed training + torchrun --nproc_per_node=8 pretrain_wan.py \ + model.tensor_model_parallel_size=4 \ + model.pipeline_model_parallel_size=2 \ + train.global_batch_size=512 + """ + args, cli_overrides = parse_cli_args() + + logger.info("Megatron-Bridge Wan Pretraining Script with YAML & CLI Overrides") + logger.info("------------------------------------------------------------------") + + # Load base configuration from the recipe as a Python dataclass + cfg: ConfigContainer = pretrain_config() + logger.info("Loaded base configuration") + + # Print configuration on rank 0 + if get_rank_safe() == 0: + cfg.print_yaml() + + # Convert the initial Python dataclass to an OmegaConf DictConfig for merging + merged_omega_conf, excluded_fields = create_omegaconf_dict_config(cfg) + + # Load and merge YAML overrides if a config file is provided + if args.config_file: + logger.debug(f"Loading YAML overrides from: {args.config_file}") + if not os.path.exists(args.config_file): + logger.error(f"Override YAML file not found: {args.config_file}") + sys.exit(1) + yaml_overrides_omega = OmegaConf.load(args.config_file) + merged_omega_conf = OmegaConf.merge(merged_omega_conf, yaml_overrides_omega) + logger.debug("YAML overrides merged successfully.") + + # Apply command-line overrides using Hydra-style parsing + if cli_overrides: + logger.debug(f"Applying Hydra-style command-line overrides: {cli_overrides}") + merged_omega_conf = parse_hydra_overrides(merged_omega_conf, cli_overrides) + logger.debug("Hydra-style command-line overrides applied successfully.") + + # Apply the final merged OmegaConf configuration back to the original ConfigContainer + logger.debug("Applying final merged configuration back to Python ConfigContainer...") + final_overrides_as_dict = OmegaConf.to_container(merged_omega_conf, resolve=True) + # Apply overrides while preserving excluded fields + apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + + # Display final configuration + if get_rank_safe() == 0: + logger.info("--- Final Merged Configuration ---") + cfg.print_yaml() + logger.info("----------------------------------") + + # Start training + logger.debug("Starting pretraining...") + pretrain(config=cfg, forward_step_func=WanForwardStep()) + + +if __name__ == "__main__": + main() diff --git a/dfm/examples/megatron/recipe/wan/wan.py b/dfm/examples/megatron/recipe/wan/wan.py new file mode 100644 index 00000000..c321d26f --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/wan.py @@ -0,0 +1,219 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import List, Optional, Union + +from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig +from dfm.examples.megatron.recipe.wan.wan_provider import WanModelProvider +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing +from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE +from megatron.bridge.training.comm_overlap import CommOverlapConfig +from megatron.bridge.training.config import ( + CheckpointConfig, + ConfigContainer, + LoggerConfig, + RNGConfig, + TokenizerConfig, + TrainingConfig, +) +from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config + + +def model_config( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + seq_length: int = 1024, +) -> WanModelProvider: + """ + Configure the Wan model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + seq_length (int): Sequence length for the model. + Returns: + WanModelProvider: Configuration for the Wan model. + """ + return WanModelProvider( + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_dtype, + virtual_pipeline_model_parallel_size=None, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + seq_length=seq_length, + ) + + +def pretrain_config( + dir: Optional[str] = None, + name: str = "default", + # Dataset configuration + data_paths: Optional[List[str]] = None, + data_args_path: Optional[str] = None, + train_data_path: Optional[List[str]] = None, + valid_data_path: Optional[List[str]] = None, + test_data_path: Optional[List[str]] = None, + per_split_data_args_path: Optional[str] = None, + mock: bool = False, + # Model configuration + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_dtype: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 1, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + use_megatron_fsdp: bool = False, + # Training hyperparameters + train_iters: int = 10000, + global_batch_size: int = 4, + micro_batch_size: int = 1, + lr: float = 0.9e-4, + lr_warmup_iters: int = 2000, + # Precision recipe + # DEBUGGING + precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", + # precision_config: Optional[Union[MixedPrecisionConfig, str]] = MixedPrecisionConfig( + # fp32=True, + # params_dtype=torch.float32, + # pipeline_dtype=torch.float32, + # autocast_enabled=False, + # ), + comm_overlap_config: Optional[CommOverlapConfig] = None, +) -> ConfigContainer: + """ + Create a pre-training configuration for GPT3 175B model. + + The default configuration is expected to run on 64 nodes with 8 GPUs each. + + Args: + dir (Optional[str]): Base directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + data_paths (Optional[List[str]]): List of paths to dataset files. If None, mock data will be used. + data_args_path (Optional[str]): Path to file containing data arguments. + train_data_path (Optional[List[str]]): List of training data paths. + valid_data_path (Optional[List[str]]): List of validation data paths. + test_data_path (Optional[List[str]]): List of test data paths. + per_split_data_args_path (Optional[str]): Path to JSON file with per-split data configuration. + mock (bool): Whether to use mock data. If True, ignores data_paths. + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism to be passed to model_config. + sequence_parallelism (bool): Whether to use sequence parallelism. + train_iters (int): Total number of training iterations. + global_batch_size (int): Global batch size for training. + micro_batch_size (int): Micro batch size for training. + seq_length (int): Sequence length for training data. + lr (float): Learning rate. + min_lr (float): Minimum learning rate for cosine decay. + lr_warmup_iters (int): Number of warmup iterations for the learning rate. + precision_config (Optional[Union[MixedPrecisionConfig, str]]): Precision configuration for the model. + comm_overlap_config (Optional[CommOverlapConfig]): Communication overlap configuration for the model. + + Returns: + ConfigContainer: Configuration for pre-training. + """ + base_output_dir = dir if dir is not None else os.path.join(os.getcwd(), "nemo_experiments") + run_output_dir = os.path.join(base_output_dir, name) + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + tensorboard_dir = os.path.join(run_output_dir, "tb_logs") + + + model_cfg = model_config( + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, + pipeline_parallelism_dtype=pipeline_parallelism_dtype, + virtual_pipeline_parallelism=virtual_pipeline_parallelism, + context_parallelism=context_parallelism, + sequence_parallelism=sequence_parallelism, + seq_length=1024, + ) + + opt_config, scheduler = distributed_fused_adam_with_cosine_annealing( + lr_warmup_iters=lr_warmup_iters, + lr_decay_iters=train_iters, + max_lr=lr, + ) + opt_config.use_precision_aware_optimizer = False + + if isinstance(precision_config, str): + precision_config = get_mixed_precision_config(precision_config) + + precision_config.grad_reduce_in_fp32 = False + + + # Config Container + cfg = ConfigContainer( + model=model_cfg, + train=TrainingConfig( + train_iters=train_iters, + eval_interval=2000, + eval_iters=32, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + manual_gc=True, + manual_gc_interval=100, + manual_gc_eval=100, + ), + optimizer=opt_config, + scheduler=scheduler, + ddp=DistributedDataParallelConfig( + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=False, + overlap_param_gather=False, + average_in_collective=True, + use_distributed_optimizer=True, + use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True + ), + dataset= WanDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10) + , + logger=LoggerConfig( + log_interval=10, + tensorboard_dir=tensorboard_dir, + log_timers_to_tensorboard=True, + ), + tokenizer=TokenizerConfig(tokenizer_type="NullTokenizer", vocab_size=DEFAULT_NULL_TOKENIZER_VOCAB_SIZE), + checkpoint=CheckpointConfig( + save_interval=2000, + save=checkpoint_dir, + load=checkpoint_dir, + ckpt_format="torch_dist", + fully_parallel_save=True, + ), + rng=RNGConfig(seed=1234), + comm_overlap=comm_overlap_config, + mixed_precision=precision_config, + ) + + return cfg diff --git a/dfm/examples/megatron/recipe/wan/wan_provider.py b/dfm/examples/megatron/recipe/wan/wan_provider.py new file mode 100644 index 00000000..8a38c682 --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/wan_provider.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from dataclasses import dataclass + +import torch +from megatron.core import parallel_state +from megatron.bridge.models.transformer_config import TransformerConfig + +from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.core.models.common.vision_module.vision_module import VisionModule +from dfm.src.megatron.model.wan.wan_model import WanModel + +logger = logging.getLogger(__name__) + +@dataclass +class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): + crossattn_emb_size: int = 1536 + add_bias_linear: bool = True + gated_linear_unit: bool = False + + num_layers: int = 30 + hidden_size: int = 1536 + ffn_hidden_size: int = 8960 + num_attention_heads: int = 12 + layernorm_epsilon: float = 1e-6 + normalization: str = "RMSNorm" + layernorm_zero_centered_gamma: bool = False + add_qkv_bias: bool = True + rotary_interleaved: bool = True + hidden_dropout: float = 0 + attention_dropout: float = 0 + fp16_lm_cross_entropy: bool = False + parallel_output: bool = True + bf16: bool = False + params_dtype: torch.dtype = torch.float32 + qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + # these attributes are unused for images/videos, we just set because bridge training requires for LLMs + seq_length: int = 1024 + share_embeddings_and_output_weights: bool = False + vocab_size: int = 25256 * 8 + make_vocab_size_divisible_by: int = 128 + + # images/videos attributes + in_channels: int = 16 + out_channels: int = 16 + patch_spatial: int = 2 + patch_temporal: int = 1 + freq_dim: int = 256 + text_len: int = 512 + text_dim: int = 4096 + + def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanModel: + vp_size = self.virtual_pipeline_model_parallel_size + if vp_size: + p_size = self.pipeline_model_parallel_size + assert (self.num_layers // p_size) % vp_size == 0, ( + "Make sure the number of model chunks is the same across all pipeline stages." + ) + + model = WanModel + + return model( + self, + pre_process=parallel_state.is_pipeline_first_stage(), + post_process=parallel_state.is_pipeline_last_stage(), + fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, + parallel_output=self.parallel_output, + ) \ No newline at end of file diff --git a/dfm/examples/megatron/recipe/wan/wan_step.py b/dfm/examples/megatron/recipe/wan/wan_step.py new file mode 100644 index 00000000..c06c9322 --- /dev/null +++ b/dfm/examples/megatron/recipe/wan/wan_step.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import partial +from typing import Iterable + +import torch +from megatron.core import parallel_state +from megatron.core.models.common.vision_module.vision_module import VisionModule +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.utils import get_model_config +from dfm.src.megatron.model.wan.flow_matching.flow_pipeline import FlowPipeline +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState + +logger = logging.getLogger(__name__) + +def wan_data_step(qkv_format, dataloader_iter): + batch = next(iter(dataloader_iter.iterable)) + + batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} + + # Construct packed sequence parameters + if ("seq_len_q" in batch) and ("seq_len_kv" in batch): + cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens = torch.cat((zero, cu_seqlens)) + + cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + + batch["packed_seq_params"] = { + "self_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + qkv_format=qkv_format, + ), + "cross_attention": PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens_kv, + qkv_format=qkv_format, + ), + } + + return batch + + +class WanForwardStep: + def __init__(self): + self.diffusion_pipeline = FlowPipeline() + + + def __call__( + self, state: GlobalState, data_iterator: Iterable, model: VisionModule + ) -> tuple[torch.Tensor, partial]: + """ + Forward training step. + """ + timers = state.timers + straggler_timer = state.straggler_timer + + config = get_model_config(model) + + timers("batch-generator", log_level=2).start() + + qkv_format = getattr(config, "qkv_format", "sbhd") + with straggler_timer(bdata=True): + batch = wan_data_step( + qkv_format, data_iterator + ) + timers("batch-generator").stop() + + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss + check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss + + # run diffusion training step + with straggler_timer: + if parallel_state.is_pipeline_last_stage(): + output_batch, loss, split_loss_mask = self.diffusion_pipeline.training_step(model, batch) + output_tensor = torch.mean(loss, dim=-1) + batch["loss_mask"] = split_loss_mask + else: + output_tensor = self.diffusion_pipeline.training_step(model, batch) + + # DEBUGGING + # TODO: do we need to gather output with sequence or context parallelism here + # especially when we have pipeline parallelism + + loss = output_tensor + if "loss_mask" not in batch or batch["loss_mask"] is None: + loss_mask = torch.ones_like(loss) + loss_mask = batch["loss_mask"] + + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) + + return output_tensor, loss_function + + + def _create_loss_function(self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool) -> partial: + """Create a partial loss function with the specified configuration. + + Args: + loss_mask: Used to mask out some portions of the loss + check_for_nan_in_loss: Whether to check for NaN values in the loss + check_for_spiky_loss: Whether to check for spiky loss values + + Returns: + A partial function that can be called with output_tensor to compute the loss + """ + return partial( + masked_next_token_loss, + loss_mask, + check_for_nan_in_loss=check_for_nan_in_loss, + check_for_spiky_loss=check_for_spiky_loss, + ) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index f316ea25..ac5b8657 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -15,7 +15,7 @@ from tqdm import tqdm from dfm.src.megatron.model.wan.wan_model import WanModel -from megatron.bridge.models.wan.wan_provider import WanModelProvider +from dfm.examples.megatron.recipe.wan.wan_provider import WanModelProvider from dfm.src.megatron.model.wan.modules.t5 import T5EncoderModel from dfm.src.megatron.model.wan.modules import WanVAE from dfm.src.megatron.model.wan.inference.utils.fm_solvers import ( From daac350182d487b4c7477da06946be43ab59535b Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Mon, 3 Nov 2025 21:46:04 +0000 Subject: [PATCH 06/80] Update example scripts and add new data module for multimodal datasets - Added comments to clarify file purposes in example_commands.sh, inference_wan.py, pretrain_wan.py, wan_provider.py, wan_step.py, and wan.py. - Introduced EnergonMultiModalDataModule for handling multimodal datasets in nemo_vfm. - Created SequentialMegatronSampler for efficient sequential sampling in large datasets. - Added new files for DIT attention and base data handling. This commit enhances documentation and introduces new functionalities for better data management and processing. --- .../megatron/recipe/wan/example_commands.sh | 4 +- .../megatron/recipe/wan/inference_wan.py | 3 + .../megatron/recipe/wan/pretrain_wan.py | 2 + dfm/examples/megatron/recipe/wan/wan.py | 1 + .../megatron/recipe/wan/wan_provider.py | 2 + dfm/examples/megatron/recipe/wan/wan_step.py | 2 + dfm/src/megatron/base/layerspec/__init__.py | 0 .../model/wan/flow_matching/flow_pipeline.py | 8 +- .../model/wan/inference/configs/__init__.py | 2 + dfm/src/megatron/model/wan/wan_layer_spec.py | 36 +- nemo_vfm/diffusion/data/base.py | 444 ++++++++++++++++++ .../models/dit/dit_attention_megatron.py | 0 12 files changed, 480 insertions(+), 24 deletions(-) create mode 100644 dfm/src/megatron/base/layerspec/__init__.py create mode 100644 nemo_vfm/diffusion/data/base.py create mode 100644 nemo_vfm/diffusion/models/dit/dit_attention_megatron.py diff --git a/dfm/examples/megatron/recipe/wan/example_commands.sh b/dfm/examples/megatron/recipe/wan/example_commands.sh index 13235aed..1505206e 100644 --- a/dfm/examples/megatron/recipe/wan/example_commands.sh +++ b/dfm/examples/megatron/recipe/wan/example_commands.sh @@ -1,3 +1,5 @@ +#Let's make a md file instead + ### set path to Megatron-Bridge DFM_PATH=/path/to/dfm MBRIDGE_PATH=/path/to/megatron-bridge @@ -13,7 +15,7 @@ pip install imageio-ffmpeg ### Convert checkpoint -See ${MBRIDGE_PATH}/examples/conversion/convert_wan_checkpoints.py for details. +# See ${MBRIDGE_PATH}/examples/conversion/convert_wan_checkpoints.py for details. ### Finetuning diff --git a/dfm/examples/megatron/recipe/wan/inference_wan.py b/dfm/examples/megatron/recipe/wan/inference_wan.py index 2f480a2b..3f84eafc 100644 --- a/dfm/examples/megatron/recipe/wan/inference_wan.py +++ b/dfm/examples/megatron/recipe/wan/inference_wan.py @@ -14,6 +14,9 @@ # --base_seed 42 \ # --sample_steps 50 + +# Goes to examples/megatron/recipe/wan + import argparse import logging import os diff --git a/dfm/examples/megatron/recipe/wan/pretrain_wan.py b/dfm/examples/megatron/recipe/wan/pretrain_wan.py index 7742397e..81b4f2f4 100644 --- a/dfm/examples/megatron/recipe/wan/pretrain_wan.py +++ b/dfm/examples/megatron/recipe/wan/pretrain_wan.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# Goes to examples/megatron/recipe/wan """ Wan Pretraining Script with YAML and CLI Configuration Overrides. diff --git a/dfm/examples/megatron/recipe/wan/wan.py b/dfm/examples/megatron/recipe/wan/wan.py index c321d26f..f64d0d72 100644 --- a/dfm/examples/megatron/recipe/wan/wan.py +++ b/dfm/examples/megatron/recipe/wan/wan.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Goes to src/megatron/reicepe/wan import os from typing import List, Optional, Union diff --git a/dfm/examples/megatron/recipe/wan/wan_provider.py b/dfm/examples/megatron/recipe/wan/wan_provider.py index 8a38c682..63229b37 100644 --- a/dfm/examples/megatron/recipe/wan/wan_provider.py +++ b/dfm/examples/megatron/recipe/wan/wan_provider.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# Goes into the model import logging from dataclasses import dataclass diff --git a/dfm/examples/megatron/recipe/wan/wan_step.py b/dfm/examples/megatron/recipe/wan/wan_step.py index c06c9322..cb19386d 100644 --- a/dfm/examples/megatron/recipe/wan/wan_step.py +++ b/dfm/examples/megatron/recipe/wan/wan_step.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# Move to the model and wan import logging from functools import partial from typing import Iterable diff --git a/dfm/src/megatron/base/layerspec/__init__.py b/dfm/src/megatron/base/layerspec/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 72ddc47b..da95f04a 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -12,16 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Optional, Tuple, List -import numpy as np import torch -from megatron.core import parallel_state -from torch import Tensor from diffusers import WanPipeline +from megatron.core import parallel_state from dfm.src.megatron.model.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling from dfm.src.megatron.model.wan.utils.utils import patchify, thd_split_inputs_cp + class FlowPipeline: def __init__( @@ -205,7 +203,7 @@ def training_step( mean_weighted_loss = weighted_loss.mean() if torch.isnan(mean_weighted_loss) or mean_weighted_loss > 100: print(f"[ERROR] Loss explosion! Loss={mean_weighted_loss.item():.3f}") - print(f"[DEBUG] Stopping training - check hyperparameters") + print("[DEBUG] Stopping training - check hyperparameters") raise ValueError(f"Loss exploded: {mean_weighted_loss.item()}") return model_pred, weighted_loss, split_loss_mask diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index a28c03c5..938da53c 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -1,6 +1,8 @@ import copy import os +# Change to diffusors + os.environ['TOKENIZERS_PARALLELISM'] = 'false' from .wan_i2v_14B import i2v_14B diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index f98576ad..267c6ad0 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -17,17 +17,16 @@ import copy from dataclasses import dataclass -from typing import Union, Optional +from typing import Optional, Union import torch -import torch.cuda.amp as amp import torch.nn as nn from megatron.core import parallel_state, tensor_parallel +from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.attention import ( CrossAttention, - CrossAttentionSubmodules, SelfAttention, - SelfAttentionSubmodules, ) from megatron.core.transformer.custom_layers.transformer_engine import ( TEColumnParallelLinear, @@ -42,8 +41,7 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.utils import make_viewless_tensor -from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.extensions.transformer_engine import TENorm + try: import transformer_engine # pylint: disable=unused-import @@ -57,6 +55,8 @@ class WanLayerNorm(nn.LayerNorm): + # Note to parth: Can we replace this with te layer norm or fuse with linear layer? + # (@huy) Remove this comment after you have answered the question. def __init__(self, dim, eps=1e-6, elementwise_affine=False): super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) @@ -70,7 +70,7 @@ def forward(self, x): @dataclass -class WanSelfAttentionSubmodules: +class WanSelfAttentionSubmodules: # Call this DiTSelfAttentionSubmodules or DiTSelfAttentionConfig? """ Configuration class for specifying the submodules of a self-attention. """ @@ -78,13 +78,13 @@ class WanSelfAttentionSubmodules: linear_qkv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None - layernorm_across_head: bool = False + layernorm_across_head: bool = False # Should be moved to Trasnformer config and not layerspec. (@huy to remove this) q_layernorm: Union[ModuleSpec, type] = None k_layernorm: Union[ModuleSpec, type] = None @dataclass -class WanCrossAttentionSubmodules: +class WanCrossAttentionSubmodules: # Call this DiTCrossAttentionSubmodules or DiTCrossAttentionConfig? """ Configuration class for specifying the submodules of a cross-attention. """ @@ -92,12 +92,12 @@ class WanCrossAttentionSubmodules: linear_kv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None linear_proj: Union[ModuleSpec, type] = None - layernorm_across_head: bool = False + layernorm_across_head: bool = False # Should be moved to Trasnformer config and not layerspec. (@huy to remove this) q_layernorm: Union[ModuleSpec, type] = None k_layernorm: Union[ModuleSpec, type] = None -class WanSelfAttention(SelfAttention): +class WanSelfAttention(SelfAttention): # Call this DitSelfAttention or DiTSelfAttentionConfig? def __init__( self, config: TransformerConfig, @@ -116,7 +116,7 @@ def __init__( pg_collection, ) - self.layernorm_across_head = submodules.layernorm_across_head + self.layernorm_across_head = getattr(self.config, "False", submodules.layernorm_across_head) # override q_layernorm if submodules.q_layernorm is not None: @@ -124,7 +124,6 @@ def __init__( q_layernorm_size = self.query_projection_size else: q_layernorm_size = self.hidden_size_per_attention_head - import transformer_engine as te norm_config = copy.deepcopy(self.config) norm_config.normalization = "RMSNorm" self.q_layernorm = build_module( @@ -142,7 +141,6 @@ def __init__( k_layernorm_size = self.kv_projection_size else: k_layernorm_size = self.hidden_size_per_attention_head - import transformer_engine as te norm_config = copy.deepcopy(self.config) norm_config.normalization = "RMSNorm" self.k_layernorm = build_module( @@ -237,7 +235,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): return query, key, value -class WanCrossAttention(CrossAttention): +class WanCrossAttention(CrossAttention): # DiTCrossAttention or DiTCrossAttentionConfig? def __init__( self, config: TransformerConfig, @@ -264,7 +262,6 @@ def __init__( q_layernorm_size = self.query_projection_size else: q_layernorm_size = self.hidden_size_per_attention_head - import transformer_engine as te norm_config = copy.deepcopy(self.config) norm_config.normalization = "RMSNorm" self.q_layernorm = build_module( @@ -282,7 +279,6 @@ def __init__( k_layernorm_size = self.kv_projection_size else: k_layernorm_size = self.hidden_size_per_attention_head - import transformer_engine as te norm_config = copy.deepcopy(self.config) norm_config.normalization = "RMSNorm" self.k_layernorm = build_module( @@ -322,6 +318,9 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): ) query = query.view(*new_tensor_shape) + # replace with our own implementation (Todo: @huy ) + query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) + # gather query and key heads across TP ranks if self.layernorm_across_head is True if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: query = query.transpose(-2, -1) @@ -545,7 +544,8 @@ def forward( return output, context -import transformer_engine as te + + def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: params = {"attn_mask_type": AttnMaskType.padding} return ModuleSpec( diff --git a/nemo_vfm/diffusion/data/base.py b/nemo_vfm/diffusion/data/base.py new file mode 100644 index 00000000..f412b516 --- /dev/null +++ b/nemo_vfm/diffusion/data/base.py @@ -0,0 +1,444 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Any, Dict, Literal, Optional + +from megatron.core import parallel_state +from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset +from torch.utils.data import DataLoader +from typing_extensions import Self + +from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig +from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder +from nemo.lightning.io.mixin import IOMixin, serialization, track_io +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +class EnergonMultiModalDataModule(pl.LightningDataModule, IOMixin): + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + tokenizer, + image_processor, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 1, + num_workers: int = 1, + num_val_workers: int | None = None, + pin_memory: bool = True, + shuffle_buffer_size: int = 100, + max_samples_per_sequence: int | None = None, + multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(), + task_encoder: Optional[MultiModalTaskEncoder] = None, + decoder_seq_length: Optional[int] = None, + packing_buffer_size: Optional[int] = None, + validation_task_encoder: Optional[MultiModalTaskEncoder] = None, + **kwargs, + ) -> None: + """ + Initialize the EnergonMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. + max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. + Defaults to None (loads the whole tar file at once). + task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. + If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. + decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models + packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. + validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding + and batching samples for validation. Defaults to None and will be the same as task_encoder. + **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon + """ + + super().__init__() + self.path = path + self.tokenizer = tokenizer + self.image_processor = image_processor + self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.multimodal_sample_config = multimodal_sample_config + self.shuffle_buffer_size = shuffle_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence + self.task_encoder = task_encoder or MultiModalTaskEncoder( + tokenizer=self.tokenizer, + image_processor=self.image_processor, + multimodal_sample_config=multimodal_sample_config, + ) + self.init_global_step = 0 + self.data_sampler = SequentialMegatronSampler( + seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_length, + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + ) + self.train_dataloader_object = None + self.val_dataloader_object = None + self.packing_buffer_size = packing_buffer_size + self.validation_task_encoder = validation_task_encoder or self.task_encoder + self.num_val_workers = num_val_workers or self.num_workers + self.kwargs = kwargs + + def io_init(self, **kwargs) -> fdl.Config[Self]: + + cfg_kwargs = { + k: deepcopy(v) + for k, v in kwargs.items() + if k not in ['image_processor', 'task_encoder', 'validation_task_encoder'] + } + + for val in cfg_kwargs.values(): + if not serialization.find_node_traverser(type(val)): + track_io(type(val)) + cfg = fdl.Config(type(self), **cfg_kwargs) + return cfg + + def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + + if split not in {'train', 'val'}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + + if split == "train": + task_encoder = self.task_encoder + else: + task_encoder = self.validation_task_encoder + + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + task_encoder=task_encoder, + worker_config=worker_config, + packing_buffer_size=self.packing_buffer_size, + split_part=split, + shuffle_buffer_size=self.shuffle_buffer_size, + max_samples_per_sequence=self.max_samples_per_sequence, + **self.kwargs, + ) + + return _dataset + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """ + Initialize and return the training DataLoader. + + This method initializes the DataLoader for the training dataset. It uses the global step + from the trainer to configure the data sampler and ensures that the parallel state is initialized + correctly for distributed training. + + Returns: + TRAIN_DATALOADERS: The DataLoader for the training dataset. + """ + if self.trainer: + self.init_global_step = self.trainer.global_step + self.data_sampler.init_global_step = self.init_global_step + logging.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") + if self.train_dataloader_object: + return self.train_dataloader_object + if not parallel_state.is_initialized(): + logging.info( + f"Muiltimodal data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + logging.info( + f" Multimodal train dataloader initializing with" + f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " + ) + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + train_dataset = self.datasets_provider(worker_config, split='train') + energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) + self.train_dataloader_object = energon_dataloader + return self.train_dataloader_object + + def val_dataloader(self) -> EVAL_DATALOADERS: + """ + Initialize and return the validation DataLoader. + + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. + + Returns: + EVAL_DATALOADERS: The DataLoader for the validation dataset. + """ + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + logging.info( + f"Muiltimodal val data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_val_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split='val') + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object + + def test_dataloader(self) -> None: + """ + Return None as test dataset split does not exist. + + This method overrides the test_dataloader method and returns None since the test dataset split + is not defined or used in this module. + + Returns: + None + """ + logging.warning("Multimodal dataloader test dataset split does not exist") + return None + + def state_dict(self) -> Dict[str, Any]: + """ + Save the state of the data module. + + This method is called when saving a checkpoint. It generates and saves the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Returns: + Dict[str, Any]: A dictionary containing the state of the data module. + """ + + if self.trainer: + dataloader_obj = self.trainer.train_dataloader + + state = [] + # All ranks should be zero except the dp rank. + if ( + parallel_state.get_context_parallel_rank() + or parallel_state.get_pipeline_model_parallel_rank() + or parallel_state.get_tensor_model_parallel_rank() + or parallel_state.get_expert_model_parallel_rank() + ) == 0: + # Save_state_global in energon assumes that we call it for only the first rank within each group that + # shares the same dataloader state. By making sure that current rank is the first rank in a model + # parallel group, we ensure this. + state = dataloader_obj.save_state_global(global_dst_rank=0) + + consumed_samples = self.data_sampler.compute_consumed_samples( + self.trainer.global_step - self.init_global_step + ) + + if state is None: + state = [] # Megatron core requires all the states on all the ranks to have same python + # type. Energon sends the state as a list + logging.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") + return {'dataloader_state': state, 'consumed_samples': consumed_samples} + + logging.warning("trainer object not connected to data module object returning empty state") + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + if not 'dataloader_state' in state_dict: + logging.warning( + f"Data loader state cannot be resumed from state_dict, " + f"it does not have the required key dataloader_state. It has {state_dict.keys()}" + ) + return + + state = state_dict['dataloader_state'] + try: + if self.trainer: + self.trainer.datamodule.train_dataloader().restore_state_global(state) + logging.info("Multimodal dataloader state restored") + else: + logging.error(f"Cannot restore state from state_dict {state_dict}") + raise ValueError( + "Cannot restore state from state_dict: " + "Is the trainer object is initialized and attached to datamodule???" + ) + except Exception as e: + logging.warning( + f"Failed to dataloader restore state due to [Please ensure you are using same version " + f"of energon while saving and loading, Continuing without restoring data loader] : {e}" + ) + + try: + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + + consumed_samples = state_dict['consumed_samples'] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + logging.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") + update_num_microbatches( + consumed_samples=consumed_samples, + consistency_check=False, + ) + + +class SequentialMegatronSampler(MegatronDataSampler): + """ + A data sampler for sequential sampling in Megatron, designed to handle large datasets efficiently. + + This class extends the MegatronDataSampler to support sequential sampling for large datasets. + It includes functionality for handling micro-batches and tracking consumed samples across training steps. + + Attributes: + seq_len (int): The sequence length for each sample. + micro_batch_size (int): The number of samples in each micro-batch. + init_consumed_samples (int): The initial number of samples that have been consumed (used for resuming training). + prev_consumed_samples (int): Tracks the number of consumed samples before the current step. + if_first_step (int): Flag to indicate if it's the first training step. + prev_global_batch_size (Optional[int]): The global batch size from the previous step. + init_global_step (int): The initial global step at the start of training. + """ + + def __init__( + self, + seq_len: int, + micro_batch_size: int = 4, + global_batch_size: int = 8, + init_consumed_samples: int = 0, + decoder_seq_len: Optional[int] = None, + init_global_step=0, + ): + """ + Initialize the SequentialMegatronSampler. + + Parameters: + seq_len (int): The sequence length for each sample. + micro_batch_size (int, optional): The number of samples in each micro-batch. Defaults to 4. + init_consumed_samples (int, optional): The initial number of samples that have been consumed. Defaults to 0. + init_global_step (int, optional): The initial global step at the start of training. Defaults to 0. + """ + super().__init__( + seq_len=seq_len, + decoder_seq_len=decoder_seq_len, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + init_consumed_samples=init_consumed_samples, + init_global_step=init_global_step, + ) + + def transform_dataloader(self, dataloader: DataLoader) -> DataLoader: + """ + Transform the DataLoader for sequential sampling. + + This method returns the DataLoader as is, but it can be overridden to apply specific transformations to + the DataLoader if needed. + + Parameters: + dataloader (DataLoader): The original DataLoader to be transformed. + + Returns: + DataLoader: The transformed DataLoader. + """ + return dataloader + + @property + def megatron_data_kwargs(self) -> Dict[str, Any]: + """ + Return the keyword arguments required for Megatron data handling. + + This property provides the necessary arguments that Megatron uses to handle data, including sequence length, + micro-batch size, and the number of micro-batches. + + Returns: + Dict[str, Any]: A dictionary containing the Megatron data handling arguments. + """ + return { + "seq_length": self.seq_len, + "micro_batch_size": self.micro_batch_size, + "num_microbatches": self.num_microbatches, + } diff --git a/nemo_vfm/diffusion/models/dit/dit_attention_megatron.py b/nemo_vfm/diffusion/models/dit/dit_attention_megatron.py new file mode 100644 index 00000000..e69de29b From d5d0106ae9a8aa836342b41f44094814dc5017ed Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 3 Nov 2025 14:51:03 -0800 Subject: [PATCH 07/80] workable code before refactoring --- .../megatron/recipe/wan/example_commands.sh | 2 +- .../wan/prepare_energon_dataset_wan_images.py | 429 ++++++++++++------ .../model/wan/flow_matching/flow_pipeline.py | 46 +- 3 files changed, 319 insertions(+), 158 deletions(-) diff --git a/dfm/examples/megatron/recipe/wan/example_commands.sh b/dfm/examples/megatron/recipe/wan/example_commands.sh index 13235aed..e6ac4a6f 100644 --- a/dfm/examples/megatron/recipe/wan/example_commands.sh +++ b/dfm/examples/megatron/recipe/wan/example_commands.sh @@ -6,7 +6,7 @@ export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launc ### install dependencies pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 -python3 -m pip install --upgrade diffusers +python3 -m pip install --upgrade diffusers==0.35.1 pip install easydict pip install imageio pip install imageio-ffmpeg diff --git a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py index 67ea4539..70758ae6 100644 --- a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py +++ b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py @@ -1,16 +1,21 @@ import os import io import json +import pickle import tarfile from pathlib import Path -from typing import Dict, List, Optional, Tuple, Iterable +from typing import Dict, List, Optional, Tuple, Iterable, Any +import multiprocessing as mp +import math import cv2 import numpy as np import torch +import webdataset as wds from diffusers import AutoencoderKLWan from transformers import AutoTokenizer, UMT5EncoderModel +from tqdm import tqdm def _map_interpolation(resize_mode: str) -> int: @@ -65,29 +70,53 @@ def _resize_frame( return frame original_height, original_width = frame.shape[:2] + target_height, target_width = target_size + + interpolation = _map_interpolation(resize_mode) + + if not maintain_aspect_ratio: + resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=interpolation) + return resized_frame + + if center_crop: + # Resize-to-cover: scale so both dims >= target, then center-crop to exact target + scale = max(target_height / max(1, original_height), target_width / max(1, original_width)) + resize_height = int(math.ceil(original_height * scale)) + resize_width = int(math.ceil(original_width * scale)) + + resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) + + y_start = max(0, (resize_height - target_height) // 2) + x_start = max(0, (resize_width - target_width) // 2) + y_end = y_start + target_height + x_end = x_start + target_width + + # Bound checks (should be safe due to ceil, but guard anyway) + y_start = min(y_start, max(0, resize_height - target_height)) + x_start = min(x_start, max(0, resize_width - target_width)) + y_end = min(y_end, resize_height) + x_end = min(x_end, resize_width) + + cropped = resized_frame[y_start:y_end, x_start:x_end] + + # If due to rounding one dim is still short, pad minimally (rare) + pad_h = max(0, target_height - cropped.shape[0]) + pad_w = max(0, target_width - cropped.shape[1]) + if pad_h > 0 or pad_w > 0: + cropped = np.pad( + cropped, + ((pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2), (0, 0)), + mode="edge", + ) + cropped = cropped[:target_height, :target_width] + + return cropped + + # Aspect-preserving resize-to-fit (no crop): may be smaller than target in one dim resize_height, resize_width = _calculate_resize_dimensions( original_height, original_width, target_size, maintain_aspect_ratio ) - - interpolation = _map_interpolation(resize_mode) resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) - - if maintain_aspect_ratio and center_crop: - target_height, target_width = target_size - if resize_height != target_height or resize_width != target_width: - y_start = max(0, (resize_height - target_height) // 2) - x_start = max(0, (resize_width - target_width) // 2) - y_end = min(resize_height, y_start + target_height) - x_end = min(resize_width, x_start + target_width) - resized_frame = resized_frame[y_start:y_end, x_start:x_end] - - if resized_frame.shape[0] < target_height or resized_frame.shape[1] < target_width: - pad_height = max(0, target_height - resized_frame.shape[0]) - pad_width = max(0, target_width - resized_frame.shape[1]) - resized_frame = np.pad( - resized_frame, ((0, pad_height), (0, pad_width), (0, 0)), mode="constant", constant_values=0 - ) - return resized_frame @@ -100,6 +129,14 @@ def _decode_image_bytes_to_rgb(image_bytes: bytes) -> Optional[np.ndarray]: return img_rgb +def _select_target_size_for_image(image_rgb: np.ndarray) -> Tuple[int, int]: + h, w = image_rgb.shape[:2] + if h <= w: + return (480, 832) + else: + return (832, 480) + + def _image_to_video_tensor( image_rgb: np.ndarray, target_size: Optional[Tuple[int, int]], @@ -158,17 +195,22 @@ def _encode_text( caption: str, ) -> torch.Tensor: caption = (caption or "").strip() + # Pad to 512, then slice back to the non-padded length inputs = tokenizer( - caption, - max_length=512, + [caption], padding="max_length", truncation=True, + max_length=512, return_tensors="pt", return_attention_mask=True, + add_special_tokens=True, ) inputs = {k: v.to(device) for k, v in inputs.items()} - outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state - return outputs + outputs = text_encoder( + input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] + ).last_hidden_state # [1, L, C] + seq_len = int(inputs["attention_mask"][0].sum().item()) + return outputs[0, :seq_len, :] @torch.no_grad() @@ -229,17 +271,179 @@ def _read_tar_member_bytes(tf: tarfile.TarFile, member: tarfile.TarInfo) -> byte return f.read() +def _process_tar_with_models( + tar_path: Path, + image_exts: Tuple[str, ...], + opts: Dict[str, Any], + device: str, + vae: AutoencoderKLWan, + text_encoder: UMT5EncoderModel, + tokenizer: AutoTokenizer, + model_dtype: torch.dtype, + sink: Any, + index: int, + tqdm_position: int = 0, +) -> int: + processed = 0 + failed = 0 + + try: + with tarfile.open(tar_path, mode="r:*") as tf: + pairs = list(_iter_tar_images_and_captions(tf, image_exts)) + for img_member, cap_member in tqdm( + pairs, total=len(pairs), desc=f"{tar_path.name}", unit="img", position=tqdm_position, leave=False + ): + img_name = os.path.basename(img_member.name) + + try: + img_bytes = _read_tar_member_bytes(tf, img_member) + if not img_bytes: + failed += 1 + continue + rgb = _decode_image_bytes_to_rgb(img_bytes) + if rgb is None: + failed += 1 + continue + + caption_text = "" + if cap_member is not None: + try: + caption_bytes = _read_tar_member_bytes(tf, cap_member) + caption_text = caption_bytes.decode("utf-8", errors="ignore") + except Exception: + caption_text = "" + + target_size = _select_target_size_for_image(rgb) + video_tensor = _image_to_video_tensor( + image_rgb=rgb, + target_size=target_size, + resize_mode=opts["resize_mode"], + maintain_aspect_ratio=not opts.get("no_aspect_ratio", False), + center_crop=opts.get("center_crop", False), + target_dtype=model_dtype, + ) + + text_embed = _encode_text(tokenizer, text_encoder, device, caption_text) + latents = _encode_video_latents( + vae, device, video_tensor, deterministic_latents=not opts.get("stochastic", False) + ) + + # text_embed is already sliced to non-padded tokens: [L_actual, C] + text_embed_cpu = text_embed.detach().to(device="cpu") + latents_cpu = latents.detach().to(device="cpu")[0] + + C, T, H, W = video_tensor.shape[1:] + json_data = { + "source_tar": str(tar_path), + "tar_member": img_member.name, + "image_name": img_name, + "processed_frames": int(T), + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not opts.get("stochastic", False)), + "memory_optimization": bool(not opts.get("no_memory_optimization", False)), + "model_version": "wan2.1", + "resize_settings": { + "target_size": target_size, + "resize_mode": opts["resize_mode"], + "maintain_aspect_ratio": bool(not opts.get("no_aspect_ratio", False)), + "center_crop": bool(opts.get("center_crop", False)), + }, + } + + sample = { + "__key__": f"{index:09}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu, protocol=pickle.HIGHEST_PROTOCOL), + "json": json_data, + } + sink.write(sample) + + index += 1 + processed += 1 + except Exception: + failed += 1 + continue + except Exception as e: + print(f"Failed to process tar {tar_path}: {e}") + return index + + print(f"Processed tar {tar_path.name}: {processed} ok, {failed} failed. WDS written") + return index + + +def _worker_run( + rank: int, + device: str, + tar_paths: List[str], + in_root: str, + out_root: str, + image_exts: Tuple[str, ...], + opts: Dict[str, Any], +): + try: + torch.cuda.set_device(int(device.split(":")[-1])) + except Exception: + pass + + vae, text_encoder, tokenizer, model_dtype = _init_hf_models( + model_id=opts["model"], + device=device, + enable_memory_optimization=not opts.get("no_memory_optimization", False), + ) + + out_root_path = Path(out_root) + in_root_path = Path(in_root) + + # DEBUGGING + for tar_str in tar_paths: + # for tar_str in tar_paths[:1]: + tar_path = Path(tar_str) + # Mirror the original directory structure from input_dir under output_root + try: + rel_parent = tar_path.parent.relative_to(in_root_path) + except Exception: + rel_parent = Path("") + out_dir = out_root_path / rel_parent + out_dir.mkdir(parents=True, exist_ok=True) + + out_tar = out_dir / f"{tar_path.stem}.tar" + if opts.get("skip_existing") and out_tar.exists(): + continue + + index = 0 + with wds.TarWriter(str(out_tar)) as sink: + index = _process_tar_with_models( + tar_path=tar_path, + image_exts=image_exts, + opts=opts, + device=device, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + model_dtype=model_dtype, + sink=sink, + index=index, + tqdm_position=rank, + ) + def main(): import argparse parser = argparse.ArgumentParser( description=( - "Prepare WAN encodings for image tar shards. Each .tar is written to a same-named directory " - "containing per-image VAE latents (.pth), T5 embeddings (.pkl), and metadata (.json)." + "Prepare WAN encodings for image tar shards and write WebDataset shards (pth, pickle, json)." ) ) parser.add_argument("--input_dir", type=str, required=True, help="Directory containing .tar shards of images") - parser.add_argument("--output_root", type=str, required=True, help="Root directory to write per-tar output dirs") + parser.add_argument("--output_dir", type=str, required=False, help="Directory to write webdataset shards") + parser.add_argument( + "--output_root", + type=str, + required=False, + help="Deprecated alias for --output_dir; if provided, will be used as output_dir", + ) parser.add_argument( "--model", default="Wan-AI/Wan2.1-T2V-14B-Diffusers", @@ -259,8 +463,6 @@ def main(): default=".jpg,.jpeg,.png,.webp", help="Comma-separated list of image extensions to include", ) - parser.add_argument("--height", type=int, default=None, help="Target height for image frames") - parser.add_argument("--width", type=int, default=None, help="Target width for image frames") parser.add_argument( "--resize_mode", default="bilinear", @@ -272,30 +474,26 @@ def main(): parser.add_argument( "--skip-existing", action="store_true", - help="Skip images whose output files already exist (all three: .pth, .pkl, .json)", + help="No-op in WDS mode; retained for CLI compatibility", + ) + parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per WDS shard") + parser.add_argument( + "--gpus", + type=str, + default="0", + help="Comma-separated GPU indices to use (e.g., '0,1,2,3')", ) args = parser.parse_args() input_dir = Path(args.input_dir) - output_root = Path(args.output_root) + # Resolve output directory (support legacy --output_root) + resolved_output_dir = args.output_dir or args.output_root + if resolved_output_dir is None: + parser.error("--output_dir must be specified (or legacy --output_root)") + output_root = Path(resolved_output_dir) output_root.mkdir(parents=True, exist_ok=True) - # Target size - target_size = None - if args.height is not None and args.width is not None: - target_size = (args.height, args.width) - elif (args.height is None) ^ (args.width is None): - parser.error("Both --height and --width must be specified together") - - # Init HF models - device = "cuda" - vae, text_encoder, tokenizer, model_dtype = _init_hf_models( - model_id=args.model, - device=device, - enable_memory_optimization=not args.no_memory_optimization, - ) - image_exts = tuple(ext.strip().lower() for ext in args.image_exts.split(",") if ext.strip()) # Find tar files @@ -303,107 +501,48 @@ def main(): if not tar_files: raise FileNotFoundError(f"No .tar files found in {input_dir}") - # DEBUGGING - # for tar_path in tar_files: - for tar_path in tar_files[:1]: - tar_stem = tar_path.name[:-4] # drop .tar - out_dir = output_root / tar_stem - out_dir.mkdir(parents=True, exist_ok=True) - - processed = 0 - failed = 0 - - # Open tar for streaming read - try: - with tarfile.open(tar_path, mode="r:*") as tf: - for img_member, cap_member in _iter_tar_images_and_captions(tf, image_exts): - img_name = os.path.basename(img_member.name) - stem = os.path.splitext(img_name)[0] - - latents_path = out_dir / f"{stem}.pth" - text_path = out_dir / f"{stem}.pkl" - meta_path = out_dir / f"{stem}.json" + # Parse GPU list and shard tars + gpu_ids = [s.strip() for s in args.gpus.split(",") if s.strip()] + devices = [f"cuda:{gid}" for gid in gpu_ids] + num_workers = len(devices) if devices else 1 + + shards: List[List[str]] = [[] for _ in range(num_workers)] + for idx, tar_path in enumerate(tar_files): + shards[idx % num_workers].append(str(tar_path)) + + opts = { + "model": args.model, + "stochastic": bool(args.stochastic), + "no_memory_optimization": bool(args.no_memory_optimization), + "resize_mode": args.resize_mode, + "no_aspect_ratio": bool(args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + "skip_existing": bool(args.skip_existing), + "shard_maxcount": int(args.shard_maxcount), + } - if args.skip_existing and latents_path.exists() and text_path.exists() and meta_path.exists(): - continue + # Ensure CUDA-safe multiprocessing + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass - try: - img_bytes = _read_tar_member_bytes(tf, img_member) - if not img_bytes: - failed += 1 - continue - rgb = _decode_image_bytes_to_rgb(img_bytes) - if rgb is None: - failed += 1 - continue - - caption_text = "" - if cap_member is not None: - try: - caption_bytes = _read_tar_member_bytes(tf, cap_member) - caption_text = caption_bytes.decode("utf-8", errors="ignore") - except Exception: - caption_text = "" - - video_tensor = _image_to_video_tensor( - image_rgb=rgb, - target_size=target_size, - resize_mode=args.resize_mode, - maintain_aspect_ratio=not args.no_aspect_ratio, - center_crop=args.center_crop, - target_dtype=model_dtype, - ) - - # Encode - text_embed = _encode_text(tokenizer, text_encoder, device, caption_text) - latents = _encode_video_latents( - vae, device, video_tensor, deterministic_latents=not args.stochastic - ) - - # Move to CPU and drop batch dim - text_embed_cpu = text_embed.detach().to(device="cpu")[0] - latents_cpu = latents.detach().to(device="cpu")[0] - - # Save outputs - torch.save(latents_cpu, latents_path) - # Use pickle for text embeddings to keep exact dtype/shape - with open(text_path, "wb") as f: - import pickle - - pickle.dump(text_embed_cpu, f, protocol=pickle.HIGHEST_PROTOCOL) - - # Metadata - C, T, H, W = video_tensor.shape[1:] - json_data = { - "source_tar": str(tar_path), - "tar_member": img_member.name, - "image_name": img_name, - "processed_frames": int(T), # always 1 - "processed_height": int(H), - "processed_width": int(W), - "caption": caption_text, - "deterministic_latents": bool(not args.stochastic), - "memory_optimization": bool(not args.no_memory_optimization), - "model_version": "wan2.1", - "resize_settings": { - "target_size": target_size, - "resize_mode": args.resize_mode, - "maintain_aspect_ratio": bool(not args.no_aspect_ratio), - "center_crop": bool(args.center_crop), - }, - } - with open(meta_path, "w", encoding="utf-8") as f: - json.dump(json_data, f, ensure_ascii=False) - - processed += 1 - except Exception: - failed += 1 - continue - except Exception as e: - print(f"Failed to process tar {tar_path}: {e}") - continue - - print(f"Processed tar {tar_path.name}: {processed} ok, {failed} failed. Output -> {out_dir}") + if num_workers == 1: + _worker_run(0, devices[0] if devices else "cuda:0", shards[0], str(input_dir), str(output_root), image_exts, opts) + else: + procs: List[mp.Process] = [] + for rank, device in enumerate(devices): + if not shards[rank]: + continue + p = mp.Process( + target=_worker_run, + args=(rank, device, shards[rank], str(input_dir), str(output_root), image_exts, opts), + daemon=False, + ) + p.start() + procs.append(p) + for p in procs: + p.join() if __name__ == "__main__": diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 72ddc47b..81044b38 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -112,18 +112,34 @@ def training_step( # ======================================================================== # Manual Flow Matching Noise Addition # ======================================================================== - - # Generate noise - noise = torch.randn_like(torch.ones([1, 16, grid_sizes[0][0], grid_sizes[0][1]*2, grid_sizes[0][2]*2], device=video_latents.device), dtype=torch.float32) - noise = patchify(noise, (1, 2, 2))[0].unsqueeze(1) - # DEBUGGING - # because video_latents might be padded, we need to make sure noise also be padded to have the same shape - seq_noise = noise.shape[0] - seq_video = video_latents.shape[0] - if seq_noise < seq_video: - pad_len = seq_video - seq_noise - pad = torch.zeros((pad_len, noise.shape[1], noise.shape[2]), device=noise.device, dtype=noise.dtype) - noise = torch.cat([noise, pad], dim=0) + + # Generate noise for batch + noise = [] + in_channels = model.config.in_channels + patch_spatial = model.config.patch_spatial + patch_temporal = model.config.patch_temporal + for grid_size in grid_sizes: + sample_noise = torch.randn( + 1, + in_channels, + grid_size[0]*patch_temporal, + grid_size[1]*patch_spatial, + grid_size[2]*patch_spatial, + dtype=torch.float32, + device=video_latents.device, + ) + sample_noise = patchify(sample_noise, (patch_temporal, patch_spatial, patch_spatial))[0] # shape [noise_seq, c * ( pF * pH * pW)] + + # because video_latents might be padded, we need to make sure noise also be padded to have the same shape + noise_seq = sample_noise.shape[0] + video_seq = video_latents.shape[0] + if noise_seq < video_seq: + pad_len = video_seq - noise_seq + pad = torch.zeros((pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype) + sample_noise = torch.cat([sample_noise, pad], dim=0) + noise.append(sample_noise) + noise = torch.stack(noise, dim=1) # shape [noise_seq, batch_size, c * ( pF * pH * pW)] + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) # x_t = (1 - σ) * x_0 + σ * ε @@ -162,6 +178,12 @@ def training_step( context_embeddings = context_embeddings split_loss_mask = loss_mask + # # DEBUGGING + # print(f"[DEBUG] [flow_pipeline] video_latents shape: {video_latents.shape}") + # print(f"[DEBUG] [flow_pipeline] noisy_latents shape: {noisy_latents.shape}") + # print(f"[DEBUG] [flow_pipeline] noise shape: {noise.shape}") + # print(f"[DEBUG] [flow_pipeline] context_embeddings shape: {context_embeddings.shape}") + # print(f"[DEBUG] [flow_pipeline] split_loss_mask shape: {split_loss_mask.shape}") # ======================================================================== # Forward Pass From 0430384f75cc02ae60b58fe118d68a6a04271a01 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 00:12:59 -0800 Subject: [PATCH 08/80] refactor attention submodules + reorder files locations --- .../megatron/model/common/dit_attention.py | 292 +++++++++++++++ .../flow_matching/flow_inference_pipeline.py | 2 +- .../model/wan/inference/configs/__init__.py | 2 - dfm/src/megatron/model/wan/wan_layer_spec.py | 339 ++---------------- dfm/src/megatron/model/wan/wan_model.py | 4 +- .../megatron/model}/wan/wan_provider.py | 2 +- .../megatron/model}/wan/wan_step.py | 1 - .../megatron/recipes}/wan/wan.py | 3 +- .../wan_pretrain_override_example.yaml | 0 .../megatron/{recipe => recipes}/README.md | 0 .../recipes/wan/convert_wan_checkpoints.py | 20 ++ .../megatron/recipes/wan/example_commands.md | 88 +++++ .../megatron/recipes}/wan/example_commands.sh | 2 +- .../megatron/recipes}/wan/inference_wan.py | 3 - .../launch_data_processing_images_sbatch.sh | 57 +++ .../recipes/wan/launch_pretrain_wan_sbatch.sh | 88 +++++ .../recipes/wan/nemo_mcore_t5_sbatch_mcore.sh | 128 +++++++ .../megatron/recipes}/wan/pretrain_wan.py | 2 - 18 files changed, 702 insertions(+), 331 deletions(-) create mode 100644 dfm/src/megatron/model/common/dit_attention.py rename dfm/{examples/megatron/recipe => src/megatron/model}/wan/wan_provider.py (98%) rename dfm/{examples/megatron/recipe => src/megatron/model}/wan/wan_step.py (99%) rename dfm/{examples/megatron/recipe => src/megatron/recipes}/wan/wan.py (98%) rename {dfm/examples/megatron/recipe/wan/conf => examples/megatron/override_configs}/wan_pretrain_override_example.yaml (100%) rename examples/megatron/{recipe => recipes}/README.md (100%) create mode 100644 examples/megatron/recipes/wan/convert_wan_checkpoints.py create mode 100644 examples/megatron/recipes/wan/example_commands.md rename {dfm/examples/megatron/recipe => examples/megatron/recipes}/wan/example_commands.sh (99%) rename {dfm/examples/megatron/recipe => examples/megatron/recipes}/wan/inference_wan.py (99%) create mode 100644 examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh create mode 100644 examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh create mode 100644 examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh rename {dfm/examples/megatron/recipe => examples/megatron/recipes}/wan/pretrain_wan.py (99%) diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py new file mode 100644 index 00000000..a5a2794c --- /dev/null +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -0,0 +1,292 @@ +from typing import Optional + +import copy +import torch +from megatron.core import parallel_state, tensor_parallel +from megatron.core.extensions.transformer_engine import SplitAlongDim +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.transformer.attention import ( + CrossAttention, + SelfAttention, + SelfAttentionSubmodules, +) +from megatron.core.transformer.spec_utils import build_module +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.transformer.enums import AttnMaskType +from dataclasses import dataclass +from typing import Union +from megatron.core.transformer.spec_utils import ModuleSpec + + +@dataclass +class DiTCrossAttentionSubmodules: + linear_q: Union[ModuleSpec, type] = None + linear_kv: Union[ModuleSpec, type] = None + core_attention: Union[ModuleSpec, type] = None + linear_proj: Union[ModuleSpec, type] = None + q_layernorm: Union[ModuleSpec, type] = None + k_layernorm: Union[ModuleSpec, type] = None + + +class DiTSelfAttention(SelfAttention): + def __init__( + self, + config: TransformerConfig, + submodules: SelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_heads = self.config.layernorm_across_heads + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_heads: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_heads: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states=None): + """ + Derives `query`, `key` and `value` tensors from `hidden_states`. + """ + # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] + mixed_qkv, _ = self.linear_qkv(hidden_states) + + # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] + new_tensor_shape = mixed_qkv.size()[:-1] + ( + self.num_query_groups_per_partition, + ( + (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) + * self.hidden_size_per_attention_head + ), + ) + mixed_qkv = mixed_qkv.view(*new_tensor_shape) + + split_arg_list = [ + ( + self.num_attention_heads_per_partition + // self.num_query_groups_per_partition + * self.hidden_size_per_attention_head + ), + self.hidden_size_per_attention_head, + self.hidden_size_per_attention_head, + ] + + if SplitAlongDim is not None: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) + else: + + # [sq, b, ng, (np/ng + 2) * hn] + # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] + (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) + + # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] + query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) + + # gather query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_heads: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_heads: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + if self.config.test_mode: + self.run_realtime_tests() + + return query, key, value + + +class DiTCrossAttention(CrossAttention): + def __init__( + self, + config: TransformerConfig, + submodules: DiTCrossAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + cp_comm_type: str = None, + pg_collection: Optional[ProcessGroupCollection] = None, + ): + super().__init__( + config, + submodules, + layer_number, + attn_mask_type, + cp_comm_type, + pg_collection, + ) + + self.layernorm_across_heads = self.config.layernorm_across_heads + + # override q_layernorm + if submodules.q_layernorm is not None: + if self.layernorm_across_heads: + q_layernorm_size = self.query_projection_size + else: + q_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.q_layernorm = build_module( + submodules.q_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=q_layernorm_size, + config=norm_config, + ) + else: + self.q_layernorm = None + + # override k_layernorm + if submodules.k_layernorm is not None: + if self.layernorm_across_heads: + k_layernorm_size = self.kv_projection_size + else: + k_layernorm_size = self.hidden_size_per_attention_head + norm_config = copy.deepcopy(self.config) + norm_config.normalization = "RMSNorm" + self.k_layernorm = build_module( + submodules.k_layernorm, + eps=norm_config.layernorm_epsilon, + hidden_size=k_layernorm_size, + config=norm_config, + ) + else: + self.k_layernorm = None + + def get_query_key_value_tensors(self, hidden_states, key_value_states): + """ + Derives `query` tensor from `hidden_states`, and `key`/`value` tensors + from `key_value_states`. + """ + # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] + mixed_kv, _ = self.linear_kv(key_value_states) + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + mixed_kv = mixed_kv.view(*new_tensor_shape) + + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) + + # Attention head [sq, b, h] --> [sq, b, hp] + query, _ = self.linear_q(hidden_states) + + # [sq, b, hp] --> [sq, b, np, hn] + new_tensor_shape = query.size()[:-1] + ( + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ) + query = query.view(*new_tensor_shape) + + # replace with our own implementation (Todo: @huy ) + query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) + + # gather query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.gather_from_tensor_model_parallel_region(query) + key = tensor_parallel.gather_from_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + + if self.q_layernorm is not None: + if self.layernorm_across_heads: + q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] + q_flat = self.q_layernorm(q_flat) + query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + else: + query = self.q_layernorm(query.contiguous()) + + if self.k_layernorm is not None: + if self.layernorm_across_heads: + k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() + k_flat = self.k_layernorm(k_flat) + key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) + else: + key = self.k_layernorm(key.contiguous()) + + # scatter query and key heads across TP ranks if self.layernorm_across_heads is True + if self.layernorm_across_heads and parallel_state.get_tensor_model_parallel_world_size() > 1: + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) + key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) + query = query.transpose(-2, -1) + key = key.transpose(-2, -1) + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors + + return query, key, value + diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index ac5b8657..b6a8864e 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -15,7 +15,7 @@ from tqdm import tqdm from dfm.src.megatron.model.wan.wan_model import WanModel -from dfm.examples.megatron.recipe.wan.wan_provider import WanModelProvider +from dfm.src.megatron.model.wan.wan_provider import WanModelProvider from dfm.src.megatron.model.wan.modules.t5 import T5EncoderModel from dfm.src.megatron.model.wan.modules import WanVAE from dfm.src.megatron.model.wan.inference.utils.fm_solvers import ( diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index 938da53c..a28c03c5 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -1,8 +1,6 @@ import copy import os -# Change to diffusors - os.environ['TOKENIZERS_PARALLELISM'] = 'false' from .wan_i2v_14B import i2v_14B diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 267c6ad0..cdf58847 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -41,6 +41,8 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.utils import make_viewless_tensor +from dfm.src.megatron.model.common.dit_attention import DiTCrossAttentionSubmodules, DiTSelfAttention, DiTCrossAttention +from megatron.core.transformer.attention import SelfAttentionSubmodules try: @@ -54,312 +56,21 @@ SplitAlongDim = None -class WanLayerNorm(nn.LayerNorm): - # Note to parth: Can we replace this with te layer norm or fuse with linear layer? - # (@huy) Remove this comment after you have answered the question. +# class WanLayerNorm(nn.LayerNorm): +# # Note to parth: Can we replace this with te layer norm or fuse with linear layer? +# # (@huy) Remove this comment after you have answered the question. - def __init__(self, dim, eps=1e-6, elementwise_affine=False): - super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) +# def __init__(self, dim, eps=1e-6, elementwise_affine=False): +# super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - def forward(self, x): - r""" - Args: - x(Tensor): Shape [B, L, C] - """ - return super().forward(x).type_as(x) +# def forward(self, x): +# r""" +# Args: +# x(Tensor): Shape [B, L, C] +# """ +# return super().forward(x).type_as(x) -@dataclass -class WanSelfAttentionSubmodules: # Call this DiTSelfAttentionSubmodules or DiTSelfAttentionConfig? - """ - Configuration class for specifying the submodules of a self-attention. - """ - - linear_qkv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - layernorm_across_head: bool = False # Should be moved to Trasnformer config and not layerspec. (@huy to remove this) - q_layernorm: Union[ModuleSpec, type] = None - k_layernorm: Union[ModuleSpec, type] = None - - -@dataclass -class WanCrossAttentionSubmodules: # Call this DiTCrossAttentionSubmodules or DiTCrossAttentionConfig? - """ - Configuration class for specifying the submodules of a cross-attention. - """ - linear_q: Union[ModuleSpec, type] = None - linear_kv: Union[ModuleSpec, type] = None - core_attention: Union[ModuleSpec, type] = None - linear_proj: Union[ModuleSpec, type] = None - layernorm_across_head: bool = False # Should be moved to Trasnformer config and not layerspec. (@huy to remove this) - q_layernorm: Union[ModuleSpec, type] = None - k_layernorm: Union[ModuleSpec, type] = None - - -class WanSelfAttention(SelfAttention): # Call this DitSelfAttention or DiTSelfAttentionConfig? - def __init__( - self, - config: TransformerConfig, - submodules: WanSelfAttentionSubmodules, - layer_number: int, - attn_mask_type: AttnMaskType, - cp_comm_type: str = None, - pg_collection: Optional[ProcessGroupCollection] = None, - ): - super().__init__( - config, - submodules, - layer_number, - attn_mask_type, - cp_comm_type, - pg_collection, - ) - - self.layernorm_across_head = getattr(self.config, "False", submodules.layernorm_across_head) - - # override q_layernorm - if submodules.q_layernorm is not None: - if self.layernorm_across_head: - q_layernorm_size = self.query_projection_size - else: - q_layernorm_size = self.hidden_size_per_attention_head - norm_config = copy.deepcopy(self.config) - norm_config.normalization = "RMSNorm" - self.q_layernorm = build_module( - submodules.q_layernorm, - eps=norm_config.layernorm_epsilon, - hidden_size=q_layernorm_size, - config=norm_config, - ) - else: - self.q_layernorm = None - - # override k_layernorm - if submodules.k_layernorm is not None: - if self.layernorm_across_head: - k_layernorm_size = self.kv_projection_size - else: - k_layernorm_size = self.hidden_size_per_attention_head - norm_config = copy.deepcopy(self.config) - norm_config.normalization = "RMSNorm" - self.k_layernorm = build_module( - submodules.k_layernorm, - eps=norm_config.layernorm_epsilon, - hidden_size=k_layernorm_size, - config=norm_config, - ) - else: - self.k_layernorm = None - - def get_query_key_value_tensors(self, hidden_states, key_value_states=None): - """ - Derives `query`, `key` and `value` tensors from `hidden_states`. - """ - # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - mixed_qkv, _ = self.linear_qkv(hidden_states) - - # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - new_tensor_shape = mixed_qkv.size()[:-1] + ( - self.num_query_groups_per_partition, - ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head - ), - ) - mixed_qkv = mixed_qkv.view(*new_tensor_shape) - - split_arg_list = [ - ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ] - - if SplitAlongDim is not None: - - # [sq, b, ng, (np/ng + 2) * hn] - # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) - else: - - # [sq, b, ng, (np/ng + 2) * hn] - # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) - - # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) - - # gather query and key heads across TP ranks if self.layernorm_across_head is True - if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = tensor_parallel.gather_from_tensor_model_parallel_region(query) - key = tensor_parallel.gather_from_tensor_model_parallel_region(key) - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - - if self.q_layernorm is not None: - if self.layernorm_across_head: - q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] - q_flat = self.q_layernorm(q_flat) - query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] - else: - query = self.q_layernorm(query.contiguous()) - - if self.k_layernorm is not None: - if self.layernorm_across_head: - k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() - k_flat = self.k_layernorm(k_flat) - key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) - else: - key = self.k_layernorm(key.contiguous()) - - # scatter query and key heads across TP ranks if self.layernorm_across_head is True - if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) - key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = query.contiguous() # important becuase TE attention expects contiguous tensors - key = key.contiguous() # important becuase TE attention expects contiguous tensors - - if self.config.test_mode: - self.run_realtime_tests() - - return query, key, value - - -class WanCrossAttention(CrossAttention): # DiTCrossAttention or DiTCrossAttentionConfig? - def __init__( - self, - config: TransformerConfig, - submodules: WanCrossAttentionSubmodules, - layer_number: int, - attn_mask_type: AttnMaskType, - cp_comm_type: str = None, - pg_collection: Optional[ProcessGroupCollection] = None, - ): - super().__init__( - config, - submodules, - layer_number, - attn_mask_type, - cp_comm_type, - pg_collection, - ) - - self.layernorm_across_head = submodules.layernorm_across_head - - # override q_layernorm - if submodules.q_layernorm is not None: - if self.layernorm_across_head: - q_layernorm_size = self.query_projection_size - else: - q_layernorm_size = self.hidden_size_per_attention_head - norm_config = copy.deepcopy(self.config) - norm_config.normalization = "RMSNorm" - self.q_layernorm = build_module( - submodules.q_layernorm, - eps=norm_config.layernorm_epsilon, - hidden_size=q_layernorm_size, - config=norm_config, - ) - else: - self.q_layernorm = None - - # override k_layernorm - if submodules.k_layernorm is not None: - if self.layernorm_across_head: - k_layernorm_size = self.kv_projection_size - else: - k_layernorm_size = self.hidden_size_per_attention_head - norm_config = copy.deepcopy(self.config) - norm_config.normalization = "RMSNorm" - self.k_layernorm = build_module( - submodules.k_layernorm, - eps=norm_config.layernorm_epsilon, - hidden_size=k_layernorm_size, - config=norm_config, - ) - else: - self.k_layernorm = None - - def get_query_key_value_tensors(self, hidden_states, key_value_states): - """ - Derives `query` tensor from `hidden_states`, and `key`/`value` tensors - from `key_value_states`. - """ - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv, _ = self.linear_kv(key_value_states) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv = mixed_kv.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query, _ = self.linear_q(hidden_states) - - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query = query.view(*new_tensor_shape) - - # replace with our own implementation (Todo: @huy ) - query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) - - # gather query and key heads across TP ranks if self.layernorm_across_head is True - if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = tensor_parallel.gather_from_tensor_model_parallel_region(query) - key = tensor_parallel.gather_from_tensor_model_parallel_region(key) - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - - if self.q_layernorm is not None: - if self.layernorm_across_head: - q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] - q_flat = self.q_layernorm(q_flat) - query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] - else: - query = self.q_layernorm(query.contiguous()) - - if self.k_layernorm is not None: - if self.layernorm_across_head: - k_flat = key.reshape(key.size(0), key.size(1), -1).contiguous() - k_flat = self.k_layernorm(k_flat) - key = k_flat.view(key.size(0), key.size(1), -1, self.hidden_size_per_attention_head) - else: - key = self.k_layernorm(key.contiguous()) - - # scatter query and key heads across TP ranks if self.layernorm_across_head is True - if self.layernorm_across_head and parallel_state.get_tensor_model_parallel_world_size() > 1: - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = tensor_parallel.scatter_to_tensor_model_parallel_region(query) - key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) - query = query.transpose(-2, -1) - key = key.transpose(-2, -1) - query = query.contiguous() # important becuase TE attention expects contiguous tensors - key = key.contiguous() # important becuase TE attention expects contiguous tensors - - return query, key, value - - @dataclass class WanWithAdaLNSubmodules(TransformerLayerSubmodules): temporal_self_attention: Union[ModuleSpec, type] = IdentityOp @@ -437,19 +148,19 @@ def __init__( self.adaLN = WanAdaLN(config=self.config) self.norm1 = build_module( submodules.norm1, - dim=config.hidden_size, + normalized_shape=config.hidden_size, eps=config.layernorm_epsilon, elementwise_affine=False ) self.norm3 = build_module( submodules.norm3, - dim=config.hidden_size, + normalized_shape=config.hidden_size, eps=config.layernorm_epsilon, elementwise_affine=True, ) self.norm2 = build_module( submodules.norm2, - dim=config.hidden_size, + normalized_shape=config.hidden_size, eps=config.layernorm_epsilon, elementwise_affine=False, ) @@ -544,37 +255,33 @@ def forward( return output, context - - def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: params = {"attn_mask_type": AttnMaskType.padding} return ModuleSpec( module=WanLayerWithAdaLN, submodules=WanWithAdaLNSubmodules( - norm1=WanLayerNorm, - norm3=WanLayerNorm, - norm2=WanLayerNorm, + norm1=nn.LayerNorm, + norm3=nn.LayerNorm, + norm2=nn.LayerNorm, full_self_attention=ModuleSpec( - module=WanSelfAttention, + module=DiTSelfAttention, params=params, - submodules=WanSelfAttentionSubmodules( + submodules=SelfAttentionSubmodules( linear_qkv=TEColumnParallelLinear, core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, - layernorm_across_head=True, q_layernorm=TENorm, k_layernorm=TENorm, ), ), cross_attention=ModuleSpec( - module=WanCrossAttention, + module=DiTCrossAttention, params=params, - submodules=WanCrossAttentionSubmodules( + submodules=DiTCrossAttentionSubmodules( linear_q=TEColumnParallelLinear, linear_kv=TEColumnParallelLinear, core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, - layernorm_across_head=True, q_layernorm=TENorm, k_layernorm=TENorm, ), diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index dc409f33..03aeebce 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -31,7 +31,6 @@ from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) -from dfm.src.megatron.model.wan.wan_layer_spec import WanLayerNorm from torch import Tensor from .rope_utils import Wan3DRopeEmbeddings @@ -47,6 +46,7 @@ def sinusoidal_embedding_1d(dim, position): x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x + class Head(nn.Module): def __init__(self, dim, out_dim, patch_size, eps=1e-6): @@ -58,7 +58,7 @@ def __init__(self, dim, out_dim, patch_size, eps=1e-6): # layers out_dim = math.prod(patch_size) * out_dim - self.norm = WanLayerNorm(dim, eps) + self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps) self.head = nn.Linear(dim, out_dim) # modulation diff --git a/dfm/examples/megatron/recipe/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py similarity index 98% rename from dfm/examples/megatron/recipe/wan/wan_provider.py rename to dfm/src/megatron/model/wan/wan_provider.py index 63229b37..90f41e14 100644 --- a/dfm/examples/megatron/recipe/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -13,7 +13,6 @@ # limitations under the License. -# Goes into the model import logging from dataclasses import dataclass @@ -40,6 +39,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): layernorm_epsilon: float = 1e-6 normalization: str = "RMSNorm" layernorm_zero_centered_gamma: bool = False + layernorm_across_heads: bool = True add_qkv_bias: bool = True rotary_interleaved: bool = True hidden_dropout: float = 0 diff --git a/dfm/examples/megatron/recipe/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py similarity index 99% rename from dfm/examples/megatron/recipe/wan/wan_step.py rename to dfm/src/megatron/model/wan/wan_step.py index cb19386d..43d5ff8a 100644 --- a/dfm/examples/megatron/recipe/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -13,7 +13,6 @@ # limitations under the License. -# Move to the model and wan import logging from functools import partial from typing import Iterable diff --git a/dfm/examples/megatron/recipe/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py similarity index 98% rename from dfm/examples/megatron/recipe/wan/wan.py rename to dfm/src/megatron/recipes/wan/wan.py index f64d0d72..f784a842 100644 --- a/dfm/examples/megatron/recipe/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Goes to src/megatron/reicepe/wan import os from typing import List, Optional, Union from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig -from dfm.examples.megatron.recipe.wan.wan_provider import WanModelProvider +from dfm.src.megatron.model.wan.wan_provider import WanModelProvider import torch from megatron.core.distributed import DistributedDataParallelConfig diff --git a/dfm/examples/megatron/recipe/wan/conf/wan_pretrain_override_example.yaml b/examples/megatron/override_configs/wan_pretrain_override_example.yaml similarity index 100% rename from dfm/examples/megatron/recipe/wan/conf/wan_pretrain_override_example.yaml rename to examples/megatron/override_configs/wan_pretrain_override_example.yaml diff --git a/examples/megatron/recipe/README.md b/examples/megatron/recipes/README.md similarity index 100% rename from examples/megatron/recipe/README.md rename to examples/megatron/recipes/README.md diff --git a/examples/megatron/recipes/wan/convert_wan_checkpoints.py b/examples/megatron/recipes/wan/convert_wan_checkpoints.py new file mode 100644 index 00000000..eaf8ef04 --- /dev/null +++ b/examples/megatron/recipes/wan/convert_wan_checkpoints.py @@ -0,0 +1,20 @@ +from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN +from megatron.bridge.models.wan.wan_bridge import WanBridge +from megatron.bridge.training.model_load_save import save_megatron_model +import os, random +os.environ["MASTER_ADDR"] = "127.0.0.1" +os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) +os.environ["RANK"] = "0" +os.environ["WORLD_SIZE"] = "1" +os.environ["LOCAL_RANK"] = "0" +# +hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") +# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") +bridge = WanBridge() +# +provider = bridge.provider_bridge(hf) +provider.perform_initialization = False +megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) +# +bridge.load_weights_hf_to_megatron(hf, megatron_models) +save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) \ No newline at end of file diff --git a/examples/megatron/recipes/wan/example_commands.md b/examples/megatron/recipes/wan/example_commands.md new file mode 100644 index 00000000..21b2492b --- /dev/null +++ b/examples/megatron/recipes/wan/example_commands.md @@ -0,0 +1,88 @@ +## WAN example commands + +### Set paths to Megatron-Bridge +```bash +DFM_PATH=/path/to/dfm +MBRIDGE_PATH=/path/to/megatron-bridge +export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" +``` + +### Install dependencies +```bash +pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 +python3 -m pip install --upgrade diffusers==0.35.1 +pip install easydict +pip install imageio +pip install imageio-ffmpeg +``` + +### Convert checkpoint +See `examples/conversion/convert_wan_checkpoints.py` under `MBRIDGE_PATH` for details. + +### Finetuning +Set environment variables and run training: +```bash +export HF_TOKEN=... +export WANDB_API_KEY=... +EXP_NAME=... +PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint +CHECKPOINT_DIR=/path/to/checkpoint_dir +DATASET_PATH=/path/to/dataset +cd ${MBRIDGE_PATH} +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.sequence_parallel=false \ + model.qkv_format=thd \ + dataset.path=${DATASET_PATH} \ + checkpoint.save=${CHECKPOINT_DIR} \ + checkpoint.load=${PRETRAINED_CHECKPOINT} \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=1 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=1 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name=${EXP_NAME} \ + logger.wandb_save_dir=${CHECKPOINT_DIR} +``` + +### Inference +Download T5 and VAE weights from the [Wan-AI/Wan2.1-T2V-1.3B repository](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main): +- T5: `models_t5_umt5-xxl-enc-bf16.pth`, provider `google` +- VAE: `Wan2.1_VAE.pth` + +Then run: +```bash +export HF_TOKEN=... +CHECKPOINT_DIR=/path/to/checkpoint_dir +T5_DIR=/path/to/t5_weights +VAE_DIR=/path/to/vae_weights +cd ${MBRIDGE_PATH} +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 832*480 \ + --checkpoint_dir ${CHECKPOINT_DIR} \ + --checkpoint_step 1000 \ + --t5_checkpoint_dir ${T5_DIR} \ + --vae_checkpoint_dir ${VAE_DIR} \ + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --frame_nums 81 \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 +``` + diff --git a/dfm/examples/megatron/recipe/wan/example_commands.sh b/examples/megatron/recipes/wan/example_commands.sh similarity index 99% rename from dfm/examples/megatron/recipe/wan/example_commands.sh rename to examples/megatron/recipes/wan/example_commands.sh index be8a0447..6cf7f6e9 100644 --- a/dfm/examples/megatron/recipe/wan/example_commands.sh +++ b/examples/megatron/recipes/wan/example_commands.sh @@ -10,7 +10,7 @@ export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launc pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 python3 -m pip install --upgrade diffusers==0.35.1 pip install easydict -pip install imageio +pip install e32qwimageio pip install imageio-ffmpeg diff --git a/dfm/examples/megatron/recipe/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py similarity index 99% rename from dfm/examples/megatron/recipe/wan/inference_wan.py rename to examples/megatron/recipes/wan/inference_wan.py index 3f84eafc..2f480a2b 100644 --- a/dfm/examples/megatron/recipe/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -14,9 +14,6 @@ # --base_seed 42 \ # --sample_steps 50 - -# Goes to examples/megatron/recipe/wan - import argparse import logging import os diff --git a/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh b/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh new file mode 100644 index 00000000..90905cf0 --- /dev/null +++ b/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +# Parameters +#SBATCH --account=coreai_dlalgo_llm +#SBATCH --job-name=coreai_dlalgo_llm-run:dfm +#SBATCH --nodes=1 +#SBATCH --partition=batch +#SBATCH --time=04:00:00 + + +OUTPUT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan + +cmd=" + +# install +DFM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/DFM_mcore_wan +MBRIDGE_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/Megatron-Bridge_mcore_wan_official +MLM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/megatron-lm_latest +export PYTHONPATH="\$DFM_PATH/.:\$MBRIDGE_PATH/src/.:\$MLM_PATH/.:/opt/NeMo-Framework-Launcher/launcher_scripts" + + +# install dependencies +python3 -m pip install --upgrade diffusers +pip install easydict +pip install imageio +pip install imageio-ffmpeg +[apt update; apt install ffmpeg -y] -> for data preparation + + +cd \$DFM_PATH +export HF_TOKEN=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 +python dfm/src/megatron/data/wan/prepare_energon_dataset_wan_square_images.py \ + --input_dir /lustre/fsw/coreai_dlalgo_genai/datasets/coyo-700m/part_00000 \ + --output_dir /lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan/part_00000_square_wds \ + --model Wan-AI/Wan2.1-T2V-14B-Diffusers \ + --gpus 0,1,2,3,4,5,6,7 \ + --size 256 \ + --resize_mode bilinear \ + --save-image \ + --skip-existing \ + --stochastic + +" + +CONT="nvcr.io/nvidia/nemo:25.09.00" +MOUNT="/lustre/fsw/:/lustre/fsw/" +OUTFILE=$OUTPUT_DIR/slurm-%j.out +ERRFILE=$OUTPUT_DIR/error-%j.out +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +echo "Running training script." +srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ + --container-image="${CONT}" --container-mounts="${MOUNT}" \ + --no-container-mount-home \ + --ntasks-per-node=1 \ + -N ${SLURM_JOB_NUM_NODES} \ + bash -c "${cmd}" \ No newline at end of file diff --git a/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh b/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh new file mode 100644 index 00000000..bfef15cd --- /dev/null +++ b/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh @@ -0,0 +1,88 @@ +#!/bin/bash + +# Parameters +#SBATCH --account=coreai_dlalgo_llm +#SBATCH --job-name=coreai_dlalgo_llm-run:dfm +#SBATCH --nodes=1 +#SBATCH --partition=batch +#SBATCH --time=04:00:00 + + +EXP_NAME=sbatch_wan_1.3B_square_images_pretrain_mbs64gbs512_1tar +CHECKPOINT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/results/wan_finetune/${EXP_NAME} +PROJECT=wan +MBS=64 +GBS=512 +LR=1e-4 +WARMUP_ITERS=10000 +# set this PRETRAIN_CHECKPOINT_DIR to CHECKPOINT_DIR to train from scratch +PRETRAIN_CHECKPOINT_DIR=${CHECKPOINT_DIR} +# PRETRAIN_CHECKPOINT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/megatron_checkpoint_1.3B + +# create checkpoint directory +mkdir -p ${CHECKPOINT_DIR} + +cmd=" + +# install +DFM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/DFM_mcore_wan +MBRIDGE_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/Megatron-Bridge_mcore_wan_official +MLM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/megatron-lm_latest +export PYTHONPATH="\$DFM_PATH/.:\$MBRIDGE_PATH/src/.:\$MLM_PATH/.:/opt/NeMo-Framework-Launcher/launcher_scripts" + + +# install dependencies +python3 -m pip install --upgrade diffusers +pip install easydict +pip install imageio +pip install imageio-ffmpeg +[apt update; apt install ffmpeg -y] -> for data preparation + + +cd \$DFM_PATH +export HF_TOKEN=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 +export WANDB_API_KEY=497a93e5ac7cf1e0ec821741ef7bac27b36f2db8 +NVTE_FUSED_ATTN=1 torchrun --standalone --nproc_per_node=8 dfm/examples/megatron/recipe/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=1 \ + model.sequence_parallel=false \ + model.qkv_format=sbhd \ + dataset.path="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan/part_00000_square_wds_1tar" \ + checkpoint.save=${CHECKPOINT_DIR} \ + checkpoint.load=${PRETRAIN_CHECKPOINT_DIR} \ + checkpoint.load_optim=false \ + checkpoint.save_interval=2500 \ + optimizer.lr=${LR} \ + optimizer.min_lr=${LR} \ + train.eval_iters=0 \ + train.train_iters=1000000 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=${WARMUP_ITERS} \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=${GBS} \ + train.micro_batch_size=${MBS} \ + dataset.global_batch_size=${GBS} \ + dataset.micro_batch_size=${MBS} \ + logger.log_interval=1 \ + logger.wandb_project=${PROJECT} \ + logger.wandb_exp_name=${EXP_NAME} \ + logger.wandb_save_dir=${CHECKPOINT_DIR} + +" + +CONT="nvcr.io/nvidia/nemo:25.09.00" +MOUNT="/lustre/fsw/:/lustre/fsw/" +OUTFILE=$CHECKPOINT_DIR/slurm-%j.out +ERRFILE=$CHECKPOINT_DIR/error-%j.out +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +echo "Running training script." +srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ + --container-image="${CONT}" --container-mounts="${MOUNT}" \ + --no-container-mount-home \ + --ntasks-per-node=1 \ + -N ${SLURM_JOB_NUM_NODES} \ + bash -c "${cmd}" \ No newline at end of file diff --git a/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh b/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh new file mode 100644 index 00000000..f086f52a --- /dev/null +++ b/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh @@ -0,0 +1,128 @@ +#!/bin/bash + +# Parameters +#SBATCH --account=coreai_dlalgo_llm +#SBATCH --job-name=coreai_dlalgo_llm-run:nemo_mcore_t5 +#SBATCH --nodes=1 +#SBATCH --partition=batch +#SBATCH --time=04:00:00 + +# sbatch script +mcore_t5=False + +NUM_DEVICES=8 +NEMO_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo +EXPNAME=nemo_nonmcore_t5_fp32_sbatch_mcore +PROJECT=t5_pretrain_sbatch + +CONFIG_NAME='megatron_t5_config' + +PRECISION=32 +AMP_O2=False +MICRO_BATCH_SIZE=64 +GLOBAL_BATCH_SIZE=512 +ACCUMULATE_GRAD_BATCHES=1 +TENSOR_MODEL_PARALLEL_SIZE=1 +VAL_CHECK_INTERVAL=2000 +MAX_STEPS=1000000 +# BLEND="[.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_00_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_01_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_02_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_03_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_04_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_05_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_06_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_07_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_08_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_09_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_10_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_11_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_12_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_13_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_14_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_15_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_16_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_17_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_18_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_19_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_20_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_21_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_22_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_23_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_24_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_25_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_26_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_27_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_28_bert_tokenizer_text_document,.0334,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_29_bert_tokenizer_text_document]" +BLEND="[/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_00_bert_tokenizer_text_document]" + +# Model architecture +SEQ_LENGTH=512 +SEQ_LENGTH_DEC=128 +NUM_LAYERS=12 +HIDDEN_SIZE=768 +NUM_ATTENTION_HEADS=12 + +home_dir=/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/results/t5_nemo +exp_dir=$home_dir/${EXPNAME} +mkdir ${exp_dir} + +cmd=" + +cd /lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo +pip install -e . +NEMO=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo +export PYTHONPATH="${NEMO}/.:${PYTHONPATH}" +## Megatron-LM +cd /lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm +pip install -e . +MEGATRONLM=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm +export PYTHONPATH="${MEGATRONLM}/.:${PYTHONPATH}" +export PYTHONPATH="/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm/.:/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo/.:/opt/NeMo-Megatron-Launcher/launcher_scripts" + +export WANDB_API_KEY=497a93e5ac7cf1e0ec821741ef7bac27b36f2db8 + +if [[ ${PRECISION} != "32" ]]; then + export NVTE_FUSED_ATTN=0 + export NVTE_FLASH_ATTN=0 +fi + +python ${NEMO_DIR}/examples/nlp/language_modeling/megatron_t5_pretraining.py \ + --config-name=${CONFIG_NAME} \ + trainer.num_nodes=1 \ + trainer.devices=${NUM_DEVICES} \ + trainer.max_epochs=null \ + trainer.max_steps=${MAX_STEPS} \ + trainer.val_check_interval=${VAL_CHECK_INTERVAL} \ + trainer.accumulate_grad_batches=${ACCUMULATE_GRAD_BATCHES} \ + trainer.precision=${PRECISION} \ + trainer.log_every_n_steps=1 \ + model.megatron_amp_O2=${AMP_O2} \ + model.micro_batch_size=${MICRO_BATCH_SIZE} \ + model.global_batch_size=${GLOBAL_BATCH_SIZE} \ + model.tensor_model_parallel_size=${TENSOR_MODEL_PARALLEL_SIZE} \ + model.max_position_embeddings=${SEQ_LENGTH} \ + model.seq_length=${SEQ_LENGTH} \ + model.encoder.hidden_size=${HIDDEN_SIZE} \ + model.decoder.hidden_size=${HIDDEN_SIZE} \ + model.encoder.num_layers=${NUM_LAYERS} \ + model.decoder.num_layers=${NUM_LAYERS} \ + model.encoder.num_attention_heads=${NUM_ATTENTION_HEADS} \ + model.decoder.num_attention_heads=${NUM_ATTENTION_HEADS} \ + model.encoder.init_method_std=0.015 \ + model.decoder.init_method_std=0.015 \ + model.encoder.transformer_block_type='pre_ln' \ + model.decoder.transformer_block_type='pre_ln' \ + model.data.data_prefix=${BLEND} \ + model.data.seq_length=${SEQ_LENGTH} \ + model.tokenizer.vocab_file=/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/bert-large-cased-vocab.txt \ + model.data.seq_length_dec=${SEQ_LENGTH_DEC} \ + model.data.splits_string=\'99982,9,9\' \ + model.data.num_workers=4 \ + model.optim.name=distributed_fused_adam \ + model.mcore_t5=${mcore_t5} \ + model.transformer_engine=True \ + +model.kv_channels=64 \ + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name=${EXPNAME} \ + exp_manager.wandb_logger_kwargs.project=${PROJECT} \ + +exp_manager.wandb_logger_kwargs.resume=True \ + exp_manager.explicit_log_dir=${exp_dir} \ + exp_manager.resume_if_exists=True \ + exp_manager.resume_ignore_no_checkpoint=True \ + exp_manager.create_checkpoint_callback=True \ + exp_manager.checkpoint_callback_params.monitor=val_loss \ + exp_manager.checkpoint_callback_params.save_top_k=3 \ + exp_manager.checkpoint_callback_params.mode=min \ + exp_manager.checkpoint_callback_params.always_save_nemo=False \ + ++exp_manager.checkpoint_callback_params.save_nemo_on_train_end=False \ + ++exp_manager.log_step_timing=True \ +" + +# ++model.async_grad_allreduce=False for training with bf16, O2, FusedAdam + +CONT="gitlab-master.nvidia.com/dl/joc/nemo-ci/train:pipe.14465850" +MOUNT="/lustre/fsw/:/lustre/fsw/" +OUTFILE=$exp_dir/slurm-%j.out +ERRFILE=$exp_dir/error-%j.out +export CUDA_DEVICE_MAX_CONNECTIONS=1 + +echo "Running training script." +srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ + --container-image="${CONT}" --container-mounts="${MOUNT}" \ + --no-container-mount-home \ + --ntasks-per-node=8 \ + -N ${SLURM_JOB_NUM_NODES} \ + bash -c "${cmd}" \ No newline at end of file diff --git a/dfm/examples/megatron/recipe/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py similarity index 99% rename from dfm/examples/megatron/recipe/wan/pretrain_wan.py rename to examples/megatron/recipes/wan/pretrain_wan.py index 81b4f2f4..7742397e 100644 --- a/dfm/examples/megatron/recipe/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -14,8 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -# Goes to examples/megatron/recipe/wan """ Wan Pretraining Script with YAML and CLI Configuration Overrides. From dfff86bba2e522b0675f04880007b5151123cd2b Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 00:23:40 -0800 Subject: [PATCH 09/80] update refactor --- .../wan/prepare_energon_dataset_wan_images.py | 551 ------------------ .../recipes/wan/convert_wan_checkpoints.py | 20 - .../megatron/recipes/wan/example_commands.sh | 80 --- .../launch_data_processing_images_sbatch.sh | 57 -- .../recipes/wan/launch_pretrain_wan_sbatch.sh | 88 --- nemo_vfm/diffusion/data/base.py | 444 -------------- .../models/dit/dit_attention_megatron.py | 0 7 files changed, 1240 deletions(-) delete mode 100644 dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py delete mode 100644 examples/megatron/recipes/wan/convert_wan_checkpoints.py delete mode 100644 examples/megatron/recipes/wan/example_commands.sh delete mode 100644 examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh delete mode 100644 examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh delete mode 100644 nemo_vfm/diffusion/data/base.py delete mode 100644 nemo_vfm/diffusion/models/dit/dit_attention_megatron.py diff --git a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py b/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py deleted file mode 100644 index 70758ae6..00000000 --- a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan_images.py +++ /dev/null @@ -1,551 +0,0 @@ -import os -import io -import json -import pickle -import tarfile -from pathlib import Path -from typing import Dict, List, Optional, Tuple, Iterable, Any -import multiprocessing as mp -import math - -import cv2 -import numpy as np -import torch -import webdataset as wds - -from diffusers import AutoencoderKLWan -from transformers import AutoTokenizer, UMT5EncoderModel -from tqdm import tqdm - - -def _map_interpolation(resize_mode: str) -> int: - interpolation_map = { - "bilinear": cv2.INTER_LINEAR, - "bicubic": cv2.INTER_CUBIC, - "nearest": cv2.INTER_NEAREST, - "area": cv2.INTER_AREA, - "lanczos": cv2.INTER_LANCZOS4, - } - if resize_mode not in interpolation_map: - raise ValueError( - f"Invalid resize_mode '{resize_mode}'. Choose from: {list(interpolation_map.keys())}" - ) - return interpolation_map[resize_mode] - - -def _calculate_resize_dimensions( - original_height: int, - original_width: int, - target_size: Optional[Tuple[int, int]], - maintain_aspect_ratio: bool, -) -> Tuple[int, int]: - if target_size is None: - return original_height, original_width - - target_height, target_width = target_size - if not maintain_aspect_ratio: - return target_height, target_width - - original_aspect = original_width / max(1, original_height) - target_aspect = target_width / max(1, target_height) - - if original_aspect > target_aspect: - new_width = target_width - new_height = int(round(target_width / max(1e-6, original_aspect))) - else: - new_height = target_height - new_width = int(round(target_height * original_aspect)) - - return new_height, new_width - - -def _resize_frame( - frame: np.ndarray, - target_size: Optional[Tuple[int, int]], - resize_mode: str, - maintain_aspect_ratio: bool, - center_crop: bool, -) -> np.ndarray: - if target_size is None: - return frame - - original_height, original_width = frame.shape[:2] - target_height, target_width = target_size - - interpolation = _map_interpolation(resize_mode) - - if not maintain_aspect_ratio: - resized_frame = cv2.resize(frame, (target_width, target_height), interpolation=interpolation) - return resized_frame - - if center_crop: - # Resize-to-cover: scale so both dims >= target, then center-crop to exact target - scale = max(target_height / max(1, original_height), target_width / max(1, original_width)) - resize_height = int(math.ceil(original_height * scale)) - resize_width = int(math.ceil(original_width * scale)) - - resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) - - y_start = max(0, (resize_height - target_height) // 2) - x_start = max(0, (resize_width - target_width) // 2) - y_end = y_start + target_height - x_end = x_start + target_width - - # Bound checks (should be safe due to ceil, but guard anyway) - y_start = min(y_start, max(0, resize_height - target_height)) - x_start = min(x_start, max(0, resize_width - target_width)) - y_end = min(y_end, resize_height) - x_end = min(x_end, resize_width) - - cropped = resized_frame[y_start:y_end, x_start:x_end] - - # If due to rounding one dim is still short, pad minimally (rare) - pad_h = max(0, target_height - cropped.shape[0]) - pad_w = max(0, target_width - cropped.shape[1]) - if pad_h > 0 or pad_w > 0: - cropped = np.pad( - cropped, - ((pad_h // 2, pad_h - pad_h // 2), (pad_w // 2, pad_w - pad_w // 2), (0, 0)), - mode="edge", - ) - cropped = cropped[:target_height, :target_width] - - return cropped - - # Aspect-preserving resize-to-fit (no crop): may be smaller than target in one dim - resize_height, resize_width = _calculate_resize_dimensions( - original_height, original_width, target_size, maintain_aspect_ratio - ) - resized_frame = cv2.resize(frame, (resize_width, resize_height), interpolation=interpolation) - return resized_frame - - -def _decode_image_bytes_to_rgb(image_bytes: bytes) -> Optional[np.ndarray]: - array = np.frombuffer(image_bytes, dtype=np.uint8) - img_bgr = cv2.imdecode(array, cv2.IMREAD_COLOR) - if img_bgr is None: - return None - img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) - return img_rgb - - -def _select_target_size_for_image(image_rgb: np.ndarray) -> Tuple[int, int]: - h, w = image_rgb.shape[:2] - if h <= w: - return (480, 832) - else: - return (832, 480) - - -def _image_to_video_tensor( - image_rgb: np.ndarray, - target_size: Optional[Tuple[int, int]], - resize_mode: str, - maintain_aspect_ratio: bool, - center_crop: bool, - target_dtype: torch.dtype, -) -> torch.Tensor: - frame = _resize_frame(image_rgb, target_size, resize_mode, maintain_aspect_ratio, center_crop) - frame = frame.astype(np.float32) / 255.0 # H, W, C in [0,1] - - video_array = frame[None, ...] # T=1, H, W, C - video_tensor = torch.from_numpy(video_array) # T, H, W, C - video_tensor = video_tensor.permute(3, 0, 1, 2).unsqueeze(0) # 1, C, T=1, H, W - video_tensor = video_tensor.to(dtype=target_dtype) - return video_tensor - - -@torch.no_grad() -def _init_hf_models( - model_id: str, - device: str, - enable_memory_optimization: bool, -): - dtype = torch.float16 if device.startswith("cuda") else torch.float32 - - text_encoder = UMT5EncoderModel.from_pretrained( - model_id, - subfolder="text_encoder", - torch_dtype=dtype, - ) - text_encoder.to(device) - text_encoder.eval() - - vae = AutoencoderKLWan.from_pretrained( - model_id, - subfolder="vae", - torch_dtype=dtype, - ) - vae.to(device) - vae.eval() - if enable_memory_optimization: - vae.enable_slicing() - vae.enable_tiling() - - tokenizer = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer") - - return vae, text_encoder, tokenizer, dtype - - -@torch.no_grad() -def _encode_text( - tokenizer: AutoTokenizer, - text_encoder: UMT5EncoderModel, - device: str, - caption: str, -) -> torch.Tensor: - caption = (caption or "").strip() - # Pad to 512, then slice back to the non-padded length - inputs = tokenizer( - [caption], - padding="max_length", - truncation=True, - max_length=512, - return_tensors="pt", - return_attention_mask=True, - add_special_tokens=True, - ) - inputs = {k: v.to(device) for k, v in inputs.items()} - outputs = text_encoder( - input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"] - ).last_hidden_state # [1, L, C] - seq_len = int(inputs["attention_mask"][0].sum().item()) - return outputs[0, :seq_len, :] - - -@torch.no_grad() -def _encode_video_latents( - vae: AutoencoderKLWan, - device: str, - video_tensor: torch.Tensor, - deterministic_latents: bool, -) -> torch.Tensor: - video_tensor = video_tensor.to(device=device, dtype=vae.dtype) - video_tensor = video_tensor * 2.0 - 1.0 # [0,1] -> [-1,1] - - latent_dist = vae.encode(video_tensor) - if deterministic_latents: - video_latents = latent_dist.latent_dist.mean - else: - video_latents = latent_dist.latent_dist.sample() - - latent_mean = video_latents.mean().item() - latent_std = video_latents.std().item() - - if abs(latent_mean) < 0.5 and 0.5 < latent_std < 2.0: - final_latents = video_latents - else: - if not hasattr(vae.config, "latents_mean") or not hasattr(vae.config, "latents_std"): - raise ValueError("Wan2.1 VAE requires latents_mean and latents_std in config") - latents_mean = torch.tensor(vae.config.latents_mean, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) - latents_std = torch.tensor(vae.config.latents_std, device=device, dtype=vae.dtype).view(1, -1, 1, 1, 1) - final_latents = (video_latents - latents_mean) / latents_std - - return final_latents - - -def _iter_tar_images_and_captions(tf: tarfile.TarFile, image_exts: Tuple[str, ...]) -> Iterable[Tuple[tarfile.TarInfo, Optional[tarfile.TarInfo]]]: - members = [m for m in tf.getmembers() if m.isfile()] - # Map stem -> caption member - txt_map: Dict[str, tarfile.TarInfo] = {} - for m in members: - name = os.path.basename(m.name) - if name.lower().endswith(".txt"): - stem = os.path.splitext(name)[0] - txt_map[stem] = m - - for m in members: - name = os.path.basename(m.name) - lower = name.lower() - if lower.endswith(image_exts): - stem = os.path.splitext(name)[0] - caption_member = txt_map.get(stem, None) - yield m, caption_member - - -def _read_tar_member_bytes(tf: tarfile.TarFile, member: tarfile.TarInfo) -> bytes: - f = tf.extractfile(member) - if f is None: - return b"" - with f: - return f.read() - - -def _process_tar_with_models( - tar_path: Path, - image_exts: Tuple[str, ...], - opts: Dict[str, Any], - device: str, - vae: AutoencoderKLWan, - text_encoder: UMT5EncoderModel, - tokenizer: AutoTokenizer, - model_dtype: torch.dtype, - sink: Any, - index: int, - tqdm_position: int = 0, -) -> int: - processed = 0 - failed = 0 - - try: - with tarfile.open(tar_path, mode="r:*") as tf: - pairs = list(_iter_tar_images_and_captions(tf, image_exts)) - for img_member, cap_member in tqdm( - pairs, total=len(pairs), desc=f"{tar_path.name}", unit="img", position=tqdm_position, leave=False - ): - img_name = os.path.basename(img_member.name) - - try: - img_bytes = _read_tar_member_bytes(tf, img_member) - if not img_bytes: - failed += 1 - continue - rgb = _decode_image_bytes_to_rgb(img_bytes) - if rgb is None: - failed += 1 - continue - - caption_text = "" - if cap_member is not None: - try: - caption_bytes = _read_tar_member_bytes(tf, cap_member) - caption_text = caption_bytes.decode("utf-8", errors="ignore") - except Exception: - caption_text = "" - - target_size = _select_target_size_for_image(rgb) - video_tensor = _image_to_video_tensor( - image_rgb=rgb, - target_size=target_size, - resize_mode=opts["resize_mode"], - maintain_aspect_ratio=not opts.get("no_aspect_ratio", False), - center_crop=opts.get("center_crop", False), - target_dtype=model_dtype, - ) - - text_embed = _encode_text(tokenizer, text_encoder, device, caption_text) - latents = _encode_video_latents( - vae, device, video_tensor, deterministic_latents=not opts.get("stochastic", False) - ) - - # text_embed is already sliced to non-padded tokens: [L_actual, C] - text_embed_cpu = text_embed.detach().to(device="cpu") - latents_cpu = latents.detach().to(device="cpu")[0] - - C, T, H, W = video_tensor.shape[1:] - json_data = { - "source_tar": str(tar_path), - "tar_member": img_member.name, - "image_name": img_name, - "processed_frames": int(T), - "processed_height": int(H), - "processed_width": int(W), - "caption": caption_text, - "deterministic_latents": bool(not opts.get("stochastic", False)), - "memory_optimization": bool(not opts.get("no_memory_optimization", False)), - "model_version": "wan2.1", - "resize_settings": { - "target_size": target_size, - "resize_mode": opts["resize_mode"], - "maintain_aspect_ratio": bool(not opts.get("no_aspect_ratio", False)), - "center_crop": bool(opts.get("center_crop", False)), - }, - } - - sample = { - "__key__": f"{index:09}", - "pth": latents_cpu, - "pickle": pickle.dumps(text_embed_cpu, protocol=pickle.HIGHEST_PROTOCOL), - "json": json_data, - } - sink.write(sample) - - index += 1 - processed += 1 - except Exception: - failed += 1 - continue - except Exception as e: - print(f"Failed to process tar {tar_path}: {e}") - return index - - print(f"Processed tar {tar_path.name}: {processed} ok, {failed} failed. WDS written") - return index - - -def _worker_run( - rank: int, - device: str, - tar_paths: List[str], - in_root: str, - out_root: str, - image_exts: Tuple[str, ...], - opts: Dict[str, Any], -): - try: - torch.cuda.set_device(int(device.split(":")[-1])) - except Exception: - pass - - vae, text_encoder, tokenizer, model_dtype = _init_hf_models( - model_id=opts["model"], - device=device, - enable_memory_optimization=not opts.get("no_memory_optimization", False), - ) - - out_root_path = Path(out_root) - in_root_path = Path(in_root) - - # DEBUGGING - for tar_str in tar_paths: - # for tar_str in tar_paths[:1]: - tar_path = Path(tar_str) - # Mirror the original directory structure from input_dir under output_root - try: - rel_parent = tar_path.parent.relative_to(in_root_path) - except Exception: - rel_parent = Path("") - out_dir = out_root_path / rel_parent - out_dir.mkdir(parents=True, exist_ok=True) - - out_tar = out_dir / f"{tar_path.stem}.tar" - if opts.get("skip_existing") and out_tar.exists(): - continue - - index = 0 - with wds.TarWriter(str(out_tar)) as sink: - index = _process_tar_with_models( - tar_path=tar_path, - image_exts=image_exts, - opts=opts, - device=device, - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - model_dtype=model_dtype, - sink=sink, - index=index, - tqdm_position=rank, - ) - -def main(): - import argparse - - parser = argparse.ArgumentParser( - description=( - "Prepare WAN encodings for image tar shards and write WebDataset shards (pth, pickle, json)." - ) - ) - parser.add_argument("--input_dir", type=str, required=True, help="Directory containing .tar shards of images") - parser.add_argument("--output_dir", type=str, required=False, help="Directory to write webdataset shards") - parser.add_argument( - "--output_root", - type=str, - required=False, - help="Deprecated alias for --output_dir; if provided, will be used as output_dir", - ) - parser.add_argument( - "--model", - default="Wan-AI/Wan2.1-T2V-14B-Diffusers", - help=( - "Wan2.1 model ID (e.g., Wan-AI/Wan2.1-T2V-14B-Diffusers or Wan-AI/Wan2.1-T2V-1.3B-Diffusers)" - ), - ) - parser.add_argument( - "--stochastic", - action="store_true", - help="Use stochastic encoding (sampling) instead of deterministic posterior mean", - ) - parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") - parser.add_argument( - "--image_exts", - type=str, - default=".jpg,.jpeg,.png,.webp", - help="Comma-separated list of image extensions to include", - ) - parser.add_argument( - "--resize_mode", - default="bilinear", - choices=["bilinear", "bicubic", "nearest", "area", "lanczos"], - help="Interpolation mode for resizing", - ) - parser.add_argument("--no-aspect-ratio", action="store_true", help="Disable aspect ratio preservation") - parser.add_argument("--center-crop", action="store_true", help="Center crop to exact target size after resize") - parser.add_argument( - "--skip-existing", - action="store_true", - help="No-op in WDS mode; retained for CLI compatibility", - ) - parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per WDS shard") - parser.add_argument( - "--gpus", - type=str, - default="0", - help="Comma-separated GPU indices to use (e.g., '0,1,2,3')", - ) - - args = parser.parse_args() - - input_dir = Path(args.input_dir) - # Resolve output directory (support legacy --output_root) - resolved_output_dir = args.output_dir or args.output_root - if resolved_output_dir is None: - parser.error("--output_dir must be specified (or legacy --output_root)") - output_root = Path(resolved_output_dir) - output_root.mkdir(parents=True, exist_ok=True) - - image_exts = tuple(ext.strip().lower() for ext in args.image_exts.split(",") if ext.strip()) - - # Find tar files - tar_files = sorted([p for p in input_dir.iterdir() if p.is_file() and p.suffix.lower() == ".tar"]) - if not tar_files: - raise FileNotFoundError(f"No .tar files found in {input_dir}") - - # Parse GPU list and shard tars - gpu_ids = [s.strip() for s in args.gpus.split(",") if s.strip()] - devices = [f"cuda:{gid}" for gid in gpu_ids] - num_workers = len(devices) if devices else 1 - - shards: List[List[str]] = [[] for _ in range(num_workers)] - for idx, tar_path in enumerate(tar_files): - shards[idx % num_workers].append(str(tar_path)) - - opts = { - "model": args.model, - "stochastic": bool(args.stochastic), - "no_memory_optimization": bool(args.no_memory_optimization), - "resize_mode": args.resize_mode, - "no_aspect_ratio": bool(args.no_aspect_ratio), - "center_crop": bool(args.center_crop), - "skip_existing": bool(args.skip_existing), - "shard_maxcount": int(args.shard_maxcount), - } - - # Ensure CUDA-safe multiprocessing - try: - mp.set_start_method("spawn", force=True) - except RuntimeError: - pass - - if num_workers == 1: - _worker_run(0, devices[0] if devices else "cuda:0", shards[0], str(input_dir), str(output_root), image_exts, opts) - else: - procs: List[mp.Process] = [] - for rank, device in enumerate(devices): - if not shards[rank]: - continue - p = mp.Process( - target=_worker_run, - args=(rank, device, shards[rank], str(input_dir), str(output_root), image_exts, opts), - daemon=False, - ) - p.start() - procs.append(p) - for p in procs: - p.join() - - -if __name__ == "__main__": - main() - - diff --git a/examples/megatron/recipes/wan/convert_wan_checkpoints.py b/examples/megatron/recipes/wan/convert_wan_checkpoints.py deleted file mode 100644 index eaf8ef04..00000000 --- a/examples/megatron/recipes/wan/convert_wan_checkpoints.py +++ /dev/null @@ -1,20 +0,0 @@ -from megatron.bridge.models.hf_pretrained.wan import PreTrainedWAN -from megatron.bridge.models.wan.wan_bridge import WanBridge -from megatron.bridge.training.model_load_save import save_megatron_model -import os, random -os.environ["MASTER_ADDR"] = "127.0.0.1" -os.environ["MASTER_PORT"] = str(29500 + random.randint(0, 1000)) -os.environ["RANK"] = "0" -os.environ["WORLD_SIZE"] = "1" -os.environ["LOCAL_RANK"] = "0" -# -hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-1.3B-Diffusers") -# hf = PreTrainedWAN("Wan-AI/Wan2.1-T2V-14B-Diffusers") -bridge = WanBridge() -# -provider = bridge.provider_bridge(hf) -provider.perform_initialization = False -megatron_models = provider.provide_distributed_model(wrap_with_ddp=False, use_cpu_initialization=True) -# -bridge.load_weights_hf_to_megatron(hf, megatron_models) -save_megatron_model(megatron_models, "/opt/megatron_checkpoint", hf_tokenizer_path=None) \ No newline at end of file diff --git a/examples/megatron/recipes/wan/example_commands.sh b/examples/megatron/recipes/wan/example_commands.sh deleted file mode 100644 index 6cf7f6e9..00000000 --- a/examples/megatron/recipes/wan/example_commands.sh +++ /dev/null @@ -1,80 +0,0 @@ -#Let's make a md file instead - -### set path to Megatron-Bridge -DFM_PATH=/path/to/dfm -MBRIDGE_PATH=/path/to/megatron-bridge -export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" - - -### install dependencies -pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 -python3 -m pip install --upgrade diffusers==0.35.1 -pip install easydict -pip install e32qwimageio -pip install imageio-ffmpeg - - -### Convert checkpoint -# See ${MBRIDGE_PATH}/examples/conversion/convert_wan_checkpoints.py for details. - - -### Finetuning -export HF_TOKEN=... -export WANDB_API_KEY=... -EXP_NAME=... -PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint -CHECKPOINT_DIR=/path/to/checkpoint_dir -DATASET_PATH=/path/to/dataset -cd ${MBRIDGE_PATH} -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.context_parallel_size=4 \ - model.sequence_parallel=false \ - model.qkv_format=thd \ - dataset.path=${DATASET_PATH} \ - checkpoint.save=${CHECKPOINT_DIR} \ - checkpoint.load=${PRETRAINED_CHECKPOINT} \ - checkpoint.load_optim=false \ - checkpoint.save_interval=200 \ - optimizer.lr=5e-6 \ - optimizer.min_lr=5e-6 \ - train.eval_iters=0 \ - scheduler.lr_decay_style=constant \ - scheduler.lr_warmup_iters=0 \ - model.seq_length=2048 \ - dataset.seq_length=2048 \ - train.global_batch_size=1 \ - train.micro_batch_size=1 \ - dataset.global_batch_size=1 \ - dataset.micro_batch_size=1 \ - logger.log_interval=1 \ - logger.wandb_project="wan" \ - logger.wandb_exp_name=${EXP_NAME} \ - logger.wandb_save_dir=${CHECKPOINT_DIR} - - -### Inferencing -# Download T5 weights and VAE weights from "https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main" -# T5: models_t5_umt5-xxl-enc-bf16.pth, google -# VAE: Wan2.1_VAE.pth -export HF_TOKEN=... -CHECKPOINT_DIR=/path/to/checkpoint_dir -T5_DIR=/path/to/t5_weights -VAE_DIR=/path/to/vae_weights -cd ${MBRIDGE_PATH} -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ - --task t2v-1.3B \ - --sizes 832*480 \ - --checkpoint_dir ${CHECKPOINT_DIR} \ - --checkpoint_step 1000 \ - --t5_checkpoint_dir ${T5_DIR} \ - --vae_checkpoint_dir ${VAE_DIR} \ - --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ - --frame_nums 81 \ - --tensor_parallel_size 1 \ - --context_parallel_size 1 \ - --pipeline_parallel_size 1 \ - --sequence_parallel False \ - --base_seed 42 \ - --sample_steps 50 \ No newline at end of file diff --git a/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh b/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh deleted file mode 100644 index 90905cf0..00000000 --- a/examples/megatron/recipes/wan/launch_data_processing_images_sbatch.sh +++ /dev/null @@ -1,57 +0,0 @@ -#!/bin/bash - -# Parameters -#SBATCH --account=coreai_dlalgo_llm -#SBATCH --job-name=coreai_dlalgo_llm-run:dfm -#SBATCH --nodes=1 -#SBATCH --partition=batch -#SBATCH --time=04:00:00 - - -OUTPUT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan - -cmd=" - -# install -DFM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/DFM_mcore_wan -MBRIDGE_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/Megatron-Bridge_mcore_wan_official -MLM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/megatron-lm_latest -export PYTHONPATH="\$DFM_PATH/.:\$MBRIDGE_PATH/src/.:\$MLM_PATH/.:/opt/NeMo-Framework-Launcher/launcher_scripts" - - -# install dependencies -python3 -m pip install --upgrade diffusers -pip install easydict -pip install imageio -pip install imageio-ffmpeg -[apt update; apt install ffmpeg -y] -> for data preparation - - -cd \$DFM_PATH -export HF_TOKEN=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 -python dfm/src/megatron/data/wan/prepare_energon_dataset_wan_square_images.py \ - --input_dir /lustre/fsw/coreai_dlalgo_genai/datasets/coyo-700m/part_00000 \ - --output_dir /lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan/part_00000_square_wds \ - --model Wan-AI/Wan2.1-T2V-14B-Diffusers \ - --gpus 0,1,2,3,4,5,6,7 \ - --size 256 \ - --resize_mode bilinear \ - --save-image \ - --skip-existing \ - --stochastic - -" - -CONT="nvcr.io/nvidia/nemo:25.09.00" -MOUNT="/lustre/fsw/:/lustre/fsw/" -OUTFILE=$OUTPUT_DIR/slurm-%j.out -ERRFILE=$OUTPUT_DIR/error-%j.out -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -echo "Running training script." -srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ - --container-image="${CONT}" --container-mounts="${MOUNT}" \ - --no-container-mount-home \ - --ntasks-per-node=1 \ - -N ${SLURM_JOB_NUM_NODES} \ - bash -c "${cmd}" \ No newline at end of file diff --git a/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh b/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh deleted file mode 100644 index bfef15cd..00000000 --- a/examples/megatron/recipes/wan/launch_pretrain_wan_sbatch.sh +++ /dev/null @@ -1,88 +0,0 @@ -#!/bin/bash - -# Parameters -#SBATCH --account=coreai_dlalgo_llm -#SBATCH --job-name=coreai_dlalgo_llm-run:dfm -#SBATCH --nodes=1 -#SBATCH --partition=batch -#SBATCH --time=04:00:00 - - -EXP_NAME=sbatch_wan_1.3B_square_images_pretrain_mbs64gbs512_1tar -CHECKPOINT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/results/wan_finetune/${EXP_NAME} -PROJECT=wan -MBS=64 -GBS=512 -LR=1e-4 -WARMUP_ITERS=10000 -# set this PRETRAIN_CHECKPOINT_DIR to CHECKPOINT_DIR to train from scratch -PRETRAIN_CHECKPOINT_DIR=${CHECKPOINT_DIR} -# PRETRAIN_CHECKPOINT_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/megatron_checkpoint_1.3B - -# create checkpoint directory -mkdir -p ${CHECKPOINT_DIR} - -cmd=" - -# install -DFM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/DFM_mcore_wan -MBRIDGE_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/Megatron-Bridge_mcore_wan_official -MLM_PATH=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/nemo_vfm/megatron-lm_latest -export PYTHONPATH="\$DFM_PATH/.:\$MBRIDGE_PATH/src/.:\$MLM_PATH/.:/opt/NeMo-Framework-Launcher/launcher_scripts" - - -# install dependencies -python3 -m pip install --upgrade diffusers -pip install easydict -pip install imageio -pip install imageio-ffmpeg -[apt update; apt install ffmpeg -y] -> for data preparation - - -cd \$DFM_PATH -export HF_TOKEN=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 -export WANDB_API_KEY=497a93e5ac7cf1e0ec821741ef7bac27b36f2db8 -NVTE_FUSED_ATTN=1 torchrun --standalone --nproc_per_node=8 dfm/examples/megatron/recipe/wan/pretrain_wan.py \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.context_parallel_size=1 \ - model.sequence_parallel=false \ - model.qkv_format=sbhd \ - dataset.path="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/coyo_dataset_wan/processed_coyo_dataset_wan/part_00000_square_wds_1tar" \ - checkpoint.save=${CHECKPOINT_DIR} \ - checkpoint.load=${PRETRAIN_CHECKPOINT_DIR} \ - checkpoint.load_optim=false \ - checkpoint.save_interval=2500 \ - optimizer.lr=${LR} \ - optimizer.min_lr=${LR} \ - train.eval_iters=0 \ - train.train_iters=1000000 \ - scheduler.lr_decay_style=constant \ - scheduler.lr_decay_style=constant \ - scheduler.lr_warmup_iters=${WARMUP_ITERS} \ - model.seq_length=2048 \ - dataset.seq_length=2048 \ - train.global_batch_size=${GBS} \ - train.micro_batch_size=${MBS} \ - dataset.global_batch_size=${GBS} \ - dataset.micro_batch_size=${MBS} \ - logger.log_interval=1 \ - logger.wandb_project=${PROJECT} \ - logger.wandb_exp_name=${EXP_NAME} \ - logger.wandb_save_dir=${CHECKPOINT_DIR} - -" - -CONT="nvcr.io/nvidia/nemo:25.09.00" -MOUNT="/lustre/fsw/:/lustre/fsw/" -OUTFILE=$CHECKPOINT_DIR/slurm-%j.out -ERRFILE=$CHECKPOINT_DIR/error-%j.out -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -echo "Running training script." -srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ - --container-image="${CONT}" --container-mounts="${MOUNT}" \ - --no-container-mount-home \ - --ntasks-per-node=1 \ - -N ${SLURM_JOB_NUM_NODES} \ - bash -c "${cmd}" \ No newline at end of file diff --git a/nemo_vfm/diffusion/data/base.py b/nemo_vfm/diffusion/data/base.py deleted file mode 100644 index f412b516..00000000 --- a/nemo_vfm/diffusion/data/base.py +++ /dev/null @@ -1,444 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from copy import deepcopy -from typing import Any, Dict, Literal, Optional - -from megatron.core import parallel_state -from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset -from torch.utils.data import DataLoader -from typing_extensions import Self - -from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig -from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder -from nemo.lightning.io.mixin import IOMixin, serialization, track_io -from nemo.lightning.pytorch.plugins import MegatronDataSampler -from nemo.utils import logging - - -class EnergonMultiModalDataModule(pl.LightningDataModule, IOMixin): - """ - A PyTorch Lightning DataModule for handling multimodal datasets with images and text. - - This data module is designed to work with multimodal datasets that involve both images and text. - It provides a seamless interface to load training and validation data, manage batching, and handle - the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon - framework for efficient data handling in large-scale distributed training. - - Attributes: - path (str): Path to the energon dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int): The maximum sequence length for tokenized text. - micro_batch_size (int): The batch size for training and validation. - num_workers (int): Number of workers for data loading. - pin_memory (bool): Whether to pin memory in the DataLoader. - multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. - task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. - init_global_step (int): The initial global step for the trainer, used for resuming training. - data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. - train_dataloader_object (Optional): The DataLoader object for training data. - val_dataloader_object (Optional): The DataLoader object for validation data. - """ - - def __init__( - self, - path: str, - tokenizer, - image_processor, - seq_length: int = 2048, - micro_batch_size: int = 1, - global_batch_size: int = 1, - num_workers: int = 1, - num_val_workers: int | None = None, - pin_memory: bool = True, - shuffle_buffer_size: int = 100, - max_samples_per_sequence: int | None = None, - multimodal_sample_config: Optional[MultiModalSampleConfig] = MultiModalSampleConfig(), - task_encoder: Optional[MultiModalTaskEncoder] = None, - decoder_seq_length: Optional[int] = None, - packing_buffer_size: Optional[int] = None, - validation_task_encoder: Optional[MultiModalTaskEncoder] = None, - **kwargs, - ) -> None: - """ - Initialize the EnergonMultiModalDataModule. - - Parameters: - path (str): Path to the dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. - micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. - num_workers (int, optional): Number of workers for data loading. Defaults to 1. - num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. - pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. - multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. - Defaults to MultiModalSampleConfig(). - shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. - max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. - Defaults to None (loads the whole tar file at once). - task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. - If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. - decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models - packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. - validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding - and batching samples for validation. Defaults to None and will be the same as task_encoder. - **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon - """ - - super().__init__() - self.path = path - self.tokenizer = tokenizer - self.image_processor = image_processor - self.seq_length = seq_length - self.decoder_seq_length = decoder_seq_length - self.micro_batch_size = micro_batch_size - self.global_batch_size = global_batch_size - self.num_workers = num_workers - self.pin_memory = pin_memory - self.multimodal_sample_config = multimodal_sample_config - self.shuffle_buffer_size = shuffle_buffer_size - self.max_samples_per_sequence = max_samples_per_sequence - self.task_encoder = task_encoder or MultiModalTaskEncoder( - tokenizer=self.tokenizer, - image_processor=self.image_processor, - multimodal_sample_config=multimodal_sample_config, - ) - self.init_global_step = 0 - self.data_sampler = SequentialMegatronSampler( - seq_len=self.seq_length, - decoder_seq_len=self.decoder_seq_length, - micro_batch_size=self.micro_batch_size, - global_batch_size=self.global_batch_size, - ) - self.train_dataloader_object = None - self.val_dataloader_object = None - self.packing_buffer_size = packing_buffer_size - self.validation_task_encoder = validation_task_encoder or self.task_encoder - self.num_val_workers = num_val_workers or self.num_workers - self.kwargs = kwargs - - def io_init(self, **kwargs) -> fdl.Config[Self]: - - cfg_kwargs = { - k: deepcopy(v) - for k, v in kwargs.items() - if k not in ['image_processor', 'task_encoder', 'validation_task_encoder'] - } - - for val in cfg_kwargs.values(): - if not serialization.find_node_traverser(type(val)): - track_io(type(val)) - cfg = fdl.Config(type(self), **cfg_kwargs) - return cfg - - def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): - """ - Provide the dataset for training or validation. - - This method retrieves the dataset for the specified split (either 'train' or 'val') and configures - it according to the worker configuration. - - Parameters: - worker_config: Configuration for the data loader workers. - split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. - - Returns: - Dataset: The dataset configured for the specified split. - """ - - if split not in {'train', 'val'}: - raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") - - if split == "train": - task_encoder = self.task_encoder - else: - task_encoder = self.validation_task_encoder - - _dataset = get_train_dataset( - self.path, - batch_size=self.micro_batch_size, - task_encoder=task_encoder, - worker_config=worker_config, - packing_buffer_size=self.packing_buffer_size, - split_part=split, - shuffle_buffer_size=self.shuffle_buffer_size, - max_samples_per_sequence=self.max_samples_per_sequence, - **self.kwargs, - ) - - return _dataset - - def train_dataloader(self) -> TRAIN_DATALOADERS: - """ - Initialize and return the training DataLoader. - - This method initializes the DataLoader for the training dataset. It uses the global step - from the trainer to configure the data sampler and ensures that the parallel state is initialized - correctly for distributed training. - - Returns: - TRAIN_DATALOADERS: The DataLoader for the training dataset. - """ - if self.trainer: - self.init_global_step = self.trainer.global_step - self.data_sampler.init_global_step = self.init_global_step - logging.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") - if self.train_dataloader_object: - return self.train_dataloader_object - if not parallel_state.is_initialized(): - logging.info( - f"Muiltimodal data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - logging.info( - f" Multimodal train dataloader initializing with" - f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " - ) - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) - train_dataset = self.datasets_provider(worker_config, split='train') - energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) - self.train_dataloader_object = energon_dataloader - return self.train_dataloader_object - - def val_dataloader(self) -> EVAL_DATALOADERS: - """ - Initialize and return the validation DataLoader. - - This method initializes the DataLoader for the validation dataset. It ensures that the parallel state - is initialized correctly for distributed training and returns a configured DataLoader object. - - Returns: - EVAL_DATALOADERS: The DataLoader for the validation dataset. - """ - if self.val_dataloader_object: - return self.val_dataloader_object - - if not parallel_state.is_initialized(): - logging.info( - f"Muiltimodal val data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_val_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - - logging.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) - val_dataset = self.datasets_provider(worker_config, split='val') - energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) - self.val_dataloader_object = energon_loader - return self.val_dataloader_object - - def test_dataloader(self) -> None: - """ - Return None as test dataset split does not exist. - - This method overrides the test_dataloader method and returns None since the test dataset split - is not defined or used in this module. - - Returns: - None - """ - logging.warning("Multimodal dataloader test dataset split does not exist") - return None - - def state_dict(self) -> Dict[str, Any]: - """ - Save the state of the data module. - - This method is called when saving a checkpoint. It generates and saves the state of the data module, - including the state of the dataloader and the number of consumed samples. - - Returns: - Dict[str, Any]: A dictionary containing the state of the data module. - """ - - if self.trainer: - dataloader_obj = self.trainer.train_dataloader - - state = [] - # All ranks should be zero except the dp rank. - if ( - parallel_state.get_context_parallel_rank() - or parallel_state.get_pipeline_model_parallel_rank() - or parallel_state.get_tensor_model_parallel_rank() - or parallel_state.get_expert_model_parallel_rank() - ) == 0: - # Save_state_global in energon assumes that we call it for only the first rank within each group that - # shares the same dataloader state. By making sure that current rank is the first rank in a model - # parallel group, we ensure this. - state = dataloader_obj.save_state_global(global_dst_rank=0) - - consumed_samples = self.data_sampler.compute_consumed_samples( - self.trainer.global_step - self.init_global_step - ) - - if state is None: - state = [] # Megatron core requires all the states on all the ranks to have same python - # type. Energon sends the state as a list - logging.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") - return {'dataloader_state': state, 'consumed_samples': consumed_samples} - - logging.warning("trainer object not connected to data module object returning empty state") - return {} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """ - Load the state of the data module from a checkpoint. - - This method is called when loading a checkpoint. It restores the state of the data module, - including the state of the dataloader and the number of consumed samples. - - Parameters: - state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. - """ - if not 'dataloader_state' in state_dict: - logging.warning( - f"Data loader state cannot be resumed from state_dict, " - f"it does not have the required key dataloader_state. It has {state_dict.keys()}" - ) - return - - state = state_dict['dataloader_state'] - try: - if self.trainer: - self.trainer.datamodule.train_dataloader().restore_state_global(state) - logging.info("Multimodal dataloader state restored") - else: - logging.error(f"Cannot restore state from state_dict {state_dict}") - raise ValueError( - "Cannot restore state from state_dict: " - "Is the trainer object is initialized and attached to datamodule???" - ) - except Exception as e: - logging.warning( - f"Failed to dataloader restore state due to [Please ensure you are using same version " - f"of energon while saving and loading, Continuing without restoring data loader] : {e}" - ) - - try: - from megatron.core.num_microbatches_calculator import update_num_microbatches - - except (ImportError, ModuleNotFoundError): - logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") - from apex.transformer.pipeline_parallel.utils import update_num_microbatches - - consumed_samples = state_dict['consumed_samples'] - self.data_sampler.init_consumed_samples = consumed_samples - self.data_sampler.prev_consumed_samples = consumed_samples - logging.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") - update_num_microbatches( - consumed_samples=consumed_samples, - consistency_check=False, - ) - - -class SequentialMegatronSampler(MegatronDataSampler): - """ - A data sampler for sequential sampling in Megatron, designed to handle large datasets efficiently. - - This class extends the MegatronDataSampler to support sequential sampling for large datasets. - It includes functionality for handling micro-batches and tracking consumed samples across training steps. - - Attributes: - seq_len (int): The sequence length for each sample. - micro_batch_size (int): The number of samples in each micro-batch. - init_consumed_samples (int): The initial number of samples that have been consumed (used for resuming training). - prev_consumed_samples (int): Tracks the number of consumed samples before the current step. - if_first_step (int): Flag to indicate if it's the first training step. - prev_global_batch_size (Optional[int]): The global batch size from the previous step. - init_global_step (int): The initial global step at the start of training. - """ - - def __init__( - self, - seq_len: int, - micro_batch_size: int = 4, - global_batch_size: int = 8, - init_consumed_samples: int = 0, - decoder_seq_len: Optional[int] = None, - init_global_step=0, - ): - """ - Initialize the SequentialMegatronSampler. - - Parameters: - seq_len (int): The sequence length for each sample. - micro_batch_size (int, optional): The number of samples in each micro-batch. Defaults to 4. - init_consumed_samples (int, optional): The initial number of samples that have been consumed. Defaults to 0. - init_global_step (int, optional): The initial global step at the start of training. Defaults to 0. - """ - super().__init__( - seq_len=seq_len, - decoder_seq_len=decoder_seq_len, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - init_consumed_samples=init_consumed_samples, - init_global_step=init_global_step, - ) - - def transform_dataloader(self, dataloader: DataLoader) -> DataLoader: - """ - Transform the DataLoader for sequential sampling. - - This method returns the DataLoader as is, but it can be overridden to apply specific transformations to - the DataLoader if needed. - - Parameters: - dataloader (DataLoader): The original DataLoader to be transformed. - - Returns: - DataLoader: The transformed DataLoader. - """ - return dataloader - - @property - def megatron_data_kwargs(self) -> Dict[str, Any]: - """ - Return the keyword arguments required for Megatron data handling. - - This property provides the necessary arguments that Megatron uses to handle data, including sequence length, - micro-batch size, and the number of micro-batches. - - Returns: - Dict[str, Any]: A dictionary containing the Megatron data handling arguments. - """ - return { - "seq_length": self.seq_len, - "micro_batch_size": self.micro_batch_size, - "num_microbatches": self.num_microbatches, - } diff --git a/nemo_vfm/diffusion/models/dit/dit_attention_megatron.py b/nemo_vfm/diffusion/models/dit/dit_attention_megatron.py deleted file mode 100644 index e69de29b..00000000 From abbaa2a793f7b5c5c0eef6bab82c0cb167b179be Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 00:25:34 -0800 Subject: [PATCH 10/80] update refactor --- dfm/src/megatron/base/layerspec/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 dfm/src/megatron/base/layerspec/__init__.py diff --git a/dfm/src/megatron/base/layerspec/__init__.py b/dfm/src/megatron/base/layerspec/__init__.py deleted file mode 100644 index e69de29b..00000000 From c59f6a2504640529d11ef3591f29ca1a01e66e59 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 07:44:32 -0800 Subject: [PATCH 11/80] reorganize files --- .../megatron/data/Dit/{data => }/__init__.py | 0 .../data/Dit/diffusion_energon_datamodule.py | 176 ++++++++++++ .../data/Dit/diffusion_taskencoder.py | 256 ++++++++++++++++++ .../data/Dit/prepare_energon_dataset.py | 117 ++++++++ dfm/src/megatron/data/Dit/utils.py | 203 ++++++++++++++ .../data/wan/wan_energon_datamodule.py | 2 +- .../recipes/wan/nemo_mcore_t5_sbatch_mcore.sh | 128 --------- 7 files changed, 753 insertions(+), 129 deletions(-) rename dfm/src/megatron/data/Dit/{data => }/__init__.py (100%) create mode 100644 dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py create mode 100644 dfm/src/megatron/data/Dit/diffusion_taskencoder.py create mode 100644 dfm/src/megatron/data/Dit/prepare_energon_dataset.py create mode 100644 dfm/src/megatron/data/Dit/utils.py delete mode 100644 examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh diff --git a/dfm/src/megatron/data/Dit/data/__init__.py b/dfm/src/megatron/data/Dit/__init__.py similarity index 100% rename from dfm/src/megatron/data/Dit/data/__init__.py rename to dfm/src/megatron/data/Dit/__init__.py diff --git a/dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py new file mode 100644 index 00000000..b78a6dbc --- /dev/null +++ b/dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py @@ -0,0 +1,176 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import logging +from typing import Any, Dict, Literal + +from torch import int_repr + +from dfm.src.megatron.data.Dit.diffusion_taskencoder import BasicDiffusionTaskEncoder +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from megatron.energon import DefaultTaskEncoder, get_train_dataset +from dfm.src.megatron.data.Dit.base import EnergonMultiModalDataModule + +@dataclass(kw_only=True) +class DiffusionDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + task_encoder_seq_length: int + global_batch_size: int + num_workers: int_repr + dataloader_type: str = "external" + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=BasicDiffusionTaskEncoder(seq_length=self.task_encoder_seq_length), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + + + + +class DiffusionDataModule(EnergonMultiModalDataModule): + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 8, + num_workers: int = 1, + pin_memory: bool = True, + task_encoder: DefaultTaskEncoder = None, + use_train_split_for_val: bool = False, + ) -> None: + """ + Initialize the SimpleMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + """ + + super().__init__( + path=path, + tokenizer=None, + image_processor=None, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=num_workers, + pin_memory=pin_memory, + task_encoder=task_encoder, + ) + self.use_train_split_for_val = use_train_split_for_val + + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + if split not in {"train", "val"}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + if self.use_train_split_for_val: + split = "train" + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + task_encoder=self.task_encoder, + worker_config=worker_config, + max_samples_per_sequence=None, + shuffle_buffer_size=100, + split_part=split, + batch_drop_last=True, + virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning + ) + return _dataset + + def val_dataloader(self): + """ + Configure the validation DataLoader. + + This method configures the DataLoader for validation data. + + Parameters: + worker_config: Configuration for the data loader workers. + + Returns: + DataLoader: The DataLoader for validation data. + """ + if self.use_train_split_for_val: + return self.train_dataloader() + return super().val_dataloader() + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + try: + super().load_state_dict(state_dict) + except Exception as e: + logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/dfm/src/megatron/data/Dit/diffusion_taskencoder.py b/dfm/src/megatron/data/Dit/diffusion_taskencoder.py new file mode 100644 index 00000000..7faa1aaa --- /dev/null +++ b/dfm/src/megatron/data/Dit/diffusion_taskencoder.py @@ -0,0 +1,256 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import torch +import torch.nn.functional as F +from einops import rearrange +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample[".json"], + pth=sample[".pth"], + pickle=sample[".pickle"], + ) + + +class BasicDiffusionTaskEncoder(DefaultTaskEncoder): + """ + BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. + Attributes: + cookers (list): A list of Cooker objects used for processing. + max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. + text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. + Methods: + __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): + Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. + encode_sample(sample: dict) -> dict: + Encodes a given sample dictionary containing video and text data. + Args: + sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. + Returns: + dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. + Raises: + SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. + """ + + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + text_embedding_padding_size: int = 512, + seq_length: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.max_frames = max_frames + self.text_embedding_padding_size = text_embedding_padding_size + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + + def encode_sample(self, sample: dict) -> dict: + video_latent = sample["pth"] + + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + info = sample["json"] + # remove batch dimension + video_latent = video_latent.squeeze(0) + # print(f"video_latent shape at start: {video_latent.shape}") + C, T, H, W = video_latent.shape + seq_len = ( + video_latent.shape[-1] + * video_latent.shape[-2] + * video_latent.shape[-3] + // self.patch_spatial**2 + // self.patch_temporal + ) + # seq_len = 1536 + is_image = T == 1 + + # print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") + if seq_len > self.seq_length: + print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") + raise SkipSample() + + if self.max_frames is not None: + video_latent = video_latent[:, : self.max_frames, :, :] + + # tpcp_size = parallel_state.get_tensor_model_parallel_world_size() + # if parallel_state.get_context_parallel_world_size() > 1: + # tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 + # if (T * H * W) % tpcp_size != 0: + # warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') + # raise SkipSample() + # print(f"video_latent shape before rearrange: {video_latent.shape}") + # video_latent shape before rearrange: torch.Size([16, 1, 64, 96]) + video_latent = rearrange( + video_latent, + "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", + ph=self.patch_spatial, + pw=self.patch_spatial, + pt=self.patch_temporal, + ) + # print(f"video_latent shape after rearrange: {video_latent.shape}") + # After reaaranging: video_latent shape after rearrange: torch.Size([1536, 64]) + # convert sample["pickle"] to numpy, and remove batch dimension + sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) + if is_image: + t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) + else: + t5_text_embeddings = torch.from_numpy(sample["pickle"][0]).to(torch.bfloat16) + t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] + + if t5_text_embeddings_seq_length > self.text_embedding_padding_size: + t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] + else: + t5_text_embeddings = F.pad( + t5_text_embeddings, + ( + 0, + 0, + 0, + self.text_embedding_padding_size - t5_text_embeddings_seq_length, + ), + ) + t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) + + if is_image: + h, w = info["image_height"], info["image_width"] + fps = torch.tensor([30] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) + else: + h, w = info["height"], info["width"] + fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) + image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), + "T H W d -> (T H W) d", + ) + + if self.seq_length is not None: + pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) + loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) + loss_mask[:seq_len] = 1 + video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) + else: + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + + print(f"Loss mask shape: {loss_mask.shape}") + print(f"video_latent shape final: {video_latent.shape}") + return dict( + video=video_latent, + t5_text_embeddings=t5_text_embeddings, + t5_text_mask=t5_text_mask, + image_size=image_size, + fps=fps, + num_frames=num_frames, + loss_mask=loss_mask, + seq_len_q=torch.tensor(seq_len, dtype=torch.int32), + seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), + pos_ids=pos_ids, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), + ) + + +class PosID3D: + def __init__(self, *, max_t=32, max_h=128, max_w=128): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device="cpu"), + torch.arange(self.max_h, device="cpu"), + torch.arange(self.max_w, device="cpu"), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +pos_id_3d = PosID3D() + + +def cook_raw_iamges(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'jpg': original images + - 'png': contains control images + - 'txt': contains raw text + """ + return dict( + **basic_sample_keys(sample), + images=sample["jpg"], + hint=sample["png"], + txt=sample["txt"], + ) + + +class RawImageDiffusionTaskEncoder(DefaultTaskEncoder): + """ + Dummy task encoder takes raw image input on CrudeDataset. + """ + + cookers = [ + # Cooker(cook), + Cooker(cook_raw_iamges), + ] diff --git a/dfm/src/megatron/data/Dit/prepare_energon_dataset.py b/dfm/src/megatron/data/Dit/prepare_energon_dataset.py new file mode 100644 index 00000000..56e57684 --- /dev/null +++ b/dfm/src/megatron/data/Dit/prepare_energon_dataset.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import os +import pickle +from typing import Callable, List + +import nemo_run as run +import numpy as np +import torch +import torch.distributed as dist +import webdataset as wds + + +def get_start_end_idx_for_this_rank(dataset_size, rank, world_size): + """ + Calculate the start and end indices for a given rank in a distributed setting. + + Args: + dataset_size (int): The total size of the dataset. + rank (int): The rank of the current process. + world_size (int): The total number of processes. + + Returns: + tuple: A tuple containing the start index (int) and end index (int) for the given rank. + """ + split_size = dataset_size // world_size + start_idx = rank * split_size + # The last rank takes the remainder + end_idx = start_idx + split_size if rank != world_size - 1 else dataset_size + return start_idx, end_idx + + +def dummy_process_func(input): + """ + Generates a sample dictionary containing random image latent tensor, text embedding, + and metadata based on the provided input key. + + Args: + input (str): The key to be used in the sample dictionary. + + Returns: + dict: A dictionary containing the following keys: + - "__key__": The input key. + - ".pth": A randomly generated image latent tensor with shape (3, 1, 720, 1280) and dtype torch.bfloat16. + - ".pickle": A pickled numpy array representing a random text embedding with shape (512, 2048). + - ".json": A dictionary containing metadata with keys: + - "image_height": The height of the image (720). + - "image_width": The width of the image (1280). + """ + C, T, H, W = 3, 1, 720, 1280 + image_latent = torch.randn(C, T, H, W, dtype=torch.bfloat16) + text_embedding = np.random.randn(512, 2048) + sample = { + "__key__": input, + ".pth": image_latent, + ".pickle": pickle.dumps(text_embedding), + ".json": { + "image_height": H, + "image_width": W, + }, + } + return sample + + +@torch.no_grad() +@run.cli.entrypoint +def prepare(process_func: Callable, inputs: List[str], output_dir: str = "output"): + """ + distributed prepration webdataset using the provided processing function, and writes the processed samples to tar files. + + Args: + process_func (Callable): A function that processes a single input and returns the processed sample. + inputs (List[str]): A list of input file paths or data entries to be processed. + output_dir (str, optional): The directory where the output tar files will be saved. Defaults to 'output'. + """ + rank = dist.get_rank() + world_size = torch.distributed.get_world_size() + + start_idx, end_idx = get_start_end_idx_for_this_rank(len(inputs), rank, world_size) + os.makedirs(output_dir, exist_ok=True) + output_tar = os.path.join(output_dir, f"rank{rank}-%06d.tar") + with wds.ShardWriter(output_tar, maxcount=10000) as sink: + for i in range(start_idx, end_idx): + sample = process_func(inputs[i]) + # Write the sample to the tar file + sink.write(sample) + + +@run.cli.factory(target=prepare) +def prepare_dummy_image_dataset() -> run.Partial: + recipe = run.Partial( + prepare, + process_func=dummy_process_func, + inputs=list(str(i) for i in range(1000)), + ) + return recipe + + +if __name__ == "__main__": + dist.init_process_group("nccl") + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + run.cli.main(prepare, default_factory=prepare_dummy_image_dataset) diff --git a/dfm/src/megatron/data/Dit/utils.py b/dfm/src/megatron/data/Dit/utils.py new file mode 100644 index 00000000..dbe8ebad --- /dev/null +++ b/dfm/src/megatron/data/Dit/utils.py @@ -0,0 +1,203 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +import numpy as np + + +def minimal_crop(tensor, target_divisor): + """ + Crops the input tensor minimally so that the total number of elements + (T * H * W) is divisible by the specified target_divisor. + + Parameters: + - tensor: NumPy array of shape (C, T, H, W) + - target_divisor: Positive integer specifying the desired divisor + + Returns: + - cropped_tensor: Cropped tensor meeting the divisibility requirement + + Raises: + - ValueError: If it's impossible to meet the divisibility requirement + """ + if not isinstance(target_divisor, int) or target_divisor <= 0: + raise ValueError("target_divisor must be a positive integer greater than zero.") + + C, T, H, W = tensor.shape + total_elements = T * H * W + remainder = total_elements % target_divisor + + if remainder == 0: + return tensor # No cropping needed + + # Elements per unit length in each dimension + elements_per_T = H * W + elements_per_H = T * W + elements_per_W = T * H + + min_elements_removed = None + optimal_deltas = None + + # Limit the search range to avoid unnecessary computations + max_delta_T = min(T - 1, (remainder // elements_per_T) + 1) + max_delta_H = min(H - 1, (remainder // elements_per_H) + 1) + max_delta_W = min(W - 1, (remainder // elements_per_W) + 1) + + for delta_T in range(0, max_delta_T + 1): + for delta_H in range(0, max_delta_H + 1): + for delta_W in range(0, max_delta_W + 1): + if delta_T == delta_H == delta_W == 0: + continue # No cropping + + new_T = T - delta_T + new_H = H - delta_H + new_W = W - delta_W + + if new_T <= 0 or new_H <= 0 or new_W <= 0: + continue # Invalid dimensions + + new_total_elements = new_T * new_H * new_W + if new_total_elements % target_divisor == 0: + elements_removed = delta_T * elements_per_T + delta_H * elements_per_H + delta_W * elements_per_W + if min_elements_removed is None or elements_removed < min_elements_removed: + min_elements_removed = elements_removed + optimal_deltas = (delta_T, delta_H, delta_W) + + if optimal_deltas is None: + raise ValueError("Cannot crop tensor to meet divisibility requirement.") + + delta_T, delta_H, delta_W = optimal_deltas + + # Perform the cropping + # T dimension: crop from the end + end_T = T - delta_T + + # H dimension: center crop + start_H = delta_H // 2 + end_H = H - (delta_H - delta_H // 2) + + # W dimension: center crop + start_W = delta_W // 2 + end_W = W - (delta_W - delta_W // 2) + + cropped_tensor = tensor[:, :end_T, start_H:end_H, start_W:end_W] + return cropped_tensor + + +def test_no_cropping_needed(): + """Test when the tensor already meets the divisibility requirement.""" + C, T, H, W = 3, 8, 8, 8 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + assert cropped_tensor.shape == (C, T, H, W) + assert (T * H * W) % target_divisor == 0 + + +def test_minimal_cropping_T_dimension(): + """Test minimal cropping along the T dimension.""" + C, T, H, W = 3, 9, 7, 6 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_T = cropped_tensor.shape[1] + assert new_T == T - 1, cropped_tensor.shape + assert (new_T * H * W) % target_divisor == 0 + + +def test_minimal_cropping_H_dimension(): + """Test minimal cropping along the H dimension.""" + C, T, H, W = 3, 7, 9, 6 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_H = cropped_tensor.shape[2] + assert new_H == H - 1, cropped_tensor.shape + assert (T * new_H * W) % target_divisor == 0 + + +def test_minimal_cropping_W_dimension(): + """Test minimal cropping along the W dimension.""" + C, T, H, W = 3, 4, 3, 9 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_W = cropped_tensor.shape[3] + assert new_W == W - 1, cropped_tensor.shape + assert (T * H * new_W) % target_divisor == 0 + + +def test_cropping_multiple_dimensions(): + """Test when minimal cropping requires adjustments on multiple dimensions.""" + C, T, H, W = 3, 9, 9, 8 + target_divisor = 16 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + new_T, new_H, new_W = cropped_tensor.shape[1:] + assert new_T <= T and new_H <= H and new_W <= W + assert (new_T * new_H * new_W) % target_divisor == 0 + + +def test_large_tensor_high_divisor(): + """Test with a larger tensor and higher target_divisor.""" + C, T, H, W = 3, 50, 50, 50 + target_divisor = 1024 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + total_elements = cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3] + assert total_elements % target_divisor == 0 + + +def test_impossible_cropping(): + """Test that an error is raised when it's impossible to meet the requirement.""" + C, T, H, W = 3, 1, 1, 1 + target_divisor = 2 + tensor = np.zeros((C, T, H, W)) + try: + minimal_crop(tensor, target_divisor) + except ValueError: + pass + + +def test_invalid_target_divisor(): + """Test that an error is raised when target_divisor is invalid.""" + C, T, H, W = 3, 8, 8, 8 + tensor = np.zeros((C, T, H, W)) + try: + minimal_crop(tensor, -1) + except ValueError: + pass + + +def test_minimal_elements_removed(): + """Test that the minimal number of elements are removed.""" + C, T, H, W = 3, 7, 7, 7 + target_divisor = 8 + tensor = np.zeros((C, T, H, W)) + cropped_tensor = minimal_crop(tensor, target_divisor) + elements_removed = (T * H * W) - (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) + print(cropped_tensor.shape) + assert elements_removed > 0 + assert (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) % target_divisor == 0 + + +test_no_cropping_needed() +test_minimal_elements_removed() +test_cropping_multiple_dimensions() +test_minimal_cropping_T_dimension() +test_minimal_cropping_H_dimension() +test_minimal_cropping_W_dimension() +test_impossible_cropping() +test_invalid_target_divisor() diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index cc20dedc..9a4fb09a 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -20,7 +20,7 @@ from torch import int_repr -from dfm.src.megatron.data.Dit.data.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.Dit.diffusion_energon_datamodule import DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider diff --git a/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh b/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh deleted file mode 100644 index f086f52a..00000000 --- a/examples/megatron/recipes/wan/nemo_mcore_t5_sbatch_mcore.sh +++ /dev/null @@ -1,128 +0,0 @@ -#!/bin/bash - -# Parameters -#SBATCH --account=coreai_dlalgo_llm -#SBATCH --job-name=coreai_dlalgo_llm-run:nemo_mcore_t5 -#SBATCH --nodes=1 -#SBATCH --partition=batch -#SBATCH --time=04:00:00 - -# sbatch script -mcore_t5=False - -NUM_DEVICES=8 -NEMO_DIR=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo -EXPNAME=nemo_nonmcore_t5_fp32_sbatch_mcore -PROJECT=t5_pretrain_sbatch - -CONFIG_NAME='megatron_t5_config' - -PRECISION=32 -AMP_O2=False -MICRO_BATCH_SIZE=64 -GLOBAL_BATCH_SIZE=512 -ACCUMULATE_GRAD_BATCHES=1 -TENSOR_MODEL_PARALLEL_SIZE=1 -VAL_CHECK_INTERVAL=2000 -MAX_STEPS=1000000 -# BLEND="[.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_00_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_01_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_02_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_03_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_04_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_05_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_06_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_07_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_08_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_09_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_10_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_11_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_12_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_13_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_14_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_15_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_16_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_17_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_18_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_19_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_20_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_21_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_22_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_23_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_24_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_25_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_26_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_27_bert_tokenizer_text_document,.0333,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_28_bert_tokenizer_text_document,.0334,/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_29_bert_tokenizer_text_document]" -BLEND="[/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/my-t5_00_bert_tokenizer_text_document]" - -# Model architecture -SEQ_LENGTH=512 -SEQ_LENGTH_DEC=128 -NUM_LAYERS=12 -HIDDEN_SIZE=768 -NUM_ATTENTION_HEADS=12 - -home_dir=/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/results/t5_nemo -exp_dir=$home_dir/${EXPNAME} -mkdir ${exp_dir} - -cmd=" - -cd /lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo -pip install -e . -NEMO=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo -export PYTHONPATH="${NEMO}/.:${PYTHONPATH}" -## Megatron-LM -cd /lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm -pip install -e . -MEGATRONLM=/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm -export PYTHONPATH="${MEGATRONLM}/.:${PYTHONPATH}" -export PYTHONPATH="/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/megatron-lm/.:/lustre/fsw/coreai_dlalgo_genai/huvu/codes/T5_nemo/NeMo/.:/opt/NeMo-Megatron-Launcher/launcher_scripts" - -export WANDB_API_KEY=497a93e5ac7cf1e0ec821741ef7bac27b36f2db8 - -if [[ ${PRECISION} != "32" ]]; then - export NVTE_FUSED_ATTN=0 - export NVTE_FLASH_ATTN=0 -fi - -python ${NEMO_DIR}/examples/nlp/language_modeling/megatron_t5_pretraining.py \ - --config-name=${CONFIG_NAME} \ - trainer.num_nodes=1 \ - trainer.devices=${NUM_DEVICES} \ - trainer.max_epochs=null \ - trainer.max_steps=${MAX_STEPS} \ - trainer.val_check_interval=${VAL_CHECK_INTERVAL} \ - trainer.accumulate_grad_batches=${ACCUMULATE_GRAD_BATCHES} \ - trainer.precision=${PRECISION} \ - trainer.log_every_n_steps=1 \ - model.megatron_amp_O2=${AMP_O2} \ - model.micro_batch_size=${MICRO_BATCH_SIZE} \ - model.global_batch_size=${GLOBAL_BATCH_SIZE} \ - model.tensor_model_parallel_size=${TENSOR_MODEL_PARALLEL_SIZE} \ - model.max_position_embeddings=${SEQ_LENGTH} \ - model.seq_length=${SEQ_LENGTH} \ - model.encoder.hidden_size=${HIDDEN_SIZE} \ - model.decoder.hidden_size=${HIDDEN_SIZE} \ - model.encoder.num_layers=${NUM_LAYERS} \ - model.decoder.num_layers=${NUM_LAYERS} \ - model.encoder.num_attention_heads=${NUM_ATTENTION_HEADS} \ - model.decoder.num_attention_heads=${NUM_ATTENTION_HEADS} \ - model.encoder.init_method_std=0.015 \ - model.decoder.init_method_std=0.015 \ - model.encoder.transformer_block_type='pre_ln' \ - model.decoder.transformer_block_type='pre_ln' \ - model.data.data_prefix=${BLEND} \ - model.data.seq_length=${SEQ_LENGTH} \ - model.tokenizer.vocab_file=/lustre/fsw/coreai_dlalgo_genai/huvu/data/t5/training_data/symlinks/bert-large-cased-vocab.txt \ - model.data.seq_length_dec=${SEQ_LENGTH_DEC} \ - model.data.splits_string=\'99982,9,9\' \ - model.data.num_workers=4 \ - model.optim.name=distributed_fused_adam \ - model.mcore_t5=${mcore_t5} \ - model.transformer_engine=True \ - +model.kv_channels=64 \ - exp_manager.create_wandb_logger=True \ - exp_manager.wandb_logger_kwargs.name=${EXPNAME} \ - exp_manager.wandb_logger_kwargs.project=${PROJECT} \ - +exp_manager.wandb_logger_kwargs.resume=True \ - exp_manager.explicit_log_dir=${exp_dir} \ - exp_manager.resume_if_exists=True \ - exp_manager.resume_ignore_no_checkpoint=True \ - exp_manager.create_checkpoint_callback=True \ - exp_manager.checkpoint_callback_params.monitor=val_loss \ - exp_manager.checkpoint_callback_params.save_top_k=3 \ - exp_manager.checkpoint_callback_params.mode=min \ - exp_manager.checkpoint_callback_params.always_save_nemo=False \ - ++exp_manager.checkpoint_callback_params.save_nemo_on_train_end=False \ - ++exp_manager.log_step_timing=True \ -" - -# ++model.async_grad_allreduce=False for training with bf16, O2, FusedAdam - -CONT="gitlab-master.nvidia.com/dl/joc/nemo-ci/train:pipe.14465850" -MOUNT="/lustre/fsw/:/lustre/fsw/" -OUTFILE=$exp_dir/slurm-%j.out -ERRFILE=$exp_dir/error-%j.out -export CUDA_DEVICE_MAX_CONNECTIONS=1 - -echo "Running training script." -srun -o ${OUTFILE} -e ${ERRFILE} --mpi=pmix \ - --container-image="${CONT}" --container-mounts="${MOUNT}" \ - --no-container-mount-home \ - --ntasks-per-node=8 \ - -N ${SLURM_JOB_NUM_NODES} \ - bash -c "${cmd}" \ No newline at end of file From 0b91a1c4251410f761da741b27b3d909db09ff8e Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 07:48:20 -0800 Subject: [PATCH 12/80] reorganize files --- .../Dit/data/diffusion_energon_datamodule.py | 176 ------------ .../data/Dit/data/diffusion_taskencoder.py | 256 ------------------ .../data/Dit/data/prepare_energon_dataset.py | 117 -------- dfm/src/megatron/data/Dit/data/utils.py | 203 -------------- 4 files changed, 752 deletions(-) delete mode 100644 dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py delete mode 100644 dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py delete mode 100644 dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py delete mode 100644 dfm/src/megatron/data/Dit/data/utils.py diff --git a/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py b/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py deleted file mode 100644 index fa38e9c6..00000000 --- a/dfm/src/megatron/data/Dit/data/diffusion_energon_datamodule.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -from dataclasses import dataclass -import logging -from typing import Any, Dict, Literal - -from torch import int_repr - -from megatron.bridge.data.Dit.data.diffusion_taskencoder import BasicDiffusionTaskEncoder -from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider -from megatron.energon import DefaultTaskEncoder, get_train_dataset -from megatron.bridge.data.Dit.base import EnergonMultiModalDataModule - -@dataclass(kw_only=True) -class DiffusionDataModuleConfig(DatasetProvider): - path: str - seq_length: int - micro_batch_size: int - task_encoder_seq_length: int - global_batch_size: int - num_workers: int_repr - dataloader_type: str = "external" - - def __post_init__(self): - self.dataset = DiffusionDataModule( - path=self.path, - seq_length=self.seq_length, - task_encoder=BasicDiffusionTaskEncoder(seq_length=self.task_encoder_seq_length), - micro_batch_size=self.micro_batch_size, - global_batch_size=self.global_batch_size, - num_workers=self.num_workers) - self.sequence_length = self.dataset.seq_length - - def build_datasets(self, context: DatasetBuildContext): - return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() - - - - -class DiffusionDataModule(EnergonMultiModalDataModule): - """ - A PyTorch Lightning DataModule for handling multimodal datasets with images and text. - - This data module is designed to work with multimodal datasets that involve both images and text. - It provides a seamless interface to load training and validation data, manage batching, and handle - the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon - framework for efficient data handling in large-scale distributed training. - - Attributes: - path (str): Path to the energon dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int): The maximum sequence length for tokenized text. - micro_batch_size (int): The batch size for training and validation. - num_workers (int): Number of workers for data loading. - pin_memory (bool): Whether to pin memory in the DataLoader. - multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. - task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. - init_global_step (int): The initial global step for the trainer, used for resuming training. - data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. - train_dataloader_object (Optional): The DataLoader object for training data. - val_dataloader_object (Optional): The DataLoader object for validation data. - """ - - def __init__( - self, - path: str, - seq_length: int = 2048, - micro_batch_size: int = 1, - global_batch_size: int = 8, - num_workers: int = 1, - pin_memory: bool = True, - task_encoder: DefaultTaskEncoder = None, - use_train_split_for_val: bool = False, - ) -> None: - """ - Initialize the SimpleMultiModalDataModule. - - Parameters: - path (str): Path to the dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. - micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. - num_workers (int, optional): Number of workers for data loading. Defaults to 1. - pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. - """ - - super().__init__( - path=path, - tokenizer=None, - image_processor=None, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - num_workers=num_workers, - pin_memory=pin_memory, - task_encoder=task_encoder, - ) - self.use_train_split_for_val = use_train_split_for_val - - def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): - """ - Provide the dataset for training or validation. - - This method retrieves the dataset for the specified split (either 'train' or 'val') and configures - it according to the worker configuration. - - Parameters: - worker_config: Configuration for the data loader workers. - split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. - - Returns: - Dataset: The dataset configured for the specified split. - """ - if split not in {"train", "val"}: - raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") - if self.use_train_split_for_val: - split = "train" - _dataset = get_train_dataset( - self.path, - batch_size=self.micro_batch_size, - task_encoder=self.task_encoder, - worker_config=worker_config, - max_samples_per_sequence=None, - shuffle_buffer_size=100, - split_part=split, - batch_drop_last=True, - virtual_epoch_length=1_000_000_000, # a hack to avoid energon end of epoch warning - ) - return _dataset - - def val_dataloader(self): - """ - Configure the validation DataLoader. - - This method configures the DataLoader for validation data. - - Parameters: - worker_config: Configuration for the data loader workers. - - Returns: - DataLoader: The DataLoader for validation data. - """ - if self.use_train_split_for_val: - return self.train_dataloader() - return super().val_dataloader() - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """ - Load the state of the data module from a checkpoint. - - This method is called when loading a checkpoint. It restores the state of the data module, - including the state of the dataloader and the number of consumed samples. - - Parameters: - state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. - """ - try: - super().load_state_dict(state_dict) - except Exception as e: - logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py b/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py deleted file mode 100644 index 7faa1aaa..00000000 --- a/dfm/src/megatron/data/Dit/data/diffusion_taskencoder.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import torch -import torch.nn.functional as F -from einops import rearrange -from megatron.energon import DefaultTaskEncoder, SkipSample -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys - - -def cook(sample: dict) -> dict: - """ - Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. - - Args: - sample (dict): The input dictionary containing the raw sample data. - - Returns: - dict: A new dictionary containing the processed sample data with the following keys: - - All keys from the result of `basic_sample_keys(sample)` - - 'json': The contains meta data like resolution, aspect ratio, fps, etc. - - 'pth': contains video latent tensor - - 'pickle': contains text embeddings - """ - return dict( - **basic_sample_keys(sample), - json=sample[".json"], - pth=sample[".pth"], - pickle=sample[".pickle"], - ) - - -class BasicDiffusionTaskEncoder(DefaultTaskEncoder): - """ - BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. - Attributes: - cookers (list): A list of Cooker objects used for processing. - max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. - text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. - Methods: - __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): - Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. - encode_sample(sample: dict) -> dict: - Encodes a given sample dictionary containing video and text data. - Args: - sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. - Returns: - dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. - Raises: - SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. - """ - - cookers = [ - Cooker(cook), - ] - - def __init__( - self, - *args, - max_frames: int = None, - text_embedding_padding_size: int = 512, - seq_length: int = None, - patch_spatial: int = 2, - patch_temporal: int = 1, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.max_frames = max_frames - self.text_embedding_padding_size = text_embedding_padding_size - self.seq_length = seq_length - self.patch_spatial = patch_spatial - self.patch_temporal = patch_temporal - - def encode_sample(self, sample: dict) -> dict: - video_latent = sample["pth"] - - if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): - raise SkipSample() - if torch.max(torch.abs(video_latent)) > 1e3: - raise SkipSample() - - info = sample["json"] - # remove batch dimension - video_latent = video_latent.squeeze(0) - # print(f"video_latent shape at start: {video_latent.shape}") - C, T, H, W = video_latent.shape - seq_len = ( - video_latent.shape[-1] - * video_latent.shape[-2] - * video_latent.shape[-3] - // self.patch_spatial**2 - // self.patch_temporal - ) - # seq_len = 1536 - is_image = T == 1 - - # print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") - if seq_len > self.seq_length: - print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") - raise SkipSample() - - if self.max_frames is not None: - video_latent = video_latent[:, : self.max_frames, :, :] - - # tpcp_size = parallel_state.get_tensor_model_parallel_world_size() - # if parallel_state.get_context_parallel_world_size() > 1: - # tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 - # if (T * H * W) % tpcp_size != 0: - # warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') - # raise SkipSample() - # print(f"video_latent shape before rearrange: {video_latent.shape}") - # video_latent shape before rearrange: torch.Size([16, 1, 64, 96]) - video_latent = rearrange( - video_latent, - "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", - ph=self.patch_spatial, - pw=self.patch_spatial, - pt=self.patch_temporal, - ) - # print(f"video_latent shape after rearrange: {video_latent.shape}") - # After reaaranging: video_latent shape after rearrange: torch.Size([1536, 64]) - # convert sample["pickle"] to numpy, and remove batch dimension - sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) - if is_image: - t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) - else: - t5_text_embeddings = torch.from_numpy(sample["pickle"][0]).to(torch.bfloat16) - t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] - - if t5_text_embeddings_seq_length > self.text_embedding_padding_size: - t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] - else: - t5_text_embeddings = F.pad( - t5_text_embeddings, - ( - 0, - 0, - 0, - self.text_embedding_padding_size - t5_text_embeddings_seq_length, - ), - ) - t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) - - if is_image: - h, w = info["image_height"], info["image_width"] - fps = torch.tensor([30] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) - else: - h, w = info["height"], info["width"] - fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) - image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) - - pos_ids = rearrange( - pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), - "T H W d -> (T H W) d", - ) - - if self.seq_length is not None: - pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) - loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) - loss_mask[:seq_len] = 1 - video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) - else: - loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) - - print(f"Loss mask shape: {loss_mask.shape}") - print(f"video_latent shape final: {video_latent.shape}") - return dict( - video=video_latent, - t5_text_embeddings=t5_text_embeddings, - t5_text_mask=t5_text_mask, - image_size=image_size, - fps=fps, - num_frames=num_frames, - loss_mask=loss_mask, - seq_len_q=torch.tensor(seq_len, dtype=torch.int32), - seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), - pos_ids=pos_ids, - latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), - ) - - -class PosID3D: - def __init__(self, *, max_t=32, max_h=128, max_w=128): - self.max_t = max_t - self.max_h = max_h - self.max_w = max_w - self.generate_pos_id() - - def generate_pos_id(self): - self.grid = torch.stack( - torch.meshgrid( - torch.arange(self.max_t, device="cpu"), - torch.arange(self.max_h, device="cpu"), - torch.arange(self.max_w, device="cpu"), - ), - dim=-1, - ) - - def get_pos_id_3d(self, *, t, h, w): - if t > self.max_t or h > self.max_h or w > self.max_w: - self.max_t = max(self.max_t, t) - self.max_h = max(self.max_h, h) - self.max_w = max(self.max_w, w) - self.generate_pos_id() - return self.grid[:t, :h, :w] - - -pos_id_3d = PosID3D() - - -def cook_raw_iamges(sample: dict) -> dict: - """ - Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. - - Args: - sample (dict): The input dictionary containing the raw sample data. - - Returns: - dict: A new dictionary containing the processed sample data with the following keys: - - All keys from the result of `basic_sample_keys(sample)` - - 'jpg': original images - - 'png': contains control images - - 'txt': contains raw text - """ - return dict( - **basic_sample_keys(sample), - images=sample["jpg"], - hint=sample["png"], - txt=sample["txt"], - ) - - -class RawImageDiffusionTaskEncoder(DefaultTaskEncoder): - """ - Dummy task encoder takes raw image input on CrudeDataset. - """ - - cookers = [ - # Cooker(cook), - Cooker(cook_raw_iamges), - ] diff --git a/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py b/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py deleted file mode 100644 index 56e57684..00000000 --- a/dfm/src/megatron/data/Dit/data/prepare_energon_dataset.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import os -import pickle -from typing import Callable, List - -import nemo_run as run -import numpy as np -import torch -import torch.distributed as dist -import webdataset as wds - - -def get_start_end_idx_for_this_rank(dataset_size, rank, world_size): - """ - Calculate the start and end indices for a given rank in a distributed setting. - - Args: - dataset_size (int): The total size of the dataset. - rank (int): The rank of the current process. - world_size (int): The total number of processes. - - Returns: - tuple: A tuple containing the start index (int) and end index (int) for the given rank. - """ - split_size = dataset_size // world_size - start_idx = rank * split_size - # The last rank takes the remainder - end_idx = start_idx + split_size if rank != world_size - 1 else dataset_size - return start_idx, end_idx - - -def dummy_process_func(input): - """ - Generates a sample dictionary containing random image latent tensor, text embedding, - and metadata based on the provided input key. - - Args: - input (str): The key to be used in the sample dictionary. - - Returns: - dict: A dictionary containing the following keys: - - "__key__": The input key. - - ".pth": A randomly generated image latent tensor with shape (3, 1, 720, 1280) and dtype torch.bfloat16. - - ".pickle": A pickled numpy array representing a random text embedding with shape (512, 2048). - - ".json": A dictionary containing metadata with keys: - - "image_height": The height of the image (720). - - "image_width": The width of the image (1280). - """ - C, T, H, W = 3, 1, 720, 1280 - image_latent = torch.randn(C, T, H, W, dtype=torch.bfloat16) - text_embedding = np.random.randn(512, 2048) - sample = { - "__key__": input, - ".pth": image_latent, - ".pickle": pickle.dumps(text_embedding), - ".json": { - "image_height": H, - "image_width": W, - }, - } - return sample - - -@torch.no_grad() -@run.cli.entrypoint -def prepare(process_func: Callable, inputs: List[str], output_dir: str = "output"): - """ - distributed prepration webdataset using the provided processing function, and writes the processed samples to tar files. - - Args: - process_func (Callable): A function that processes a single input and returns the processed sample. - inputs (List[str]): A list of input file paths or data entries to be processed. - output_dir (str, optional): The directory where the output tar files will be saved. Defaults to 'output'. - """ - rank = dist.get_rank() - world_size = torch.distributed.get_world_size() - - start_idx, end_idx = get_start_end_idx_for_this_rank(len(inputs), rank, world_size) - os.makedirs(output_dir, exist_ok=True) - output_tar = os.path.join(output_dir, f"rank{rank}-%06d.tar") - with wds.ShardWriter(output_tar, maxcount=10000) as sink: - for i in range(start_idx, end_idx): - sample = process_func(inputs[i]) - # Write the sample to the tar file - sink.write(sample) - - -@run.cli.factory(target=prepare) -def prepare_dummy_image_dataset() -> run.Partial: - recipe = run.Partial( - prepare, - process_func=dummy_process_func, - inputs=list(str(i) for i in range(1000)), - ) - return recipe - - -if __name__ == "__main__": - dist.init_process_group("nccl") - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - run.cli.main(prepare, default_factory=prepare_dummy_image_dataset) diff --git a/dfm/src/megatron/data/Dit/data/utils.py b/dfm/src/megatron/data/Dit/data/utils.py deleted file mode 100644 index dbe8ebad..00000000 --- a/dfm/src/megatron/data/Dit/data/utils.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import numpy as np - - -def minimal_crop(tensor, target_divisor): - """ - Crops the input tensor minimally so that the total number of elements - (T * H * W) is divisible by the specified target_divisor. - - Parameters: - - tensor: NumPy array of shape (C, T, H, W) - - target_divisor: Positive integer specifying the desired divisor - - Returns: - - cropped_tensor: Cropped tensor meeting the divisibility requirement - - Raises: - - ValueError: If it's impossible to meet the divisibility requirement - """ - if not isinstance(target_divisor, int) or target_divisor <= 0: - raise ValueError("target_divisor must be a positive integer greater than zero.") - - C, T, H, W = tensor.shape - total_elements = T * H * W - remainder = total_elements % target_divisor - - if remainder == 0: - return tensor # No cropping needed - - # Elements per unit length in each dimension - elements_per_T = H * W - elements_per_H = T * W - elements_per_W = T * H - - min_elements_removed = None - optimal_deltas = None - - # Limit the search range to avoid unnecessary computations - max_delta_T = min(T - 1, (remainder // elements_per_T) + 1) - max_delta_H = min(H - 1, (remainder // elements_per_H) + 1) - max_delta_W = min(W - 1, (remainder // elements_per_W) + 1) - - for delta_T in range(0, max_delta_T + 1): - for delta_H in range(0, max_delta_H + 1): - for delta_W in range(0, max_delta_W + 1): - if delta_T == delta_H == delta_W == 0: - continue # No cropping - - new_T = T - delta_T - new_H = H - delta_H - new_W = W - delta_W - - if new_T <= 0 or new_H <= 0 or new_W <= 0: - continue # Invalid dimensions - - new_total_elements = new_T * new_H * new_W - if new_total_elements % target_divisor == 0: - elements_removed = delta_T * elements_per_T + delta_H * elements_per_H + delta_W * elements_per_W - if min_elements_removed is None or elements_removed < min_elements_removed: - min_elements_removed = elements_removed - optimal_deltas = (delta_T, delta_H, delta_W) - - if optimal_deltas is None: - raise ValueError("Cannot crop tensor to meet divisibility requirement.") - - delta_T, delta_H, delta_W = optimal_deltas - - # Perform the cropping - # T dimension: crop from the end - end_T = T - delta_T - - # H dimension: center crop - start_H = delta_H // 2 - end_H = H - (delta_H - delta_H // 2) - - # W dimension: center crop - start_W = delta_W // 2 - end_W = W - (delta_W - delta_W // 2) - - cropped_tensor = tensor[:, :end_T, start_H:end_H, start_W:end_W] - return cropped_tensor - - -def test_no_cropping_needed(): - """Test when the tensor already meets the divisibility requirement.""" - C, T, H, W = 3, 8, 8, 8 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - assert cropped_tensor.shape == (C, T, H, W) - assert (T * H * W) % target_divisor == 0 - - -def test_minimal_cropping_T_dimension(): - """Test minimal cropping along the T dimension.""" - C, T, H, W = 3, 9, 7, 6 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_T = cropped_tensor.shape[1] - assert new_T == T - 1, cropped_tensor.shape - assert (new_T * H * W) % target_divisor == 0 - - -def test_minimal_cropping_H_dimension(): - """Test minimal cropping along the H dimension.""" - C, T, H, W = 3, 7, 9, 6 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_H = cropped_tensor.shape[2] - assert new_H == H - 1, cropped_tensor.shape - assert (T * new_H * W) % target_divisor == 0 - - -def test_minimal_cropping_W_dimension(): - """Test minimal cropping along the W dimension.""" - C, T, H, W = 3, 4, 3, 9 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_W = cropped_tensor.shape[3] - assert new_W == W - 1, cropped_tensor.shape - assert (T * H * new_W) % target_divisor == 0 - - -def test_cropping_multiple_dimensions(): - """Test when minimal cropping requires adjustments on multiple dimensions.""" - C, T, H, W = 3, 9, 9, 8 - target_divisor = 16 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_T, new_H, new_W = cropped_tensor.shape[1:] - assert new_T <= T and new_H <= H and new_W <= W - assert (new_T * new_H * new_W) % target_divisor == 0 - - -def test_large_tensor_high_divisor(): - """Test with a larger tensor and higher target_divisor.""" - C, T, H, W = 3, 50, 50, 50 - target_divisor = 1024 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - total_elements = cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3] - assert total_elements % target_divisor == 0 - - -def test_impossible_cropping(): - """Test that an error is raised when it's impossible to meet the requirement.""" - C, T, H, W = 3, 1, 1, 1 - target_divisor = 2 - tensor = np.zeros((C, T, H, W)) - try: - minimal_crop(tensor, target_divisor) - except ValueError: - pass - - -def test_invalid_target_divisor(): - """Test that an error is raised when target_divisor is invalid.""" - C, T, H, W = 3, 8, 8, 8 - tensor = np.zeros((C, T, H, W)) - try: - minimal_crop(tensor, -1) - except ValueError: - pass - - -def test_minimal_elements_removed(): - """Test that the minimal number of elements are removed.""" - C, T, H, W = 3, 7, 7, 7 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - elements_removed = (T * H * W) - (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) - print(cropped_tensor.shape) - assert elements_removed > 0 - assert (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) % target_divisor == 0 - - -test_no_cropping_needed() -test_minimal_elements_removed() -test_cropping_multiple_dimensions() -test_minimal_cropping_T_dimension() -test_minimal_cropping_H_dimension() -test_minimal_cropping_W_dimension() -test_impossible_cropping() -test_invalid_target_divisor() From aa2050466b8b9b8844d754cc61ea93c1f7a0e90e Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 23:14:31 -0800 Subject: [PATCH 13/80] refactoring code --- .../data/Dit/prepare_energon_dataset.py | 117 ------------------ .../megatron/data/{Dit => dit}/__init__.py | 0 dfm/src/megatron/data/{Dit => dit}/base.py | 0 .../diffusion_energon_datamodule.py | 4 +- .../{Dit => dit}/diffusion_taskencoder.py | 0 dfm/src/megatron/data/{Dit => dit}/utils.py | 0 .../data/wan/wan_energon_datamodule.py | 2 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 10 +- .../megatron/model/common/dit_attention.py | 43 +++---- .../flow_matching/flow_inference_pipeline.py | 23 +++- .../model/wan/flow_matching/flow_pipeline.py | 15 +-- dfm/src/megatron/model/wan/wan_model.py | 16 +-- dfm/src/megatron/recipes/wan/wan.py | 7 -- .../conf}/wan_pretrain_override_example.yaml | 0 .../megatron/recipes/wan/example_commands.md | 10 ++ .../wan/prepare_energon_dataset_wan.py | 14 +++ examples/megatron/recipes/wan/pretrain_wan.py | 4 +- pyproject.toml | 7 +- 18 files changed, 88 insertions(+), 184 deletions(-) delete mode 100644 dfm/src/megatron/data/Dit/prepare_energon_dataset.py rename dfm/src/megatron/data/{Dit => dit}/__init__.py (100%) rename dfm/src/megatron/data/{Dit => dit}/base.py (100%) rename dfm/src/megatron/data/{Dit => dit}/diffusion_energon_datamodule.py (98%) rename dfm/src/megatron/data/{Dit => dit}/diffusion_taskencoder.py (100%) rename dfm/src/megatron/data/{Dit => dit}/utils.py (100%) rename examples/megatron/{override_configs => recipes/wan/conf}/wan_pretrain_override_example.yaml (100%) rename {dfm/src/megatron/data => examples/megatron/recipes}/wan/prepare_energon_dataset_wan.py (95%) diff --git a/dfm/src/megatron/data/Dit/prepare_energon_dataset.py b/dfm/src/megatron/data/Dit/prepare_energon_dataset.py deleted file mode 100644 index 56e57684..00000000 --- a/dfm/src/megatron/data/Dit/prepare_energon_dataset.py +++ /dev/null @@ -1,117 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import os -import pickle -from typing import Callable, List - -import nemo_run as run -import numpy as np -import torch -import torch.distributed as dist -import webdataset as wds - - -def get_start_end_idx_for_this_rank(dataset_size, rank, world_size): - """ - Calculate the start and end indices for a given rank in a distributed setting. - - Args: - dataset_size (int): The total size of the dataset. - rank (int): The rank of the current process. - world_size (int): The total number of processes. - - Returns: - tuple: A tuple containing the start index (int) and end index (int) for the given rank. - """ - split_size = dataset_size // world_size - start_idx = rank * split_size - # The last rank takes the remainder - end_idx = start_idx + split_size if rank != world_size - 1 else dataset_size - return start_idx, end_idx - - -def dummy_process_func(input): - """ - Generates a sample dictionary containing random image latent tensor, text embedding, - and metadata based on the provided input key. - - Args: - input (str): The key to be used in the sample dictionary. - - Returns: - dict: A dictionary containing the following keys: - - "__key__": The input key. - - ".pth": A randomly generated image latent tensor with shape (3, 1, 720, 1280) and dtype torch.bfloat16. - - ".pickle": A pickled numpy array representing a random text embedding with shape (512, 2048). - - ".json": A dictionary containing metadata with keys: - - "image_height": The height of the image (720). - - "image_width": The width of the image (1280). - """ - C, T, H, W = 3, 1, 720, 1280 - image_latent = torch.randn(C, T, H, W, dtype=torch.bfloat16) - text_embedding = np.random.randn(512, 2048) - sample = { - "__key__": input, - ".pth": image_latent, - ".pickle": pickle.dumps(text_embedding), - ".json": { - "image_height": H, - "image_width": W, - }, - } - return sample - - -@torch.no_grad() -@run.cli.entrypoint -def prepare(process_func: Callable, inputs: List[str], output_dir: str = "output"): - """ - distributed prepration webdataset using the provided processing function, and writes the processed samples to tar files. - - Args: - process_func (Callable): A function that processes a single input and returns the processed sample. - inputs (List[str]): A list of input file paths or data entries to be processed. - output_dir (str, optional): The directory where the output tar files will be saved. Defaults to 'output'. - """ - rank = dist.get_rank() - world_size = torch.distributed.get_world_size() - - start_idx, end_idx = get_start_end_idx_for_this_rank(len(inputs), rank, world_size) - os.makedirs(output_dir, exist_ok=True) - output_tar = os.path.join(output_dir, f"rank{rank}-%06d.tar") - with wds.ShardWriter(output_tar, maxcount=10000) as sink: - for i in range(start_idx, end_idx): - sample = process_func(inputs[i]) - # Write the sample to the tar file - sink.write(sample) - - -@run.cli.factory(target=prepare) -def prepare_dummy_image_dataset() -> run.Partial: - recipe = run.Partial( - prepare, - process_func=dummy_process_func, - inputs=list(str(i) for i in range(1000)), - ) - return recipe - - -if __name__ == "__main__": - dist.init_process_group("nccl") - local_rank = int(os.environ["LOCAL_RANK"]) - torch.cuda.set_device(local_rank) - run.cli.main(prepare, default_factory=prepare_dummy_image_dataset) diff --git a/dfm/src/megatron/data/Dit/__init__.py b/dfm/src/megatron/data/dit/__init__.py similarity index 100% rename from dfm/src/megatron/data/Dit/__init__.py rename to dfm/src/megatron/data/dit/__init__.py diff --git a/dfm/src/megatron/data/Dit/base.py b/dfm/src/megatron/data/dit/base.py similarity index 100% rename from dfm/src/megatron/data/Dit/base.py rename to dfm/src/megatron/data/dit/base.py diff --git a/dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py similarity index 98% rename from dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py rename to dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index b78a6dbc..3f7ea3e0 100644 --- a/dfm/src/megatron/data/Dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -20,10 +20,10 @@ from torch import int_repr -from dfm.src.megatron.data.Dit.diffusion_taskencoder import BasicDiffusionTaskEncoder +from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from megatron.energon import DefaultTaskEncoder, get_train_dataset -from dfm.src.megatron.data.Dit.base import EnergonMultiModalDataModule +from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule @dataclass(kw_only=True) class DiffusionDataModuleConfig(DatasetProvider): diff --git a/dfm/src/megatron/data/Dit/diffusion_taskencoder.py b/dfm/src/megatron/data/dit/diffusion_taskencoder.py similarity index 100% rename from dfm/src/megatron/data/Dit/diffusion_taskencoder.py rename to dfm/src/megatron/data/dit/diffusion_taskencoder.py diff --git a/dfm/src/megatron/data/Dit/utils.py b/dfm/src/megatron/data/dit/utils.py similarity index 100% rename from dfm/src/megatron/data/Dit/utils.py rename to dfm/src/megatron/data/dit/utils.py diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 9a4fb09a..8a209eda 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -20,7 +20,7 @@ from torch import int_repr -from dfm.src.megatron.data.Dit.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 5d504b4d..b4a21165 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -72,7 +72,7 @@ def __init__( self.patch_temporal = patch_temporal self.seq_length = seq_length - + ## actual encode_sample() for production def encode_sample(self, sample: dict) -> dict: video_latent = sample["pth"] @@ -103,12 +103,14 @@ def encode_sample(self, sample: dict) -> dict: video_metadata=video_metadata, ) - + ## mock encode_sample() for debugging # def encode_sample(self, sample: dict) -> dict: # # mock encode sample - # video_latent = torch.tensor(torch.randn(16, 3, 104, 60), dtype=torch.float32) - # # video_latent = torch.tensor(torch.randn(16, 24, 104, 60), dtype=torch.float32) + # F_latents = 24 + # H_latents = 104 + # W_latents = 60 + # video_latent = torch.tensor(torch.randn(16, F_latents, H_latents, W_latents), dtype=torch.float32) # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) # video_metadata = {} diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index a5a2794c..113a7d3f 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Optional import copy @@ -47,7 +61,7 @@ def __init__( pg_collection, ) - self.layernorm_across_heads = self.config.layernorm_across_heads + self.layernorm_across_heads = getattr(self.config, "layernorm_across_heads", False) # override q_layernorm if submodules.q_layernorm is not None: @@ -185,7 +199,7 @@ def __init__( pg_collection, ) - self.layernorm_across_heads = self.config.layernorm_across_heads + self.layernorm_across_heads = getattr(self.config, "layernorm_across_heads", False) # override q_layernorm if submodules.q_layernorm is not None: @@ -226,30 +240,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv, _ = self.linear_kv(key_value_states) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv = mixed_kv.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key, value) = tensor_parallel.split_tensor_along_last_dim(mixed_kv, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query, _ = self.linear_q(hidden_states) - - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query = query.view(*new_tensor_shape) - - # replace with our own implementation (Todo: @huy ) + query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) # gather query and key heads across TP ranks if self.layernorm_across_heads is True diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index b6a8864e..0b0dd0c3 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import gc import logging import math @@ -8,6 +22,10 @@ import re from contextlib import contextmanager from functools import partial +from megatron.core import parallel_state +from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model import torch import torch.cuda.amp as amp @@ -148,7 +166,6 @@ def setup_model_from_checkpoint(self, checkpoint_dir): provider.initialize_model_parallel(seed=0) ## Read from megatron checkpoint - from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model model = _load_megatron_model( checkpoint_dir, mp_overrides={ @@ -210,9 +227,6 @@ def forward_pp_step( Forward pass supporting pipeline parallelism. """ - from megatron.core import parallel_state - from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank - pp_world_size = parallel_state.get_pipeline_model_parallel_world_size() is_pp_first = parallel_state.is_pipeline_first_stage(ignore_virtual=True) is_pp_last = parallel_state.is_pipeline_last_stage(ignore_virtual=True) @@ -444,7 +458,6 @@ def noop_no_sync(): # sample videos latents = noises - from megatron.core.packed_seq_params import PackedSeqParams cu_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)]) cu_q = cu_q.to(torch.int32).to(self.device) cu_kv_self = cu_q diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 56f6f689..0940b653 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -67,11 +67,9 @@ def training_step( batch_size = video_latents.shape[1] device = video_latents.device - # # # DEBUGGING precision - # # import torch.cuda.amp as amp - # # with amp.autocast(dtype=torch.bfloat16): - # # # Pass through model - # # ... + # TODO: should we do as in Wan Github repo: + # with amp.autocast(dtype=torch.bfloat16) + # # Pass through model # ======================================================================== # Flow Matching Timestep Sampling @@ -176,13 +174,6 @@ def training_step( context_embeddings = context_embeddings split_loss_mask = loss_mask - # # DEBUGGING - # print(f"[DEBUG] [flow_pipeline] video_latents shape: {video_latents.shape}") - # print(f"[DEBUG] [flow_pipeline] noisy_latents shape: {noisy_latents.shape}") - # print(f"[DEBUG] [flow_pipeline] noise shape: {noise.shape}") - # print(f"[DEBUG] [flow_pipeline] context_embeddings shape: {context_embeddings.shape}") - # print(f"[DEBUG] [flow_pipeline] split_loss_mask shape: {split_loss_mask.shape}") - # ======================================================================== # Forward Pass # ======================================================================== diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 03aeebce..8b475f78 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -135,9 +135,13 @@ def __init__( nn.Linear(self.config.text_dim, self.config.hidden_size), nn.GELU(approximate='tanh'), nn.Linear(self.config.hidden_size, self.config.hidden_size)) - self.time_embedding = nn.Sequential( - nn.Linear(self.freq_dim, self.config.hidden_size), nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size)) - self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size * 6)) + # As in diffuser's Wan implementation + from diffusers.models.embeddings import Timesteps + from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding + self.timesteps_proj = Timesteps(num_channels=self.freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) + self.time_embedder = ParallelTimestepEmbedding(in_channels=self.freq_dim, time_embed_dim=self.config.hidden_size) + self.time_proj_act_fn = nn.SiLU() + self.time_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size * 6) self.rope_embeddings = Wan3DRopeEmbeddings(dim_head = self.config.hidden_size // self.num_heads, max_position_len = 1024) @@ -205,10 +209,8 @@ def forward( x = self.decoder.input_tensor # time embeddings - e = self.time_embedding( - sinusoidal_embedding_1d(self.freq_dim, t).to(x.dtype) - ) - e0 = self.time_projection(e).unflatten(1, (6, self.config.hidden_size)) + e = self.time_embedder(self.timesteps_proj(t).to(x.dtype)) + e0 = self.time_proj(self.time_proj_act_fn(e)).unflatten(1, (6, self.config.hidden_size)) # context embeddings context = self.text_embedding(context) # shape [text_len, b, hidden_size] diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index f784a842..19fe9c2d 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -94,14 +94,7 @@ def pretrain_config( lr: float = 0.9e-4, lr_warmup_iters: int = 2000, # Precision recipe - # DEBUGGING precision_config: Optional[Union[MixedPrecisionConfig, str]] = "bf16_mixed", - # precision_config: Optional[Union[MixedPrecisionConfig, str]] = MixedPrecisionConfig( - # fp32=True, - # params_dtype=torch.float32, - # pipeline_dtype=torch.float32, - # autocast_enabled=False, - # ), comm_overlap_config: Optional[CommOverlapConfig] = None, ) -> ConfigContainer: """ diff --git a/examples/megatron/override_configs/wan_pretrain_override_example.yaml b/examples/megatron/recipes/wan/conf/wan_pretrain_override_example.yaml similarity index 100% rename from examples/megatron/override_configs/wan_pretrain_override_example.yaml rename to examples/megatron/recipes/wan/conf/wan_pretrain_override_example.yaml diff --git a/examples/megatron/recipes/wan/example_commands.md b/examples/megatron/recipes/wan/example_commands.md index 21b2492b..8ecb8a97 100644 --- a/examples/megatron/recipes/wan/example_commands.md +++ b/examples/megatron/recipes/wan/example_commands.md @@ -1,6 +1,16 @@ ## WAN example commands +### Launch container +Example command on EOS cluster: +``` +CONT="nvcr.io/nvidia/nemo:25.09.00" +MOUNT="/lustre/fsw/:/lustre/fsw/" +srun -t 02:00:00 --account coreai_dlalgo_llm -N 1 -J coreai_dlalgo_llm:* -p interactive --exclusive --container-image="${CONT}" --container-mounts="${MOUNT}" --pty bash +``` + + ### Set paths to Megatron-Bridge +Inside container: ```bash DFM_PATH=/path/to/dfm MBRIDGE_PATH=/path/to/megatron-bridge diff --git a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py similarity index 95% rename from dfm/src/megatron/data/wan/prepare_energon_dataset_wan.py rename to examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index a8464aa6..05405964 100644 --- a/dfm/src/megatron/data/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import json import pickle diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 7742397e..251934d6 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -58,9 +58,9 @@ from omegaconf import OmegaConf -from dfm.examples.megatron.recipe.wan.wan import pretrain_config +from dfm.src.megatron.recipes.wan.wan import pretrain_config from megatron.bridge.training.config import ConfigContainer -from dfm.examples.megatron.recipe.wan.wan_step import WanForwardStep +from dfm.src.megatron.model.wan.wan_step import WanForwardStep from megatron.bridge.training.pretrain import pretrain from megatron.bridge.training.utils.omegaconf_utils import ( apply_overrides, diff --git a/pyproject.toml b/pyproject.toml index 20a4d190..eee69845 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,12 @@ classifiers = [ "Topic :: Software Development :: Libraries", "Topic :: Utilities", ] -dependencies = [] +dependencies = [ + "diffusers==0.35.1", + "easydict", + "imageio", + "imageio-ffmpeg", +] [build-system] requires = ["setuptools>=61"] From d5f58c93a896bb24d06bb593d2dd668ec3de534e Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 4 Nov 2025 23:52:12 -0800 Subject: [PATCH 14/80] add README for perf test --- .../megatron/recipes/wan/README_perf_test.md | 178 ++++++++++++++++++ 1 file changed, 178 insertions(+) create mode 100644 examples/megatron/recipes/wan/README_perf_test.md diff --git a/examples/megatron/recipes/wan/README_perf_test.md b/examples/megatron/recipes/wan/README_perf_test.md new file mode 100644 index 00000000..a89c8d40 --- /dev/null +++ b/examples/megatron/recipes/wan/README_perf_test.md @@ -0,0 +1,178 @@ +## WAN Model Setup and Usage (for Perf Test) + +This guide provides concise steps to set up the environment and run WAN pretraining and inference. It pins repo commits and shows explicit commands for the 1.3B and 14B configurations. + +## Container Launch + +```bash +CONT="nvcr.io/nvidia/nemo:25.09.00" +MOUNT="/lustre/fsw/:/lustre/fsw/" + +srun -t 02:00:00 \ + --account \ + -N 1 \ + -J \ + -p batch \ + --exclusive \ + --container-image="${CONT}" \ + --container-mounts="${MOUNT}" \ + --pty bash +``` + +## Setup Inside the Container + +Setup DFM, Megatron-Bridge, Megatron-LM with specific commits, and other dependencies. + +```bash +cd /opt/ + +# DFM (pinned) +git clone --no-checkout https://github.com/NVIDIA-NeMo/DFM.git +git -C DFM checkout aa2050466b8b9b8844d754cc61ea93c1f7a0e90e +export DFM_PATH=/opt/DFM + +# Megatron-Bridge (pinned) +rm -rf /opt/Megatron-Bridge +git clone --no-checkout https://github.com/huvunvidia/Megatron-Bridge.git +git -C Megatron-Bridge checkout 713ab548e4bfee307eb94a7bb3f57c17dbb31b50 + +# Megatron-LM (pinned) +rm -rf /opt/Megatron-LM +git clone --no-checkout https://github.com/NVIDIA/Megatron-LM.git +git -C Megatron-LM checkout ce8185cbbe04f38beb74360e878450f2e8525885 + +# Python path +export PYTHONPATH="${DFM_PATH}/.:/opt/Megatron-Bridge/.:/opt/Megatron-LM" + +# Python deps +python3 -m pip install --upgrade diffusers==0.35.1 +pip install easydict imageio imageio-ffmpeg +``` + +## Pretraining +Set data path and checkpoint directory: + +```bash +DATASET_PATH="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_datasets/processed_arrietty_scene_automodel" +EXP_NAME=wan_debug_perf +CHECKPOINT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/results/wan_finetune/${EXP_NAME}" + +export HF_TOKEN= +export WANDB_API_KEY= +cd ${DFM_PATH} +``` + + +### 1.3B configuration + +```bash +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.crossattn_emb_size=1536 \ + model.hidden_size=1536 \ + model.ffn_hidden_size=8960 \ + model.num_attention_heads=12 \ + model.num_layers=30 \ + model.qkv_format=thd \ + dataset.path="${DATASET_PATH}" \ + checkpoint.save="${CHECKPOINT_DIR}" \ + checkpoint.load="${CHECKPOINT_DIR}" \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=2 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=2 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name="${EXP_NAME}" \ + logger.wandb_save_dir="${CHECKPOINT_DIR}" +``` + +### 14B configuration + +```bash +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pretrain_wan.py \ + model.tensor_model_parallel_size=2 \ + model.pipeline_model_parallel_size=1 \ + model.context_parallel_size=4 \ + model.recompute_granularity=full \ + model.recompute_method=uniform \ + model.recompute_num_layers=1 \ + model.crossattn_emb_size=5120 \ + model.hidden_size=5120 \ + model.ffn_hidden_size=13824 \ + model.num_attention_heads=40 \ + model.num_layers=40 \ + model.qkv_format=thd \ + dataset.path="${DATASET_PATH}" \ + checkpoint.save="${CHECKPOINT_DIR}" \ + checkpoint.load="${CHECKPOINT_DIR}" \ + checkpoint.load_optim=false \ + checkpoint.save_interval=200 \ + optimizer.lr=5e-6 \ + optimizer.min_lr=5e-6 \ + train.eval_iters=0 \ + scheduler.lr_decay_style=constant \ + scheduler.lr_warmup_iters=0 \ + model.seq_length=2048 \ + dataset.seq_length=2048 \ + train.global_batch_size=2 \ + train.micro_batch_size=1 \ + dataset.global_batch_size=2 \ + dataset.micro_batch_size=1 \ + logger.log_interval=1 \ + logger.wandb_project="wan" \ + logger.wandb_exp_name="${EXP_NAME}" \ + logger.wandb_save_dir="${CHECKPOINT_DIR}" +``` + +### Using mock data (optional, for debugging) + +- Edit `dfm/src/megatron/data/wan/wan_taskencoder.py`. +- Comment out the production `encode_sample()` and uncomment the mock version. +- Adjust `video_size` (F_latents, H_latents, W_latents). Total `seq_len = F * H * W`. + +## Inference + +```bash +cd ${DFM_PATH} +export HF_TOKEN= + +T5_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/t5" +VAE_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/vae" +CKPT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_checkpoints/megatron_checkpoint_1.3B" + +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/megatron/recipes/wan/inference_wan.py \ + --task t2v-1.3B \ + --sizes 480*832 \ + --checkpoint_dir "${CKPT_DIR}" \ + --checkpoint_step 0 \ + --t5_checkpoint_dir "${T5_DIR}" \ + --vae_checkpoint_dir "${VAE_DIR}" \ + --frame_nums 81 \ + --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ + --tensor_parallel_size 1 \ + --context_parallel_size 1 \ + --pipeline_parallel_size 1 \ + --sequence_parallel False \ + --base_seed 42 \ + --sample_steps 50 +``` + +## Notes + +- Replace placeholders (tokens, account, dataset/checkpoint paths) with your own. +- Keep the specified commit hashes for compatibility. +- `NVTE_FUSED_ATTN=1` enables fused attention where supported. + + From 9b8e4fbbc712d2317a4566af0d4c74b41af44723 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 13:58:53 -0800 Subject: [PATCH 15/80] using vae, t5, scheduler from Diffusers --- .../flow_matching/flow_inference_pipeline.py | 197 ++-- .../model/wan/inference/configs/__init__.py | 20 - .../wan/inference/configs/wan_i2v_14B.py | 35 - .../model/wan/inference/{utils => }/utils.py | 0 .../model/wan/inference/utils/fm_solvers.py | 858 ------------------ .../wan/inference/utils/fm_solvers_unipc.py | 801 ---------------- .../megatron/recipes/wan/inference_wan.py | 10 +- 7 files changed, 102 insertions(+), 1819 deletions(-) delete mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py rename dfm/src/megatron/model/wan/inference/{utils => }/utils.py (100%) delete mode 100644 dfm/src/megatron/model/wan/inference/utils/fm_solvers.py delete mode 100644 dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index 0b0dd0c3..f5dcea58 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -17,51 +17,69 @@ import math import os import random -import sys -import types import re +import sys from contextlib import contextmanager -from functools import partial -from megatron.core import parallel_state -from megatron.core.inference.communication_utils import broadcast_from_last_pipeline_stage, recv_from_prev_pipeline_rank_, send_to_next_pipeline_rank -from megatron.core.packed_seq_params import PackedSeqParams -from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model +from typing import Tuple import torch import torch.cuda.amp as amp import torch.distributed as dist +from diffusers import AutoencoderKLWan +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler +from megatron.bridge.training.model_load_save import load_megatron_model as _load_megatron_model +from megatron.core import parallel_state +from megatron.core.inference.communication_utils import ( + broadcast_from_last_pipeline_stage, + recv_from_prev_pipeline_rank_, + send_to_next_pipeline_rank, +) +from megatron.core.packed_seq_params import PackedSeqParams +from torch.nn import functional as F from tqdm import tqdm +from transformers import AutoTokenizer, UMT5EncoderModel -from dfm.src.megatron.model.wan.wan_model import WanModel -from dfm.src.megatron.model.wan.wan_provider import WanModelProvider -from dfm.src.megatron.model.wan.modules.t5 import T5EncoderModel -from dfm.src.megatron.model.wan.modules import WanVAE -from dfm.src.megatron.model.wan.inference.utils.fm_solvers import ( - FlowDPMSolverMultistepScheduler, - get_sampling_sigmas, - retrieve_timesteps, -) -from dfm.src.megatron.model.wan.inference.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify -from megatron.core import parallel_state -from torch.nn import functional as F +from dfm.src.megatron.model.wan.wan_provider import WanModelProvider -import math -from typing import Tuple, Union + +@torch.no_grad() +def _encode_text( + tokenizer: AutoTokenizer, + text_encoder: UMT5EncoderModel, + device: str, + caption: str, +) -> torch.Tensor: + caption = caption.strip() + inputs = tokenizer( + caption, + max_length=512, + padding="max_length", + truncation=True, + return_tensors="pt", + return_attention_mask=True, + ) + inputs = {k: v.to(device) for k, v in inputs.items()} + outputs = text_encoder(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"]).last_hidden_state + # Trim to the true (unpadded) sequence length using the attention mask + true_len = int(inputs["attention_mask"].sum(dim=-1).item()) + outputs = outputs[0, :true_len, :] + return outputs class FlowInferencePipeline: def __init__( self, config, - checkpoint_dir, + model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", + checkpoint_dir=None, checkpoint_step=None, t5_checkpoint_dir=None, vae_checkpoint_dir=None, device_id=0, rank=0, t5_cpu=False, - tensor_parallel_size=1, context_parallel_size=1, pipeline_parallel_size=1, @@ -89,6 +107,7 @@ def __init__( """ self.device = torch.device(f"cuda:{device_id}") self.config = config + self.model_id = model_id self.rank = rank self.t5_cpu = t5_cpu self.tensor_parallel_size = tensor_parallel_size @@ -99,19 +118,24 @@ def __init__( self.num_train_timesteps = config.num_train_timesteps self.param_dtype = config.param_dtype - self.text_encoder = T5EncoderModel( - text_len=config.text_len, - dtype=config.t5_dtype, - device=torch.device('cpu'), - checkpoint_path=os.path.join(t5_checkpoint_dir, config.t5_checkpoint), - tokenizer_path=os.path.join(t5_checkpoint_dir, config.t5_tokenizer), - shard_fn=None) + self.text_encoder = UMT5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder", + torch_dtype=config.t5_dtype, + ) + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, + subfolder="tokenizer", + ) self.vae_stride = config.vae_stride - self.patch_size = config.patch_size - self.vae = WanVAE( - vae_pth=os.path.join(vae_checkpoint_dir, config.vae_checkpoint), - device=self.device) + self.patch_size = config.patch_size + self.vae = AutoencoderKLWan.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=config.param_dtype, + ) + self.vae.to(self.device) wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) @@ -302,7 +326,6 @@ def generate(self, sizes, frame_nums, shift=5.0, - sample_solver='unipc', sampling_steps=50, guide_scale=5.0, n_prompt="", @@ -320,8 +343,6 @@ def generate(self, How many frames to sample from a video. The number should be 4n+1 shift (`float`, *optional*, defaults to 5.0): Noise schedule shift parameter. Affects temporal dynamics - sample_solver (`str`, *optional*, defaults to 'unipc'): - Solver used to sample the video. sampling_steps (`int`, *optional*, defaults to 40): Number of diffusion sampling steps. Higher values improve quality but slow generation guide_scale (`float`, *optional*, defaults 5.0): @@ -345,7 +366,7 @@ def generate(self, # preprocess target_shapes = [] for size, frame_num in zip(sizes, frame_nums): - target_shapes.append((self.vae.model.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, + target_shapes.append((self.vae.config.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, size[1] // self.vae_stride[1], size[0] // self.vae_stride[2])) @@ -366,23 +387,27 @@ def generate(self, ## process context + # we implement similar to Wan's diffuser setup + # (https://github.com/huggingface/diffusers/blob/0f252be0ed42006c125ef4429156cb13ae6c1d60/src/diffusers/pipelines/wan/pipeline_wan.py#L157) + # in which we pad the text to 512, pass through text encoder, and truncate to the actual tokens, then pad with 0s to 512. context_max_len = 512 context_lens = [] contexts = [] contexts_null = [] for prompt in prompts: if not self.t5_cpu: - self.text_encoder.model.to(self.device) - context = self.text_encoder([prompt], self.device)[0] - context_null = self.text_encoder([n_prompt], self.device)[0] + self.text_encoder.to(self.device) + context = _encode_text(self.tokenizer, self.text_encoder, self.device, prompt) + context_null = _encode_text(self.tokenizer, self.text_encoder, self.device, n_prompt) if offload_model: - self.text_encoder.model.cpu() + self.text_encoder.cpu() else: context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) context_lens.append(context_max_len) # all samples have the same context_max_len contexts.append(context) contexts_null.append(context_null) + # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] @@ -421,39 +446,18 @@ def noop_no_sync(): # evaluation mode with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): - - if sample_solver == 'unipc': - # Create a prototype scheduler to compute shared timesteps - sample_scheduler = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sample_scheduler.set_timesteps( - sampling_steps, device=self.device, shift=shift) - timesteps = sample_scheduler.timesteps - - # Instantiate per-sample schedulers so each sample maintains its own state - batch_size_for_schedulers = len(noises) - schedulers = [] - for _ in range(batch_size_for_schedulers): - s = FlowUniPCMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - s.set_timesteps(sampling_steps, device=self.device, shift=shift) - schedulers.append(s) - elif sample_solver == 'dpm++': - sample_scheduler = FlowDPMSolverMultistepScheduler( - num_train_timesteps=self.num_train_timesteps, - shift=1, - use_dynamic_shifting=False) - sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) - timesteps, _ = retrieve_timesteps( - sample_scheduler, - device=self.device, - sigmas=sampling_sigmas) - else: - raise NotImplementedError("Unsupported solver.") + # Instantiate per-sample schedulers so each sample maintains its own state + batch_size_for_schedulers = len(noises) + schedulers = [] + for _ in range(batch_size_for_schedulers): + base_sched = FlowMatchEulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler" + ) + s = UniPCMultistepScheduler.from_config(base_sched.config, flow_shift=shift) + s.set_timesteps(sampling_steps, device=self.device) + + schedulers.append(s) + timesteps = schedulers[0].timesteps # sample videos latents = noises @@ -508,11 +512,11 @@ def noop_no_sync(): unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim) unpatchified_noise_pred_uncond = noise_pred_uncond unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.model.z_dim) + unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim) noise_preds = [] for i in range(batch_size): @@ -523,21 +527,11 @@ def noop_no_sync(): # step and update latents latents = [] for i in range(batch_size): - - if sample_solver == 'unipc': - temp_x0 = schedulers[i].step( - noise_preds[i].unsqueeze(0), - t, - unpatchified_latents[i].unsqueeze(0), - return_dict=False, - generator=seed_g)[0] - else: - temp_x0 = sample_scheduler.step( - noise_preds[i].unsqueeze(0), - t, - unpatchified_latents[i].unsqueeze(0), - return_dict=False, - generator=seed_g)[0] + temp_x0 =schedulers[i].step( + noise_preds[i].unsqueeze(0), + t, + unpatchified_latents[i].unsqueeze(0), + return_dict=False)[0] latents.append(temp_x0.squeeze(0)) x0 = latents @@ -545,15 +539,24 @@ def noop_no_sync(): self.model.cpu() torch.cuda.empty_cache() if self.rank == 0: - videos = self.vae.decode(x0) + # Diffusers' VAE decoding + latents = torch.stack(x0, dim=0) + latents = latents.to(self.vae.dtype) + latents_mean = ( + torch.tensor(self.vae.config.latents_mean) + .view(1, self.vae.config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + latents = latents / latents_std + latents_mean + videos = self.vae.decode(latents).sample else: videos = None del noises, latents - if sample_solver == 'unipc': - del schedulers - else: - del sample_scheduler + del schedulers if offload_model: gc.collect() torch.cuda.synchronize() diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index a28c03c5..21679892 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -3,27 +3,12 @@ os.environ['TOKENIZERS_PARALLELISM'] = 'false' -from .wan_i2v_14B import i2v_14B from .wan_t2v_1_3B import t2v_1_3B from .wan_t2v_14B import t2v_14B -# the config of t2i_14B is the same as t2v_14B -t2i_14B = copy.deepcopy(t2v_14B) -t2i_14B.__name__ = 'Config: Wan T2I 14B' - -# the config of flf2v_14B is the same as i2v_14B -flf2v_14B = copy.deepcopy(i2v_14B) -flf2v_14B.__name__ = 'Config: Wan FLF2V 14B' -flf2v_14B.sample_neg_prompt = "镜头切换," + flf2v_14B.sample_neg_prompt - WAN_CONFIGS = { 't2v-14B': t2v_14B, 't2v-1.3B': t2v_1_3B, - 'i2v-14B': i2v_14B, - 't2i-14B': t2i_14B, - 'flf2v-14B': flf2v_14B, - 'vace-1.3B': t2v_1_3B, - 'vace-14B': t2v_14B, } SIZE_CONFIGS = { @@ -44,9 +29,4 @@ SUPPORTED_SIZES = { 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), 't2v-1.3B': ('480*832', '832*480'), - 'i2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), - 'flf2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), - 't2i-14B': tuple(SIZE_CONFIGS.keys()), - 'vace-1.3B': ('480*832', '832*480'), - 'vace-14B': ('720*1280', '1280*720', '480*832', '832*480') } diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py deleted file mode 100644 index 764d2ed8..00000000 --- a/dfm/src/megatron/model/wan/inference/configs/wan_i2v_14B.py +++ /dev/null @@ -1,35 +0,0 @@ -import torch -from easydict import EasyDict - -from .shared_config import wan_shared_cfg - -#------------------------ Wan I2V 14B ------------------------# - -i2v_14B = EasyDict(__name__='Config: Wan I2V 14B') -i2v_14B.update(wan_shared_cfg) -i2v_14B.sample_neg_prompt = "镜头晃动," + i2v_14B.sample_neg_prompt - -i2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -i2v_14B.t5_tokenizer = 'google/umt5-xxl' - -# clip -i2v_14B.clip_model = 'clip_xlm_roberta_vit_h_14' -i2v_14B.clip_dtype = torch.float16 -i2v_14B.clip_checkpoint = 'models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth' -i2v_14B.clip_tokenizer = 'xlm-roberta-large' - -# vae -i2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' -i2v_14B.vae_stride = (4, 8, 8) - -# transformer -i2v_14B.patch_size = (1, 2, 2) -i2v_14B.dim = 5120 -i2v_14B.ffn_dim = 13824 -i2v_14B.freq_dim = 256 -i2v_14B.num_heads = 40 -i2v_14B.num_layers = 40 -i2v_14B.window_size = (-1, -1) -i2v_14B.qk_norm = True -i2v_14B.cross_attn_norm = True -i2v_14B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/inference/utils/utils.py b/dfm/src/megatron/model/wan/inference/utils.py similarity index 100% rename from dfm/src/megatron/model/wan/inference/utils/utils.py rename to dfm/src/megatron/model/wan/inference/utils.py diff --git a/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py b/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py deleted file mode 100644 index a38b755c..00000000 --- a/dfm/src/megatron/model/wan/inference/utils/fm_solvers.py +++ /dev/null @@ -1,858 +0,0 @@ -# Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py -# Convert dpm solver for flow matching - -import inspect -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import ( - KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput, -) -from diffusers.utils import deprecate, is_scipy_available -from diffusers.utils.torch_utils import randn_tensor - -if is_scipy_available(): - pass - - -def get_sampling_sigmas(sampling_steps, shift): - sigma = np.linspace(1, 0, sampling_steps + 1)[:sampling_steps] - sigma = (shift * sigma / (1 + (shift - 1) * sigma)) - - return sigma - - -def retrieve_timesteps( - scheduler, - num_inference_steps=None, - device=None, - timesteps=None, - sigmas=None, - **kwargs, -): - if timesteps is not None and sigmas is not None: - raise ValueError( - "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" - ) - if timesteps is not None: - accepts_timesteps = "timesteps" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - elif sigmas is not None: - accept_sigmas = "sigmas" in set( - inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accept_sigmas: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" sigmas schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class FlowDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `FlowDPMSolverMultistepScheduler` is a fast dedicated high-order solver for diffusion ODEs. - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. This determines the resolution of the diffusion process. - solver_order (`int`, defaults to 2): - The DPMSolver order which can be `1`, `2`, or `3`. It is recommended to use `solver_order=2` for guided - sampling, and `solver_order=3` for unconditional sampling. This affects the number of model outputs stored - and used in multistep updates. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts - the flow of the diffusion process. - shift (`float`, *optional*, defaults to 1.0): - A factor used to adjust the sigmas in the noise schedule. It modifies the step sizes during the sampling - process. - use_dynamic_shifting (`bool`, defaults to `False`): - Whether to apply dynamic shifting to the timesteps based on image resolution. If `True`, the shifting is - applied on the fly. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This method adjusts the predicted sample to prevent - saturation and improve photorealism. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and - `algorithm_type="dpmsolver++"`. - algorithm_type (`str`, defaults to `dpmsolver++`): - Algorithm type for the solver; can be `dpmsolver`, `dpmsolver++`, `sde-dpmsolver` or `sde-dpmsolver++`. The - `dpmsolver` type implements the algorithms in the [DPMSolver](https://huggingface.co/papers/2206.00927) - paper, and the `dpmsolver++` type implements the algorithms in the - [DPMSolver++](https://huggingface.co/papers/2211.01095) paper. It is recommended to use `dpmsolver++` or - `sde-dpmsolver++` with `solver_order=2` for guided sampling like in Stable Diffusion. - solver_type (`str`, defaults to `midpoint`): - Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the - sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. - lower_order_final (`bool`, defaults to `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - euler_at_final (`bool`, defaults to `False`): - Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail - richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference - steps, but sometimes may result in blurring. - final_sigmas_type (`str`, *optional*, defaults to "zero"): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - lambda_min_clipped (`float`, defaults to `-inf`): - Clipping threshold for the minimum value of `lambda(t)` for numerical stability. This is critical for the - cosine (`squaredcos_cap_v2`) noise schedule. - variance_type (`str`, *optional*): - Set to "learned" or "learned_range" for diffusion models that predict variance. If set, the model's output - contains the predicted Gaussian variance. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "dpmsolver++", - solver_type: str = "midpoint", - lower_order_final: bool = True, - euler_at_final: bool = False, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - lambda_min_clipped: float = -float("inf"), - variance_type: Optional[str] = None, - invert_sigmas: bool = False, - ): - if algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - deprecation_message = f"algorithm_type {algorithm_type} is deprecated and will be removed in a future version. Choose from `dpmsolver++` or `sde-dpmsolver++` instead" - deprecate("algorithm_types dpmsolver and sde-dpmsolver", "1.0.0", - deprecation_message) - - # settings for DPM-Solver - if algorithm_type not in [ - "dpmsolver", "dpmsolver++", "sde-dpmsolver", "sde-dpmsolver++" - ]: - if algorithm_type == "deis": - self.register_to_config(algorithm_type="dpmsolver++") - else: - raise NotImplementedError( - f"{algorithm_type} is not implemented for {self.__class__}") - - if solver_type not in ["midpoint", "heun"]: - if solver_type in ["logrho", "bh1", "bh2"]: - self.register_to_config(solver_type="midpoint") - else: - raise NotImplementedError( - f"{solver_type} is not implemented for {self.__class__}") - - if algorithm_type not in ["dpmsolver++", "sde-dpmsolver++" - ] and final_sigmas_type == "zero": - raise ValueError( - f"`final_sigmas_type` {final_sigmas_type} is not supported for `algorithm_type` {algorithm_type}. Please choose `sigma_min` instead." - ) - - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, - num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + - (shift - 1) * sigmas) # pyright: ignore - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - self._step_index = None - self._begin_index = None - - # self.sigmas = self.sigmas.to( - # "cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError( - " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" - ) - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, - num_inference_steps + - 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + - (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / - self.alphas_cumprod[0])**0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last] - ]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to( - device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - self._step_index = None - self._begin_index = None - # self.sigmas = self.sigmas.to( - # "cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float( - ) # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile( - abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze( - 1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp( - sample, -s, s - ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def _sigma_to_alpha_sigma_t(self, sigma): - return 1 - sigma, sigma - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.convert_model_output - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is - designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an - integral of the data prediction model. - - The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise - prediction and data prediction models. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError( - "missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - # DPM-Solver++ needs to solve an integral of the data prediction model. - if self.config.algorithm_type in ["dpmsolver++", "sde-dpmsolver++"]: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction`, or `flow_prediction` for the FlowDPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - x0_pred = self._threshold_sample(x0_pred) - - return x0_pred - - # DPM-Solver needs to solve an integral of the noise prediction model. - elif self.config.algorithm_type in ["dpmsolver", "sde-dpmsolver"]: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - epsilon = sample - (1 - sigma_t) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the FlowDPMSolverMultistepScheduler." - ) - - if self.config.thresholding: - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - x0_pred = self._threshold_sample(x0_pred) - epsilon = model_output + x0_pred - - return epsilon - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.dpm_solver_first_order_update - def dpm_solver_first_order_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - noise: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the first-order DPMSolver (equivalent to DDIM). - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop( - "prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError( - " missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[ - self.step_index] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s = torch.log(alpha_s) - torch.log(sigma_s) - - h = lambda_t - lambda_s - if self.config.algorithm_type == "dpmsolver++": - x_t = (sigma_t / - sigma_s) * sample - (alpha_t * - (torch.exp(-h) - 1.0)) * model_output - elif self.config.algorithm_type == "dpmsolver": - x_t = (alpha_t / - alpha_s) * sample - (sigma_t * - (torch.exp(h) - 1.0)) * model_output - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - x_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample + - (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + - sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - x_t = ((alpha_t / alpha_s) * sample - 2.0 * - (sigma_t * (torch.exp(h) - 1.0)) * model_output + - sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) - return x_t # pyright: ignore - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_second_order_update - def multistep_dpm_solver_second_order_update( - self, - model_output_list: List[torch.Tensor], - *args, - sample: torch.Tensor = None, - noise: Optional[torch.Tensor] = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the second-order multistep DPMSolver. - Args: - model_output_list (`List[torch.Tensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - timestep_list = args[0] if len(args) > 0 else kwargs.pop( - "timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop( - "prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError( - " missing `sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1 = ( - self.sigmas[self.step_index + 1], # pyright: ignore - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], # pyright: ignore - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - - m0, m1 = model_output_list[-1], model_output_list[-2] - - h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 - r0 = h_0 / h - D0, D1 = m0, (1.0 / r0) * (m0 - m1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2211.01095 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ((sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 - 0.5 * - (alpha_t * (torch.exp(-h) - 1.0)) * D1) - elif self.config.solver_type == "heun": - x_t = ((sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - if self.config.solver_type == "midpoint": - x_t = ((alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 0.5 * - (sigma_t * (torch.exp(h) - 1.0)) * D1) - elif self.config.solver_type == "heun": - x_t = ((alpha_t / alpha_s0) * sample - - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1) - elif self.config.algorithm_type == "sde-dpmsolver++": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + - (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + 0.5 * - (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + - sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) - elif self.config.solver_type == "heun": - x_t = ((sigma_t / sigma_s0 * torch.exp(-h)) * sample + - (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + - (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / - (-2.0 * h) + 1.0)) * D1 + - sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise) - elif self.config.algorithm_type == "sde-dpmsolver": - assert noise is not None - if self.config.solver_type == "midpoint": - x_t = ((alpha_t / alpha_s0) * sample - 2.0 * - (sigma_t * (torch.exp(h) - 1.0)) * D0 - - (sigma_t * (torch.exp(h) - 1.0)) * D1 + - sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) - elif self.config.solver_type == "heun": - x_t = ((alpha_t / alpha_s0) * sample - 2.0 * - (sigma_t * (torch.exp(h) - 1.0)) * D0 - 2.0 * - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 + - sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise) - return x_t # pyright: ignore - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.multistep_dpm_solver_third_order_update - def multistep_dpm_solver_third_order_update( - self, - model_output_list: List[torch.Tensor], - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - """ - One step for the third-order multistep DPMSolver. - Args: - model_output_list (`List[torch.Tensor]`): - The direct outputs from learned diffusion model at current and latter timesteps. - sample (`torch.Tensor`): - A current instance of a sample created by diffusion process. - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - - timestep_list = args[0] if len(args) > 0 else kwargs.pop( - "timestep_list", None) - prev_timestep = args[1] if len(args) > 1 else kwargs.pop( - "prev_timestep", None) - if sample is None: - if len(args) > 2: - sample = args[2] - else: - raise ValueError( - " missing`sample` as a required keyward argument") - if timestep_list is not None: - deprecate( - "timestep_list", - "1.0.0", - "Passing `timestep_list` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma_t, sigma_s0, sigma_s1, sigma_s2 = ( - self.sigmas[self.step_index + 1], # pyright: ignore - self.sigmas[self.step_index], - self.sigmas[self.step_index - 1], # pyright: ignore - self.sigmas[self.step_index - 2], # pyright: ignore - ) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) - alpha_s2, sigma_s2 = self._sigma_to_alpha_sigma_t(sigma_s2) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) - lambda_s2 = torch.log(alpha_s2) - torch.log(sigma_s2) - - m0, m1, m2 = model_output_list[-1], model_output_list[ - -2], model_output_list[-3] - - h, h_0, h_1 = lambda_t - lambda_s0, lambda_s0 - lambda_s1, lambda_s1 - lambda_s2 - r0, r1 = h_0 / h, h_1 / h - D0 = m0 - D1_0, D1_1 = (1.0 / r0) * (m0 - m1), (1.0 / r1) * (m1 - m2) - D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) - D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1) - if self.config.algorithm_type == "dpmsolver++": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ((sigma_t / sigma_s0) * sample - - (alpha_t * (torch.exp(-h) - 1.0)) * D0 + - (alpha_t * ((torch.exp(-h) - 1.0) / h + 1.0)) * D1 - - (alpha_t * ((torch.exp(-h) - 1.0 + h) / h**2 - 0.5)) * D2) - elif self.config.algorithm_type == "dpmsolver": - # See https://arxiv.org/abs/2206.00927 for detailed derivations - x_t = ((alpha_t / alpha_s0) * sample - (sigma_t * - (torch.exp(h) - 1.0)) * D0 - - (sigma_t * ((torch.exp(h) - 1.0) / h - 1.0)) * D1 - - (sigma_t * ((torch.exp(h) - 1.0 - h) / h**2 - 0.5)) * D2) - return x_t # pyright: ignore - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - # Modified from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.step - def step( - self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - generator=None, - variance_noise: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep DPMSolver. - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - generator (`torch.Generator`, *optional*): - A random number generator. - variance_noise (`torch.Tensor`): - Alternative to generating noise with `generator` by directly providing the noise for the variance - itself. Useful for methods such as [`LEdits++`]. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Improve numerical stability for small number of steps - lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( - self.config.euler_at_final or - (self.config.lower_order_final and len(self.timesteps) < 15) or - self.config.final_sigmas_type == "zero") - lower_order_second = ((self.step_index == len(self.timesteps) - 2) and - self.config.lower_order_final and - len(self.timesteps) < 15) - - model_output = self.convert_model_output(model_output, sample=sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - if self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++" - ] and variance_noise is None: - noise = randn_tensor( - model_output.shape, - generator=generator, - device=model_output.device, - dtype=torch.float32) - elif self.config.algorithm_type in ["sde-dpmsolver", "sde-dpmsolver++"]: - noise = variance_noise.to( - device=model_output.device, - dtype=torch.float32) # pyright: ignore - else: - noise = None - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.dpm_solver_first_order_update( - model_output, sample=sample, noise=noise) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - prev_sample = self.multistep_dpm_solver_second_order_update( - self.model_outputs, sample=sample, noise=noise) - else: - prev_sample = self.multistep_dpm_solver_third_order_update( - self.model_outputs, sample=sample) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # Cast sample back to expected dtype - prev_sample = prev_sample.to(model_output.dtype) - - # upon completion increase step index by one - self._step_index += 1 # pyright: ignore - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input - def scale_model_input(self, sample: torch.Tensor, *args, - **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - Args: - sample (`torch.Tensor`): - The input sample. - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.scale_model_input - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to( - device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point( - timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to( - original_samples.device, dtype=torch.float32) - timesteps = timesteps.to( - original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [ - self.index_for_timestep(t, schedule_timesteps) - for t in timesteps - ] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py b/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py deleted file mode 100644 index 8d960583..00000000 --- a/dfm/src/megatron/model/wan/inference/utils/fm_solvers_unipc.py +++ /dev/null @@ -1,801 +0,0 @@ -# Copied from https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/schedulers/scheduling_unipc_multistep.py -# Convert unipc for flow matching - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import ( - KarrasDiffusionSchedulers, - SchedulerMixin, - SchedulerOutput, -) -from diffusers.utils import deprecate, is_scipy_available - -if is_scipy_available(): - import scipy.stats - - -class FlowUniPCMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - `UniPCMultistepScheduler` is a training-free framework designed for the fast sampling of diffusion models. - - This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic - methods the library implements for all schedulers such as loading and saving. - - Args: - num_train_timesteps (`int`, defaults to 1000): - The number of diffusion steps to train the model. - solver_order (`int`, default `2`): - The UniPC order which can be any positive integer. The effective order of accuracy is `solver_order + 1` - due to the UniC. It is recommended to use `solver_order=2` for guided sampling, and `solver_order=3` for - unconditional sampling. - prediction_type (`str`, defaults to "flow_prediction"): - Prediction type of the scheduler function; must be `flow_prediction` for this scheduler, which predicts - the flow of the diffusion process. - thresholding (`bool`, defaults to `False`): - Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such - as Stable Diffusion. - dynamic_thresholding_ratio (`float`, defaults to 0.995): - The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. - sample_max_value (`float`, defaults to 1.0): - The threshold value for dynamic thresholding. Valid only when `thresholding=True` and `predict_x0=True`. - predict_x0 (`bool`, defaults to `True`): - Whether to use the updating algorithm on the predicted x0. - solver_type (`str`, default `bh2`): - Solver type for UniPC. It is recommended to use `bh1` for unconditional sampling when steps < 10, and `bh2` - otherwise. - lower_order_final (`bool`, default `True`): - Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can - stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. - disable_corrector (`list`, default `[]`): - Decides which step to disable the corrector to mitigate the misalignment between `epsilon_theta(x_t, c)` - and `epsilon_theta(x_t^c, c)` which can influence convergence for a large guidance scale. Corrector is - usually disabled during the first few steps. - solver_p (`SchedulerMixin`, default `None`): - Any other scheduler that if specified, the algorithm becomes `solver_p + UniC`. - use_karras_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use Karras sigmas for step sizes in the noise schedule during the sampling process. If `True`, - the sigmas are determined according to a sequence of noise levels {σi}. - use_exponential_sigmas (`bool`, *optional*, defaults to `False`): - Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. - timestep_spacing (`str`, defaults to `"linspace"`): - The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. - steps_offset (`int`, defaults to 0): - An offset added to the inference steps, as required by some model families. - final_sigmas_type (`str`, defaults to `"zero"`): - The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final - sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - solver_order: int = 2, - prediction_type: str = "flow_prediction", - shift: Optional[float] = 1.0, - use_dynamic_shifting=False, - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - predict_x0: bool = True, - solver_type: str = "bh2", - lower_order_final: bool = True, - disable_corrector: List[int] = [], - solver_p: SchedulerMixin = None, - timestep_spacing: str = "linspace", - steps_offset: int = 0, - final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" - ): - - if solver_type not in ["bh1", "bh2"]: - if solver_type in ["midpoint", "heun", "logrho"]: - self.register_to_config(solver_type="bh2") - else: - raise NotImplementedError( - f"{solver_type} is not implemented for {self.__class__}") - - self.predict_x0 = predict_x0 - # setable values - self.num_inference_steps = None - alphas = np.linspace(1, 1 / num_train_timesteps, - num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + - (shift - 1) * sigmas) # pyright: ignore - - self.sigmas = sigmas - self.timesteps = sigmas * num_train_timesteps - - self.model_outputs = [None] * solver_order - self.timestep_list = [None] * solver_order - self.lower_order_nums = 0 - self.disable_corrector = disable_corrector - self.solver_p = solver_p - self.last_sample = None - self._step_index = None - self._begin_index = None - - self.sigmas = self.sigmas.to( - "cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - # Modified from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.set_timesteps - def set_timesteps( - self, - num_inference_steps: Union[int, None] = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[Union[float, None]] = None, - shift: Optional[Union[float, None]] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - Args: - num_inference_steps (`int`): - Total number of the spacing of the time steps. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError( - " you have to pass a value for `mu` when `use_dynamic_shifting` is set to be `True`" - ) - - if sigmas is None: - sigmas = np.linspace(self.sigma_max, self.sigma_min, - num_inference_steps + - 1).copy()[:-1] # pyright: ignore - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) # pyright: ignore - else: - if shift is None: - shift = self.config.shift - sigmas = shift * sigmas / (1 + - (shift - 1) * sigmas) # pyright: ignore - - if self.config.final_sigmas_type == "sigma_min": - sigma_last = ((1 - self.alphas_cumprod[0]) / - self.alphas_cumprod[0])**0.5 - elif self.config.final_sigmas_type == "zero": - sigma_last = 0 - else: - raise ValueError( - f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" - ) - - timesteps = sigmas * self.config.num_train_timesteps - sigmas = np.concatenate([sigmas, [sigma_last] - ]).astype(np.float32) # pyright: ignore - - self.sigmas = torch.from_numpy(sigmas) - self.timesteps = torch.from_numpy(timesteps).to( - device=device, dtype=torch.int64) - - self.num_inference_steps = len(timesteps) - - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - self.last_sample = None - if self.solver_p: - self.solver_p.set_timesteps(self.num_inference_steps, device=device) - - # add an index counter for schedulers that allow duplicated timesteps - self._step_index = None - self._begin_index = None - self.sigmas = self.sigmas.to( - "cpu") # to avoid too much CPU/GPU communication - - # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample - def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor: - """ - "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the - prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by - s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing - pixels from saturation at each step. We find that dynamic thresholding results in significantly better - photorealism as well as better image-text alignment, especially when using very large guidance weights." - - https://arxiv.org/abs/2205.11487 - """ - dtype = sample.dtype - batch_size, channels, *remaining_dims = sample.shape - - if dtype not in (torch.float32, torch.float64): - sample = sample.float( - ) # upcast for quantile calculation, and clamp not implemented for cpu half - - # Flatten sample for doing quantile calculation along each image - sample = sample.reshape(batch_size, channels * np.prod(remaining_dims)) - - abs_sample = sample.abs() # "a certain percentile absolute pixel value" - - s = torch.quantile( - abs_sample, self.config.dynamic_thresholding_ratio, dim=1) - s = torch.clamp( - s, min=1, max=self.config.sample_max_value - ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] - s = s.unsqueeze( - 1) # (batch_size, 1) because clamp will broadcast along dim=0 - sample = torch.clamp( - sample, -s, s - ) / s # "we threshold xt0 to the range [-s, s] and then divide by s" - - sample = sample.reshape(batch_size, channels, *remaining_dims) - sample = sample.to(dtype) - - return sample - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._sigma_to_t - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def _sigma_to_alpha_sigma_t(self, sigma): - return 1 - sigma, sigma - - # Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.set_timesteps - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1)**sigma) - - def convert_model_output( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - **kwargs, - ) -> torch.Tensor: - r""" - Convert the model output to the corresponding type the UniPC algorithm needs. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - - Returns: - `torch.Tensor`: - The converted model output. - """ - timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError( - "missing `sample` as a required keyward argument") - if timestep is not None: - deprecate( - "timesteps", - "1.0.0", - "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - sigma = self.sigmas[self.step_index] - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - - if self.predict_x0: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - x0_pred = self._threshold_sample(x0_pred) - - return x0_pred - else: - if self.config.prediction_type == "flow_prediction": - sigma_t = self.sigmas[self.step_index] - epsilon = sample - (1 - sigma_t) * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`," - " `v_prediction` or `flow_prediction` for the UniPCMultistepScheduler." - ) - - if self.config.thresholding: - sigma_t = self.sigmas[self.step_index] - x0_pred = sample - sigma_t * model_output - x0_pred = self._threshold_sample(x0_pred) - epsilon = model_output + x0_pred - - return epsilon - - def multistep_uni_p_bh_update( - self, - model_output: torch.Tensor, - *args, - sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified. - - Args: - model_output (`torch.Tensor`): - The direct output from the learned diffusion model at the current timestep. - prev_timestep (`int`): - The previous discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - order (`int`): - The order of UniP at this timestep (corresponds to the *p* in UniPC-p). - - Returns: - `torch.Tensor`: - The sample tensor at the previous timestep. - """ - prev_timestep = args[0] if len(args) > 0 else kwargs.pop( - "prev_timestep", None) - if sample is None: - if len(args) > 1: - sample = args[1] - else: - raise ValueError( - " missing `sample` as a required keyward argument") - if order is None: - if len(args) > 2: - order = args[2] - else: - raise ValueError( - " missing `order` as a required keyward argument") - if prev_timestep is not None: - deprecate( - "prev_timestep", - "1.0.0", - "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - model_output_list = self.model_outputs - - s0 = self.timestep_list[-1] - m0 = model_output_list[-1] - x = sample - - if self.solver_p: - x_t = self.solver_p.step(model_output, s0, x).prev_sample - return x_t - - sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[ - self.step_index] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - i # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) # (B, K) - # for order 2, we use a simplified version - if order == 2: - rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_p = torch.linalg.solve(R[:-1, :-1], - b[:-1]).to(device).to(x.dtype) - else: - D1s = None - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, - D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - alpha_t * B_h * pred_res - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - pred_res = torch.einsum("k,bkc...->bc...", rhos_p, - D1s) # pyright: ignore - else: - pred_res = 0 - x_t = x_t_ - sigma_t * B_h * pred_res - - x_t = x_t.to(x.dtype) - return x_t - - def multistep_uni_c_bh_update( - self, - this_model_output: torch.Tensor, - *args, - last_sample: torch.Tensor = None, - this_sample: torch.Tensor = None, - order: int = None, # pyright: ignore - **kwargs, - ) -> torch.Tensor: - """ - One step for the UniC (B(h) version). - - Args: - this_model_output (`torch.Tensor`): - The model outputs at `x_t`. - this_timestep (`int`): - The current timestep `t`. - last_sample (`torch.Tensor`): - The generated sample before the last predictor `x_{t-1}`. - this_sample (`torch.Tensor`): - The generated sample after the last predictor `x_{t}`. - order (`int`): - The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`. - - Returns: - `torch.Tensor`: - The corrected sample tensor at the current timestep. - """ - this_timestep = args[0] if len(args) > 0 else kwargs.pop( - "this_timestep", None) - if last_sample is None: - if len(args) > 1: - last_sample = args[1] - else: - raise ValueError( - " missing`last_sample` as a required keyward argument") - if this_sample is None: - if len(args) > 2: - this_sample = args[2] - else: - raise ValueError( - " missing`this_sample` as a required keyward argument") - if order is None: - if len(args) > 3: - order = args[3] - else: - raise ValueError( - " missing`order` as a required keyward argument") - if this_timestep is not None: - deprecate( - "this_timestep", - "1.0.0", - "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`", - ) - - model_output_list = self.model_outputs - - m0 = model_output_list[-1] - x = last_sample - x_t = this_sample - model_t = this_model_output - - sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[ - self.step_index - 1] # pyright: ignore - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) - alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) - - lambda_t = torch.log(alpha_t) - torch.log(sigma_t) - lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) - - h = lambda_t - lambda_s0 - device = this_sample.device - - rks = [] - D1s = [] - for i in range(1, order): - si = self.step_index - (i + 1) # pyright: ignore - mi = model_output_list[-(i + 1)] - alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si]) - lambda_si = torch.log(alpha_si) - torch.log(sigma_si) - rk = (lambda_si - lambda_s0) / h - rks.append(rk) - D1s.append((mi - m0) / rk) # pyright: ignore - - rks.append(1.0) - rks = torch.tensor(rks, device=device) - - R = [] - b = [] - - hh = -h if self.predict_x0 else h - h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 - h_phi_k = h_phi_1 / hh - 1 - - factorial_i = 1 - - if self.config.solver_type == "bh1": - B_h = hh - elif self.config.solver_type == "bh2": - B_h = torch.expm1(hh) - else: - raise NotImplementedError() - - for i in range(1, order + 1): - R.append(torch.pow(rks, i - 1)) - b.append(h_phi_k * factorial_i / B_h) - factorial_i *= i + 1 - h_phi_k = h_phi_k / hh - 1 / factorial_i - - R = torch.stack(R) - b = torch.tensor(b, device=device) - - if len(D1s) > 0: - D1s = torch.stack(D1s, dim=1) - else: - D1s = None - - # for order 1, we use a simplified version - if order == 1: - rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device) - else: - rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype) - - if self.predict_x0: - x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t) - else: - x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 - if D1s is not None: - corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s) - else: - corr_res = 0 - D1_t = model_t - m0 - x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t) - x_t = x_t.to(x.dtype) - return x_t - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index - def _init_step_index(self, timestep): - """ - Initialize the step_index counter for the scheduler. - """ - - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - def step(self, - model_output: torch.Tensor, - timestep: Union[int, torch.Tensor], - sample: torch.Tensor, - return_dict: bool = True, - generator=None) -> Union[SchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with - the multistep UniPC. - - Args: - model_output (`torch.Tensor`): - The direct output from learned diffusion model. - timestep (`int`): - The current discrete timestep in the diffusion chain. - sample (`torch.Tensor`): - A current instance of a sample created by the diffusion process. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. - - Returns: - [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a - tuple is returned where the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if self.step_index is None: - self._init_step_index(timestep) - - use_corrector = ( - self.step_index > 0 and - self.step_index - 1 not in self.disable_corrector and - self.last_sample is not None # pyright: ignore - ) - - model_output_convert = self.convert_model_output( - model_output, sample=sample) - if use_corrector: - sample = self.multistep_uni_c_bh_update( - this_model_output=model_output_convert, - last_sample=self.last_sample, - this_sample=sample, - order=self.this_order, - ) - - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.timestep_list[i] = self.timestep_list[i + 1] - - self.model_outputs[-1] = model_output_convert - self.timestep_list[-1] = timestep # pyright: ignore - - if self.config.lower_order_final: - this_order = min(self.config.solver_order, - len(self.timesteps) - - self.step_index) # pyright: ignore - else: - this_order = self.config.solver_order - - self.this_order = min(this_order, - self.lower_order_nums + 1) # warmup for multistep - assert self.this_order > 0 - - self.last_sample = sample - prev_sample = self.multistep_uni_p_bh_update( - model_output=model_output, # pass the original non-converted model output, in case solver-p is used - sample=sample, - order=self.this_order, - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - # upon completion increase step index by one - self._step_index += 1 # pyright: ignore - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.Tensor, *args, - **kwargs) -> torch.Tensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.Tensor`): - The input sample. - - Returns: - `torch.Tensor`: - A scaled input sample. - """ - return sample - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise - def add_noise( - self, - original_samples: torch.Tensor, - noise: torch.Tensor, - timesteps: torch.IntTensor, - ) -> torch.Tensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to( - device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point( - timesteps): - # mps does not support float64 - schedule_timesteps = self.timesteps.to( - original_samples.device, dtype=torch.float32) - timesteps = timesteps.to( - original_samples.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - # begin_index is None when the scheduler is used for training or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [ - self.index_for_timestep(t, schedule_timesteps) - for t in timesteps - ] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timesteps.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timesteps.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma) - noisy_samples = alpha_t * original_samples + sigma_t * noise - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 2f480a2b..1c04c1ca 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -31,7 +31,7 @@ from dfm.src.megatron.model.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline from dfm.src.megatron.model.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS -from dfm.src.megatron.model.wan.inference.utils.utils import cache_video, str2bool +from dfm.src.megatron.model.wan.inference.utils import cache_video, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { @@ -148,12 +148,6 @@ def _parse_args(): type=int, default=-1, help="The seed to use for generating the image or video.") - parser.add_argument( - "--sample_solver", - type=str, - default='unipc', - choices=['unipc', 'dpm++'], - help="The solver used to sample.") parser.add_argument( "--sample_steps", type=int, default=None, help="The sampling steps.") parser.add_argument( @@ -263,6 +257,7 @@ def generate(args): pipeline = FlowInferencePipeline( config=cfg, checkpoint_dir=args.checkpoint_dir, + model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", checkpoint_step=args.checkpoint_step, t5_checkpoint_dir=args.t5_checkpoint_dir, vae_checkpoint_dir=args.vae_checkpoint_dir, @@ -292,7 +287,6 @@ def generate(args): sizes=[SIZE_CONFIGS[size] for size in size_keys], frame_nums=frame_nums, shift=args.sample_shift, - sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, From 7f414aeab8ebc791d9fecccacafc4796b835ecac Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 14:32:31 -0800 Subject: [PATCH 16/80] update repo, remove Wan's Github moduels --- .../wan/inference/configs/shared_config.py | 1 + .../megatron/model/wan/modules/__init__.py | 13 - dfm/src/megatron/model/wan/modules/t5.py | 512 -------------- .../megatron/model/wan/modules/tokenizers.py | 81 --- dfm/src/megatron/model/wan/modules/vae.py | 662 ------------------ dfm/src/megatron/model/wan/wan_layer_spec.py | 39 +- dfm/src/megatron/model/wan/wan_model.py | 10 +- dfm/src/megatron/model/wan/wan_provider.py | 7 +- .../megatron/recipes/wan/inference_wan.py | 31 +- examples/megatron/recipes/wan/pretrain_wan.py | 14 +- 10 files changed, 37 insertions(+), 1333 deletions(-) delete mode 100644 dfm/src/megatron/model/wan/modules/__init__.py delete mode 100644 dfm/src/megatron/model/wan/modules/t5.py delete mode 100644 dfm/src/megatron/model/wan/modules/tokenizers.py delete mode 100644 dfm/src/megatron/model/wan/modules/vae.py diff --git a/dfm/src/megatron/model/wan/inference/configs/shared_config.py b/dfm/src/megatron/model/wan/inference/configs/shared_config.py index 37d3ae0c..52b6d92b 100644 --- a/dfm/src/megatron/model/wan/inference/configs/shared_config.py +++ b/dfm/src/megatron/model/wan/inference/configs/shared_config.py @@ -1,6 +1,7 @@ import torch from easydict import EasyDict + #------------------------ Wan shared config ------------------------# wan_shared_cfg = EasyDict() diff --git a/dfm/src/megatron/model/wan/modules/__init__.py b/dfm/src/megatron/model/wan/modules/__init__.py deleted file mode 100644 index 435f1eef..00000000 --- a/dfm/src/megatron/model/wan/modules/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model -from .tokenizers import HuggingfaceTokenizer -from .vae import WanVAE - - -__all__ = [ - 'WanVAE', - 'T5Model', - 'T5Encoder', - 'T5Decoder', - 'T5EncoderModel', - 'HuggingfaceTokenizer', -] diff --git a/dfm/src/megatron/model/wan/modules/t5.py b/dfm/src/megatron/model/wan/modules/t5.py deleted file mode 100644 index fecd989e..00000000 --- a/dfm/src/megatron/model/wan/modules/t5.py +++ /dev/null @@ -1,512 +0,0 @@ -# Modified from transformers.models.t5.modeling_t5 -import logging -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .tokenizers import HuggingfaceTokenizer - -__all__ = [ - 'T5Model', - 'T5Encoder', - 'T5Decoder', - 'T5EncoderModel', -] - - -def fp16_clamp(x): - if x.dtype == torch.float16 and torch.isinf(x).any(): - clamp = torch.finfo(x.dtype).max - 1000 - x = torch.clamp(x, min=-clamp, max=clamp) - return x - - -def init_weights(m): - if isinstance(m, T5LayerNorm): - nn.init.ones_(m.weight) - elif isinstance(m, T5Model): - nn.init.normal_(m.token_embedding.weight, std=1.0) - elif isinstance(m, T5FeedForward): - nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) - nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) - nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) - elif isinstance(m, T5Attention): - nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5) - nn.init.normal_(m.k.weight, std=m.dim**-0.5) - nn.init.normal_(m.v.weight, std=m.dim**-0.5) - nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5) - elif isinstance(m, T5RelativeEmbedding): - nn.init.normal_( - m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5) - - -class GELU(nn.Module): - - def forward(self, x): - return 0.5 * x * (1.0 + torch.tanh( - math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) - - -class T5LayerNorm(nn.Module): - - def __init__(self, dim, eps=1e-6): - super(T5LayerNorm, self).__init__() - self.dim = dim - self.eps = eps - self.weight = nn.Parameter(torch.ones(dim)) - - def forward(self, x): - x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + - self.eps) - if self.weight.dtype in [torch.float16, torch.bfloat16]: - x = x.type_as(self.weight) - return self.weight * x - - -class T5Attention(nn.Module): - - def __init__(self, dim, dim_attn, num_heads, dropout=0.1): - assert dim_attn % num_heads == 0 - super(T5Attention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.num_heads = num_heads - self.head_dim = dim_attn // num_heads - - # layers - self.q = nn.Linear(dim, dim_attn, bias=False) - self.k = nn.Linear(dim, dim_attn, bias=False) - self.v = nn.Linear(dim, dim_attn, bias=False) - self.o = nn.Linear(dim_attn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, context=None, mask=None, pos_bias=None): - """ - x: [B, L1, C]. - context: [B, L2, C] or None. - mask: [B, L2] or [B, L1, L2] or None. - """ - # check inputs - context = x if context is None else context - b, n, c = x.size(0), self.num_heads, self.head_dim - - # compute query, key, value - q = self.q(x).view(b, -1, n, c) - k = self.k(context).view(b, -1, n, c) - v = self.v(context).view(b, -1, n, c) - - # attention bias - attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) - if pos_bias is not None: - attn_bias += pos_bias - if mask is not None: - assert mask.ndim in [2, 3] - mask = mask.view(b, 1, 1, - -1) if mask.ndim == 2 else mask.unsqueeze(1) - attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) - - # compute attention (T5 does not use scaling) - attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias - attn = F.softmax(attn.float(), dim=-1).type_as(attn) - x = torch.einsum('bnij,bjnc->binc', attn, v) - - # output - x = x.reshape(b, -1, n * c) - x = self.o(x) - x = self.dropout(x) - return x - - -class T5FeedForward(nn.Module): - - def __init__(self, dim, dim_ffn, dropout=0.1): - super(T5FeedForward, self).__init__() - self.dim = dim - self.dim_ffn = dim_ffn - - # layers - self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) - self.fc1 = nn.Linear(dim, dim_ffn, bias=False) - self.fc2 = nn.Linear(dim_ffn, dim, bias=False) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) * self.gate(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - - -class T5SelfAttention(nn.Module): - - def __init__(self, - dim, - dim_attn, - dim_ffn, - num_heads, - num_buckets, - shared_pos=True, - dropout=0.1): - super(T5SelfAttention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=True) - - def forward(self, x, mask=None, pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding( - x.size(1), x.size(1)) - x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.ffn(self.norm2(x))) - return x - - -class T5CrossAttention(nn.Module): - - def __init__(self, - dim, - dim_attn, - dim_ffn, - num_heads, - num_buckets, - shared_pos=True, - dropout=0.1): - super(T5CrossAttention, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.norm1 = T5LayerNorm(dim) - self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm2 = T5LayerNorm(dim) - self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) - self.norm3 = T5LayerNorm(dim) - self.ffn = T5FeedForward(dim, dim_ffn, dropout) - self.pos_embedding = None if shared_pos else T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=False) - - def forward(self, - x, - mask=None, - encoder_states=None, - encoder_mask=None, - pos_bias=None): - e = pos_bias if self.shared_pos else self.pos_embedding( - x.size(1), x.size(1)) - x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) - x = fp16_clamp(x + self.cross_attn( - self.norm2(x), context=encoder_states, mask=encoder_mask)) - x = fp16_clamp(x + self.ffn(self.norm3(x))) - return x - - -class T5RelativeEmbedding(nn.Module): - - def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): - super(T5RelativeEmbedding, self).__init__() - self.num_buckets = num_buckets - self.num_heads = num_heads - self.bidirectional = bidirectional - self.max_dist = max_dist - - # layers - self.embedding = nn.Embedding(num_buckets, num_heads) - - def forward(self, lq, lk): - device = self.embedding.weight.device - # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ - # torch.arange(lq).unsqueeze(1).to(device) - rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \ - torch.arange(lq, device=device).unsqueeze(1) - rel_pos = self._relative_position_bucket(rel_pos) - rel_pos_embeds = self.embedding(rel_pos) - rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze( - 0) # [1, N, Lq, Lk] - return rel_pos_embeds.contiguous() - - def _relative_position_bucket(self, rel_pos): - # preprocess - if self.bidirectional: - num_buckets = self.num_buckets // 2 - rel_buckets = (rel_pos > 0).long() * num_buckets - rel_pos = torch.abs(rel_pos) - else: - num_buckets = self.num_buckets - rel_buckets = 0 - rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) - - # embeddings for small and large positions - max_exact = num_buckets // 2 - rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / - math.log(self.max_dist / max_exact) * - (num_buckets - max_exact)).long() - rel_pos_large = torch.min( - rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) - rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) - return rel_buckets - - -class T5Encoder(nn.Module): - - def __init__(self, - vocab, - dim, - dim_attn, - dim_ffn, - num_heads, - num_layers, - num_buckets, - shared_pos=True, - dropout=0.1): - super(T5Encoder, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ - else nn.Embedding(vocab, dim) - self.pos_embedding = T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=True) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList([ - T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, - shared_pos, dropout) for _ in range(num_layers) - ]) - self.norm = T5LayerNorm(dim) - - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None): - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), - x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Decoder(nn.Module): - - def __init__(self, - vocab, - dim, - dim_attn, - dim_ffn, - num_heads, - num_layers, - num_buckets, - shared_pos=True, - dropout=0.1): - super(T5Decoder, self).__init__() - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.num_layers = num_layers - self.num_buckets = num_buckets - self.shared_pos = shared_pos - - # layers - self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \ - else nn.Embedding(vocab, dim) - self.pos_embedding = T5RelativeEmbedding( - num_buckets, num_heads, bidirectional=False) if shared_pos else None - self.dropout = nn.Dropout(dropout) - self.blocks = nn.ModuleList([ - T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, - shared_pos, dropout) for _ in range(num_layers) - ]) - self.norm = T5LayerNorm(dim) - - # initialize weights - self.apply(init_weights) - - def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): - b, s = ids.size() - - # causal mask - if mask is None: - mask = torch.tril(torch.ones(1, s, s).to(ids.device)) - elif mask.ndim == 2: - mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) - - # layers - x = self.token_embedding(ids) - x = self.dropout(x) - e = self.pos_embedding(x.size(1), - x.size(1)) if self.shared_pos else None - for block in self.blocks: - x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) - x = self.norm(x) - x = self.dropout(x) - return x - - -class T5Model(nn.Module): - - def __init__(self, - vocab_size, - dim, - dim_attn, - dim_ffn, - num_heads, - encoder_layers, - decoder_layers, - num_buckets, - shared_pos=True, - dropout=0.1): - super(T5Model, self).__init__() - self.vocab_size = vocab_size - self.dim = dim - self.dim_attn = dim_attn - self.dim_ffn = dim_ffn - self.num_heads = num_heads - self.encoder_layers = encoder_layers - self.decoder_layers = decoder_layers - self.num_buckets = num_buckets - - # layers - self.token_embedding = nn.Embedding(vocab_size, dim) - self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, - num_heads, encoder_layers, num_buckets, - shared_pos, dropout) - self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, - num_heads, decoder_layers, num_buckets, - shared_pos, dropout) - self.head = nn.Linear(dim, vocab_size, bias=False) - - # initialize weights - self.apply(init_weights) - - def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): - x = self.encoder(encoder_ids, encoder_mask) - x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) - x = self.head(x) - return x - - -def _t5(name, - encoder_only=False, - decoder_only=False, - return_tokenizer=False, - tokenizer_kwargs={}, - dtype=torch.float32, - device='cpu', - **kwargs): - # sanity check - assert not (encoder_only and decoder_only) - - # params - if encoder_only: - model_cls = T5Encoder - kwargs['vocab'] = kwargs.pop('vocab_size') - kwargs['num_layers'] = kwargs.pop('encoder_layers') - _ = kwargs.pop('decoder_layers') - elif decoder_only: - model_cls = T5Decoder - kwargs['vocab'] = kwargs.pop('vocab_size') - kwargs['num_layers'] = kwargs.pop('decoder_layers') - _ = kwargs.pop('encoder_layers') - else: - model_cls = T5Model - - # init model - with torch.device(device): - model = model_cls(**kwargs) - - # set device - model = model.to(dtype=dtype, device=device) - - # init tokenizer - if return_tokenizer: - from .tokenizers import HuggingfaceTokenizer - tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs) - return model, tokenizer - else: - return model - - -def umt5_xxl(**kwargs): - cfg = dict( - vocab_size=256384, - dim=4096, - dim_attn=4096, - dim_ffn=10240, - num_heads=64, - encoder_layers=24, - decoder_layers=24, - num_buckets=32, - shared_pos=False, - dropout=0.1) - cfg.update(**kwargs) - return _t5('umt5-xxl', **cfg) - - -class T5EncoderModel: - - def __init__( - self, - text_len, - dtype=torch.bfloat16, - device=torch.cuda.current_device(), - checkpoint_path=None, - tokenizer_path=None, - shard_fn=None, - ): - self.text_len = text_len - self.dtype = dtype - self.device = device - self.checkpoint_path = checkpoint_path - self.tokenizer_path = tokenizer_path - - # init model - model = umt5_xxl( - encoder_only=True, - return_tokenizer=False, - dtype=dtype, - device=device).eval().requires_grad_(False) - logging.info(f'loading {checkpoint_path}') - model.load_state_dict(torch.load(checkpoint_path, map_location='cpu')) - self.model = model - if shard_fn is not None: - self.model = shard_fn(self.model, sync_module_states=False) - else: - self.model.to(self.device) - # init tokenizer - self.tokenizer = HuggingfaceTokenizer( - name=tokenizer_path, seq_len=text_len, clean='whitespace') - - def __call__(self, texts, device): - ids, mask = self.tokenizer( - texts, return_mask=True, add_special_tokens=True) - ids = ids.to(device) - mask = mask.to(device) - seq_lens = mask.gt(0).sum(dim=1).long() - context = self.model(ids, mask) - return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/dfm/src/megatron/model/wan/modules/tokenizers.py b/dfm/src/megatron/model/wan/modules/tokenizers.py deleted file mode 100644 index a69972ad..00000000 --- a/dfm/src/megatron/model/wan/modules/tokenizers.py +++ /dev/null @@ -1,81 +0,0 @@ -import html -import string - -import ftfy -import regex as re -from transformers import AutoTokenizer - -__all__ = ['HuggingfaceTokenizer'] - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -def canonicalize(text, keep_punctuation_exact_string=None): - text = text.replace('_', ' ') - if keep_punctuation_exact_string: - text = keep_punctuation_exact_string.join( - part.translate(str.maketrans('', '', string.punctuation)) - for part in text.split(keep_punctuation_exact_string)) - else: - text = text.translate(str.maketrans('', '', string.punctuation)) - text = text.lower() - text = re.sub(r'\s+', ' ', text) - return text.strip() - - -class HuggingfaceTokenizer: - - def __init__(self, name, seq_len=None, clean=None, **kwargs): - assert clean in (None, 'whitespace', 'lower', 'canonicalize') - self.name = name - self.seq_len = seq_len - self.clean = clean - - # init tokenizer - self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) - self.vocab_size = self.tokenizer.vocab_size - - def __call__(self, sequence, **kwargs): - return_mask = kwargs.pop('return_mask', False) - - # arguments - _kwargs = {'return_tensors': 'pt'} - if self.seq_len is not None: - _kwargs.update({ - 'padding': 'max_length', - 'truncation': True, - 'max_length': self.seq_len - }) - _kwargs.update(**kwargs) - - # tokenization - if isinstance(sequence, str): - sequence = [sequence] - if self.clean: - sequence = [self._clean(u) for u in sequence] - ids = self.tokenizer(sequence, **_kwargs) - - # output - if return_mask: - return ids.input_ids, ids.attention_mask - else: - return ids.input_ids - - def _clean(self, text): - if self.clean == 'whitespace': - text = whitespace_clean(basic_clean(text)) - elif self.clean == 'lower': - text = whitespace_clean(basic_clean(text)).lower() - elif self.clean == 'canonicalize': - text = canonicalize(basic_clean(text)) - return text diff --git a/dfm/src/megatron/model/wan/modules/vae.py b/dfm/src/megatron/model/wan/modules/vae.py deleted file mode 100644 index d4f1ef1d..00000000 --- a/dfm/src/megatron/model/wan/modules/vae.py +++ /dev/null @@ -1,662 +0,0 @@ -import logging - -import torch -import torch.cuda.amp as amp -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - -__all__ = [ - 'WanVAE', -] - -CACHE_T = 2 - - -class CausalConv3d(nn.Conv3d): - """ - Causal 3d convolusion. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._padding = (self.padding[2], self.padding[2], self.padding[1], - self.padding[1], 2 * self.padding[0], 0) - self.padding = (0, 0, 0) - - def forward(self, x, cache_x=None): - padding = list(self._padding) - if cache_x is not None and self._padding[4] > 0: - cache_x = cache_x.to(x.device) - x = torch.cat([cache_x, x], dim=2) - padding[4] -= cache_x.shape[2] - x = F.pad(x, padding) - - return super().forward(x) - - -class RMS_norm(nn.Module): - - def __init__(self, dim, channel_first=True, images=True, bias=False): - super().__init__() - broadcastable_dims = (1, 1, 1) if not images else (1, 1) - shape = (dim, *broadcastable_dims) if channel_first else (dim,) - - self.channel_first = channel_first - self.scale = dim**0.5 - self.gamma = nn.Parameter(torch.ones(shape)) - self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0. - - def forward(self, x): - return F.normalize( - x, dim=(1 if self.channel_first else - -1)) * self.scale * self.gamma + self.bias - - -class Upsample(nn.Upsample): - - def forward(self, x): - """ - Fix bfloat16 support for nearest neighbor interpolation. - """ - return super().forward(x.float()).type_as(x) - - -class Resample(nn.Module): - - def __init__(self, dim, mode): - assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d', - 'downsample3d') - super().__init__() - self.dim = dim - self.mode = mode - - # layers - if mode == 'upsample2d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - elif mode == 'upsample3d': - self.resample = nn.Sequential( - Upsample(scale_factor=(2., 2.), mode='nearest-exact'), - nn.Conv2d(dim, dim // 2, 3, padding=1)) - self.time_conv = CausalConv3d( - dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) - - elif mode == 'downsample2d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - elif mode == 'downsample3d': - self.resample = nn.Sequential( - nn.ZeroPad2d((0, 1, 0, 1)), - nn.Conv2d(dim, dim, 3, stride=(2, 2))) - self.time_conv = CausalConv3d( - dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) - - else: - self.resample = nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - b, c, t, h, w = x.size() - if self.mode == 'upsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = 'Rep' - feat_idx[0] += 1 - else: - - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) - if feat_cache[idx] == 'Rep': - x = self.time_conv(x) - else: - x = self.time_conv(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - - x = x.reshape(b, 2, c, t, h, w) - x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), - 3) - x = x.reshape(b, c, t * 2, h, w) - t = x.shape[2] - x = rearrange(x, 'b c t h w -> (b t) c h w') - x = self.resample(x) - x = rearrange(x, '(b t) c h w -> b c t h w', t=t) - - if self.mode == 'downsample3d': - if feat_cache is not None: - idx = feat_idx[0] - if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 - else: - - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - - x = self.time_conv( - torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - return x - - def init_weight(self, conv): - conv_weight = conv.weight - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - one_matrix = torch.eye(c1, c2) - init_matrix = one_matrix - nn.init.zeros_(conv_weight) - #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5 - conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5 - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - def init_weight2(self, conv): - conv_weight = conv.weight.data - nn.init.zeros_(conv_weight) - c1, c2, t, h, w = conv_weight.size() - init_matrix = torch.eye(c1 // 2, c2) - #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2) - conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix - conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix - conv.weight.data.copy_(conv_weight) - nn.init.zeros_(conv.bias.data) - - -class ResidualBlock(nn.Module): - - def __init__(self, in_dim, out_dim, dropout=0.0): - super().__init__() - self.in_dim = in_dim - self.out_dim = out_dim - - # layers - self.residual = nn.Sequential( - RMS_norm(in_dim, images=False), nn.SiLU(), - CausalConv3d(in_dim, out_dim, 3, padding=1), - RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout), - CausalConv3d(out_dim, out_dim, 3, padding=1)) - self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ - if in_dim != out_dim else nn.Identity() - - def forward(self, x, feat_cache=None, feat_idx=[0]): - h = self.shortcut(x) - for layer in self.residual: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + h - - -class AttentionBlock(nn.Module): - """ - Causal self-attention with a single head. - """ - - def __init__(self, dim): - super().__init__() - self.dim = dim - - # layers - self.norm = RMS_norm(dim) - self.to_qkv = nn.Conv2d(dim, dim * 3, 1) - self.proj = nn.Conv2d(dim, dim, 1) - - # zero out the last layer params - nn.init.zeros_(self.proj.weight) - - def forward(self, x): - identity = x - b, c, t, h, w = x.size() - x = rearrange(x, 'b c t h w -> (b t) c h w') - x = self.norm(x) - # compute query, key, value - q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, - -1).permute(0, 1, 3, - 2).contiguous().chunk( - 3, dim=-1) - - # apply attention - x = F.scaled_dot_product_attention( - q, - k, - v, - ) - x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) - - # output - x = self.proj(x) - x = rearrange(x, '(b t) c h w-> b c t h w', t=t) - return x + identity - - -class Encoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - - # dimensions - dims = [dim * u for u in [1] + dim_mult] - scale = 1.0 - - # init block - self.conv1 = CausalConv3d(3, dims[0], 3, padding=1) - - # downsample blocks - downsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - for _ in range(num_res_blocks): - downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - downsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # downsample block - if i != len(dim_mult) - 1: - mode = 'downsample3d' if temperal_downsample[ - i] else 'downsample2d' - downsamples.append(Resample(out_dim, mode=mode)) - scale /= 2.0 - self.downsamples = nn.Sequential(*downsamples) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim), - ResidualBlock(out_dim, out_dim, dropout)) - - # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, z_dim, 3, padding=1)) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## downsamples - for layer in self.downsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -class Decoder3d(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_upsample=[False, True, True], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_upsample = temperal_upsample - - # dimensions - dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] - scale = 1.0 / 2**(len(dim_mult) - 2) - - # init block - self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) - - # middle blocks - self.middle = nn.Sequential( - ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]), - ResidualBlock(dims[0], dims[0], dropout)) - - # upsample blocks - upsamples = [] - for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): - # residual (+attention) blocks - if i == 1 or i == 2 or i == 3: - in_dim = in_dim // 2 - for _ in range(num_res_blocks + 1): - upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) - if scale in attn_scales: - upsamples.append(AttentionBlock(out_dim)) - in_dim = out_dim - - # upsample block - if i != len(dim_mult) - 1: - mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d' - upsamples.append(Resample(out_dim, mode=mode)) - scale *= 2.0 - self.upsamples = nn.Sequential(*upsamples) - - # output blocks - self.head = nn.Sequential( - RMS_norm(out_dim, images=False), nn.SiLU(), - CausalConv3d(out_dim, 3, 3, padding=1)) - - def forward(self, x, feat_cache=None, feat_idx=[0]): - ## conv1 - if feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = self.conv1(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = self.conv1(x) - - ## middle - for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: - if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x - - -def count_conv3d(model): - count = 0 - for m in model.modules(): - if isinstance(m, CausalConv3d): - count += 1 - return count - - -class WanVAE_(nn.Module): - - def __init__(self, - dim=128, - z_dim=4, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[True, True, False], - dropout=0.0): - super().__init__() - self.dim = dim - self.z_dim = z_dim - self.dim_mult = dim_mult - self.num_res_blocks = num_res_blocks - self.attn_scales = attn_scales - self.temperal_downsample = temperal_downsample - self.temperal_upsample = temperal_downsample[::-1] - - # modules - self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks, - attn_scales, self.temperal_downsample, dropout) - self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) - self.conv2 = CausalConv3d(z_dim, z_dim, 1) - self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks, - attn_scales, self.temperal_upsample, dropout) - - def forward(self, x): - mu, log_var = self.encode(x) - z = self.reparameterize(mu, log_var) - x_recon = self.decode(z) - return x_recon, mu, log_var - - def encode(self, x, scale): - self.clear_cache() - ## cache - t = x.shape[2] - iter_ = 1 + (t - 1) // 4 - ## 对encode输入的x,按时间拆分为1、4、4、4.... - for i in range(iter_): - self._enc_conv_idx = [0] - if i == 0: - out = self.encoder( - x[:, :, :1, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - else: - out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], - feat_cache=self._enc_feat_map, - feat_idx=self._enc_conv_idx) - out = torch.cat([out, out_], 2) - mu, log_var = self.conv1(out).chunk(2, dim=1) - if isinstance(scale[0], torch.Tensor): - mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view( - 1, self.z_dim, 1, 1, 1) - else: - mu = (mu - scale[0]) * scale[1] - self.clear_cache() - return mu - - def decode(self, z, scale): - self.clear_cache() - # z: [b,c,t,h,w] - if isinstance(scale[0], torch.Tensor): - z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view( - 1, self.z_dim, 1, 1, 1) - else: - z = z / scale[1] + scale[0] - iter_ = z.shape[2] - x = self.conv2(z) - for i in range(iter_): - self._conv_idx = [0] - if i == 0: - out = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - else: - out_ = self.decoder( - x[:, :, i:i + 1, :, :], - feat_cache=self._feat_map, - feat_idx=self._conv_idx) - out = torch.cat([out, out_], 2) - self.clear_cache() - return out - - def reparameterize(self, mu, log_var): - std = torch.exp(0.5 * log_var) - eps = torch.randn_like(std) - return eps * std + mu - - def sample(self, imgs, deterministic=False): - mu, log_var = self.encode(imgs) - if deterministic: - return mu - std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) - return mu + std * torch.randn_like(std) - - def clear_cache(self): - self._conv_num = count_conv3d(self.decoder) - self._conv_idx = [0] - self._feat_map = [None] * self._conv_num - #cache encode - self._enc_conv_num = count_conv3d(self.encoder) - self._enc_conv_idx = [0] - self._enc_feat_map = [None] * self._enc_conv_num - - -def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs): - """ - Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL. - """ - # params - cfg = dict( - dim=96, - z_dim=z_dim, - dim_mult=[1, 2, 4, 4], - num_res_blocks=2, - attn_scales=[], - temperal_downsample=[False, True, True], - dropout=0.0) - cfg.update(**kwargs) - - # init model - with torch.device('meta'): - model = WanVAE_(**cfg) - - # load checkpoint - logging.info(f'loading {pretrained_path}') - model.load_state_dict( - torch.load(pretrained_path, map_location=device), assign=True) - - return model - - -class WanVAE: - - def __init__(self, - z_dim=16, - vae_pth='cache/vae_step_411000.pth', - dtype=torch.float, - device="cuda"): - self.dtype = dtype - self.device = device - - mean = [ - -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, - 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921 - ] - std = [ - 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, - 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 - ] - self.mean = torch.tensor(mean, dtype=dtype, device=device) - self.std = torch.tensor(std, dtype=dtype, device=device) - self.scale = [self.mean, 1.0 / self.std] - - # init model - self.model = _video_vae( - pretrained_path=vae_pth, - z_dim=z_dim, - ).eval().requires_grad_(False).to(device) - - def encode(self, videos): - """ - videos: A list of videos each with shape [C, T, H, W]. - """ - with amp.autocast(dtype=self.dtype): - return [ - self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0) - for u in videos - ] - - def decode(self, zs): - with amp.autocast(dtype=self.dtype): - return [ - self.model.decode(u.unsqueeze(0), - self.scale).float().clamp_(-1, 1).squeeze(0) - for u in zs - ] diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index cdf58847..23a6f22f 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -15,19 +15,14 @@ # pylint: disable=C0115,C0116,C0301 -import copy from dataclasses import dataclass from typing import Optional, Union import torch import torch.nn as nn -from megatron.core import parallel_state, tensor_parallel from megatron.core.extensions.transformer_engine import TENorm from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.attention import ( - CrossAttention, - SelfAttention, -) +from megatron.core.transformer.attention import SelfAttentionSubmodules from megatron.core.transformer.custom_layers.transformer_engine import ( TEColumnParallelLinear, TEDotProductAttention, @@ -41,34 +36,12 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules from megatron.core.utils import make_viewless_tensor -from dfm.src.megatron.model.common.dit_attention import DiTCrossAttentionSubmodules, DiTSelfAttention, DiTCrossAttention -from megatron.core.transformer.attention import SelfAttentionSubmodules - - -try: - import transformer_engine # pylint: disable=unused-import - HAVE_TE = True - from megatron.core.extensions.transformer_engine import SplitAlongDim - -except ImportError: - HAVE_TE = False - SplitAlongDim = None - - -# class WanLayerNorm(nn.LayerNorm): -# # Note to parth: Can we replace this with te layer norm or fuse with linear layer? -# # (@huy) Remove this comment after you have answered the question. - -# def __init__(self, dim, eps=1e-6, elementwise_affine=False): -# super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) - -# def forward(self, x): -# r""" -# Args: -# x(Tensor): Shape [B, L, C] -# """ -# return super().forward(x).type_as(x) +from dfm.src.megatron.model.common.dit_attention import ( + DiTCrossAttention, + DiTCrossAttentionSubmodules, + DiTSelfAttention, +) @dataclass diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 8b475f78..b43a6a70 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -14,11 +14,10 @@ # pylint: disable=C0115,C0116,C0301 -from typing import Dict, Literal, Optional, Tuple, List, Union - import math +from typing import Dict, Optional, Tuple + import torch -import torch.cuda.amp as amp import torch.nn as nn from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict @@ -28,12 +27,15 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint +from torch import Tensor + from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) -from torch import Tensor + from .rope_utils import Wan3DRopeEmbeddings + def sinusoidal_embedding_1d(dim, position): # preprocess assert dim % 2 == 0 diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index 90f41e14..95f803e9 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -17,13 +17,14 @@ from dataclasses import dataclass import torch -from megatron.core import parallel_state -from megatron.bridge.models.transformer_config import TransformerConfig - from megatron.bridge.models.model_provider import ModelProviderMixin +from megatron.bridge.models.transformer_config import TransformerConfig +from megatron.core import parallel_state from megatron.core.models.common.vision_module.vision_module import VisionModule + from dfm.src.megatron.model.wan.wan_model import WanModel + logger = logging.getLogger(__name__) @dataclass diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 1c04c1ca..b65d8672 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -1,18 +1,17 @@ -# Example of running script for Wan inference. -# NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ -# --task t2v-1.3B \ -# --sizes 480*832 \ -# --checkpoint_dir /path/to/wan_checkpoint_dir \ -# --t5_checkpoint_dir /path/to/t5_checkpoint_dir \ -# --vae_checkpoint_dir /path/to/vae_checkpoint_dir \ -# --frame_nums 81 \ -# --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ -# --tensor_parallel_size 1 \ -# --context_parallel_size 1 \ -# --pipeline_parallel_size 1 \ -# --sequence_parallel False \ -# --base_seed 42 \ -# --sample_steps 50 +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import logging @@ -21,6 +20,7 @@ import warnings from datetime import datetime + warnings.filterwarnings('ignore') import random @@ -33,6 +33,7 @@ from dfm.src.megatron.model.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from dfm.src.megatron.model.wan.inference.utils import cache_video, str2bool + EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 251934d6..6eae3e31 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python3 # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. # @@ -56,11 +55,7 @@ from pathlib import Path from typing import Tuple -from omegaconf import OmegaConf - -from dfm.src.megatron.recipes.wan.wan import pretrain_config from megatron.bridge.training.config import ConfigContainer -from dfm.src.megatron.model.wan.wan_step import WanForwardStep from megatron.bridge.training.pretrain import pretrain from megatron.bridge.training.utils.omegaconf_utils import ( apply_overrides, @@ -68,6 +63,10 @@ parse_hydra_overrides, ) from megatron.bridge.utils.common_utils import get_rank_safe +from omegaconf import OmegaConf + +from dfm.src.megatron.model.wan.wan_step import WanForwardStep +from dfm.src.megatron.recipes.wan.wan import pretrain_config logger: logging.Logger = logging.getLogger(__name__) @@ -80,11 +79,6 @@ DEFAULT_CONFIG_FILENAME: str = "wan_pretrain_override_example.yaml" DEFAULT_CONFIG_FILE_PATH: Path = SCRIPT_DIR / "conf" / DEFAULT_CONFIG_FILENAME -# DEBUGGING -import numpy as np -import torch -np.set_printoptions(precision=10, suppress=False) -torch.set_printoptions(precision=10, sci_mode=False) def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: """Parse command line arguments, separating known script args from OmegaConf overrides.""" From 2de59344897604beea76fd78944a0c8d9e1729d9 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 17:06:09 -0800 Subject: [PATCH 17/80] fix Ruff --- .../data/wan/wan_energon_datamodule.py | 5 ++- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- .../megatron/model/common/dit_attention.py | 12 +++---- .../flow_matching/flow_inference_pipeline.py | 31 ++----------------- .../model/wan/flow_matching/flow_pipeline.py | 3 +- .../wan/flow_matching/time_shift_utils.py | 1 + .../model/wan/inference/configs/__init__.py | 2 +- .../wan/inference/configs/wan_t2v_14B.py | 1 + .../wan/inference/configs/wan_t2v_1_3B.py | 1 + dfm/src/megatron/model/wan/inference/utils.py | 1 + dfm/src/megatron/model/wan/rope_utils.py | 3 +- .../megatron/model/wan/{utils => }/utils.py | 0 dfm/src/megatron/model/wan/wan_step.py | 6 ++-- dfm/src/megatron/recipes/wan/wan.py | 10 +++--- .../megatron/recipes/wan/inference_wan.py | 3 +- .../wan/prepare_energon_dataset_wan.py | 2 -- 16 files changed, 29 insertions(+), 54 deletions(-) rename dfm/src/megatron/model/wan/{utils => }/utils.py (100%) diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 8a209eda..5649c7ca 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -15,14 +15,13 @@ # pylint: disable=C0115,C0116,C0301 from dataclasses import dataclass -import logging -from typing import Any, Dict, Literal +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from torch import int_repr from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder -from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + @dataclass(kw_only=True) class WanDataModuleConfig(DatasetProvider): diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index b4a21165..5c217770 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from megatron.energon import DefaultTaskEncoder, SkipSample from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify +from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index 113a7d3f..f667c0e7 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import copy +from dataclasses import dataclass +from typing import Optional, Union + import torch from megatron.core import parallel_state, tensor_parallel from megatron.core.extensions.transformer_engine import SplitAlongDim @@ -24,12 +25,9 @@ SelfAttention, SelfAttentionSubmodules, ) -from megatron.core.transformer.spec_utils import build_module -from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.enums import AttnMaskType -from dataclasses import dataclass -from typing import Union -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import TransformerConfig @dataclass diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index f5dcea58..adb97802 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -40,7 +40,7 @@ from tqdm import tqdm from transformers import AutoTokenizer, UMT5EncoderModel -from dfm.src.megatron.model.wan.utils.utils import grid_sizes_calculation, patchify +from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify, unpatchify from dfm.src.megatron.model.wan.wan_provider import WanModelProvider @@ -153,31 +153,6 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt - def unpatchify(self, x: torch.Tensor, grid_sizes: torch.Tensor, out_dim: int) -> list[torch.Tensor]: - r""" - Reconstruct video tensors from patch embeddings into a list of videotensors. - - Args: - x (torch.Tensor): - Tensor of patchified features, with shape [seq_len, c * pF * pH * pW] - grid_sizes (Tensor): - Original spatial-temporal grid dimensions before patching, - shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) - - Returns: - list[torch.Tensor]: list of tensors, each with shape [c, F_latents, H_latents, W_latents] - """ - - c = out_dim - out = [] - for u, v in zip(x, grid_sizes.tolist()): - u = u[:math.prod(v)].view(*v, *self.patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) - u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) - out.append(u) - return out - - def setup_model_from_checkpoint(self, checkpoint_dir): provider = WanModelProvider() provider.tensor_model_parallel_size = self.tensor_parallel_size @@ -512,11 +487,11 @@ def noop_no_sync(): unpatchified_noise_pred_cond = noise_pred_cond unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_cond = self.unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim) + unpatchified_noise_pred_cond = unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim, self.patch_size) unpatchified_noise_pred_uncond = noise_pred_uncond unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_uncond = self.unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim) + unpatchified_noise_pred_uncond = unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim, self.patch_size) noise_preds = [] for i in range(batch_size): diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 0940b653..2bbe5204 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -16,8 +16,9 @@ import torch from diffusers import WanPipeline from megatron.core import parallel_state + from dfm.src.megatron.model.wan.flow_matching.time_shift_utils import compute_density_for_timestep_sampling -from dfm.src.megatron.model.wan.utils.utils import patchify, thd_split_inputs_cp +from dfm.src.megatron.model.wan.utils import patchify, thd_split_inputs_cp class FlowPipeline: diff --git a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py index 56faee4b..7da8c70e 100644 --- a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py +++ b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py @@ -1,6 +1,7 @@ # time_shift_utils.py - Timestep sampling and sigma computation utilities import math + import numpy as np import torch diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index 21679892..dd12de34 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -1,4 +1,3 @@ -import copy import os os.environ['TOKENIZERS_PARALLELISM'] = 'false' @@ -6,6 +5,7 @@ from .wan_t2v_1_3B import t2v_1_3B from .wan_t2v_14B import t2v_14B + WAN_CONFIGS = { 't2v-14B': t2v_14B, 't2v-1.3B': t2v_1_3B, diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py index c793f7f6..0037887b 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py @@ -2,6 +2,7 @@ from .shared_config import wan_shared_cfg + #------------------------ Wan T2V 14B ------------------------# t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py index c8458ce8..e9e2a451 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py @@ -2,6 +2,7 @@ from .shared_config import wan_shared_cfg + #------------------------ Wan T2V 1.3B ------------------------# t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') diff --git a/dfm/src/megatron/model/wan/inference/utils.py b/dfm/src/megatron/model/wan/inference/utils.py index a57f9bb9..58ddcfa7 100644 --- a/dfm/src/megatron/model/wan/inference/utils.py +++ b/dfm/src/megatron/model/wan/inference/utils.py @@ -7,6 +7,7 @@ import torch import torchvision + __all__ = ['cache_video', 'cache_image', 'str2bool'] diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 00a2a519..b9af64de 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -1,6 +1,5 @@ import torch -from torch.cuda import amp -from megatron.core import parallel_state + class Wan3DRopeEmbeddings(torch.nn.Module): """ diff --git a/dfm/src/megatron/model/wan/utils/utils.py b/dfm/src/megatron/model/wan/utils.py similarity index 100% rename from dfm/src/megatron/model/wan/utils/utils.py rename to dfm/src/megatron/model/wan/utils.py diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index 43d5ff8a..999728d1 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -18,13 +18,15 @@ from typing import Iterable import torch +from megatron.bridge.training.losses import masked_next_token_loss +from megatron.bridge.training.state import GlobalState from megatron.core import parallel_state from megatron.core.models.common.vision_module.vision_module import VisionModule from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.utils import get_model_config + from dfm.src.megatron.model.wan.flow_matching.flow_pipeline import FlowPipeline -from megatron.bridge.training.losses import masked_next_token_loss -from megatron.bridge.training.state import GlobalState + logger = logging.getLogger(__name__) diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 19fe9c2d..306d9d8a 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -15,11 +15,7 @@ import os from typing import List, Optional, Union -from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig -from dfm.src.megatron.model.wan.wan_provider import WanModelProvider import torch -from megatron.core.distributed import DistributedDataParallelConfig - from megatron.bridge.recipes.utils.optimizer_utils import distributed_fused_adam_with_cosine_annealing from megatron.bridge.recipes.utils.tokenizer_utils import DEFAULT_NULL_TOKENIZER_VOCAB_SIZE from megatron.bridge.training.comm_overlap import CommOverlapConfig @@ -28,10 +24,14 @@ ConfigContainer, LoggerConfig, RNGConfig, - TokenizerConfig, + TokenizerConfig, TrainingConfig, ) from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config +from megatron.core.distributed import DistributedDataParallelConfig + +from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig +from dfm.src.megatron.model.wan.wan_provider import WanModelProvider def model_config( diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index b65d8672..1e92a2c6 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -27,7 +27,6 @@ import torch import torch.distributed as dist -from PIL import Image from dfm.src.megatron.model.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline from dfm.src.megatron.model.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS @@ -282,7 +281,7 @@ def generate(args): print("\n\n\n") logging.info( - f"Generating videos ...") + "Generating videos ...") videos = pipeline.generate( prompts=prompts, sizes=[SIZE_CONFIGS[size] for size in size_keys], diff --git a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index 05405964..748c2c53 100644 --- a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import json import pickle from pathlib import Path @@ -22,7 +21,6 @@ import numpy as np import torch import webdataset as wds - from diffusers import AutoencoderKLWan from transformers import AutoTokenizer, UMT5EncoderModel From 6b46a7f6a7b00c7efd98f2c1cdaa78a754177452 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 17:17:22 -0800 Subject: [PATCH 18/80] fix ruff + copyright --- dfm/src/megatron/data/dit/base.py | 7 +++--- .../data/dit/diffusion_energon_datamodule.py | 2 +- .../wan/flow_matching/time_shift_utils.py | 14 +++++++++++- .../wan/inference/configs/shared_config.py | 15 +++++++++++++ .../wan/inference/configs/wan_t2v_14B.py | 14 ++++++++++++ .../wan/inference/configs/wan_t2v_1_3B.py | 14 ++++++++++++ dfm/src/megatron/model/wan/inference/utils.py | 14 ++++++++++++ dfm/src/megatron/model/wan/rope_utils.py | 15 +++++++++++++ dfm/src/megatron/model/wan/utils.py | 22 ++++++++++++++++--- 9 files changed, 108 insertions(+), 9 deletions(-) diff --git a/dfm/src/megatron/data/dit/base.py b/dfm/src/megatron/data/dit/base.py index 413dc686..81219486 100644 --- a/dfm/src/megatron/data/dit/base.py +++ b/dfm/src/megatron/data/dit/base.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy +import logging from typing import Any, Dict, Literal, Optional from megatron.core import parallel_state from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset -from torch.utils.data import DataLoader -from typing_extensions import Self -import logging + + logger = logging.getLogger(__name__) diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index 3f7ea3e0..30e2add1 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -20,10 +20,10 @@ from torch import int_repr +from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from megatron.energon import DefaultTaskEncoder, get_train_dataset -from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule @dataclass(kw_only=True) class DiffusionDataModuleConfig(DatasetProvider): diff --git a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py index 7da8c70e..1ca78f47 100644 --- a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py +++ b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py @@ -1,4 +1,16 @@ -# time_shift_utils.py - Timestep sampling and sigma computation utilities +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import math diff --git a/dfm/src/megatron/model/wan/inference/configs/shared_config.py b/dfm/src/megatron/model/wan/inference/configs/shared_config.py index 52b6d92b..b4aa8e21 100644 --- a/dfm/src/megatron/model/wan/inference/configs/shared_config.py +++ b/dfm/src/megatron/model/wan/inference/configs/shared_config.py @@ -1,3 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch from easydict import EasyDict diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py index 0037887b..9909bb5f 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from easydict import EasyDict from .shared_config import wan_shared_cfg diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py index e9e2a451..2fa292b4 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from easydict import EasyDict from .shared_config import wan_shared_cfg diff --git a/dfm/src/megatron/model/wan/inference/utils.py b/dfm/src/megatron/model/wan/inference/utils.py index 58ddcfa7..199bf968 100644 --- a/dfm/src/megatron/model/wan/inference/utils.py +++ b/dfm/src/megatron/model/wan/inference/utils.py @@ -1,3 +1,17 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import argparse import binascii import os diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index b9af64de..0b6fcf4a 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -1,3 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + import torch diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index 9fc86555..2a1a7025 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -1,10 +1,26 @@ -import torch +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math from typing import Tuple -from torch.distributed import all_gather + import megatron.core.parallel_state as parallel_state -import math +import torch import torch.distributed as dist import transformer_engine_torch as tex +from torch.distributed import all_gather + def grid_sizes_calculation( input_shape: Tuple[int, int, int], # (F_latents, H_latents, W_latents) From c1d89233b3b20e262f52d7b2d63243164b624271 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:19:15 -0800 Subject: [PATCH 19/80] fix Ruff + Lint --- .../data/dit/diffusion_energon_datamodule.py | 5 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 8 +- .../flow_matching/flow_inference_pipeline.py | 180 ++++++++++-------- .../model/wan/flow_matching/flow_pipeline.py | 125 ++++++------ .../wan/flow_matching/time_shift_utils.py | 33 ++-- .../model/wan/inference/configs/__init__.py | 32 ++-- .../wan/inference/configs/shared_config.py | 6 +- .../wan/inference/configs/wan_t2v_14B.py | 10 +- .../wan/inference/configs/wan_t2v_1_3B.py | 10 +- dfm/src/megatron/model/wan/inference/utils.py | 66 +++---- dfm/src/megatron/model/wan/rope_utils.py | 35 ++-- dfm/src/megatron/model/wan/utils.py | 19 +- dfm/src/megatron/model/wan/wan_layer_spec.py | 16 +- dfm/src/megatron/model/wan/wan_model.py | 64 ++++--- dfm/src/megatron/model/wan/wan_provider.py | 3 +- dfm/src/megatron/model/wan/wan_step.py | 15 +- dfm/src/megatron/recipes/wan/wan.py | 10 +- .../megatron/recipes/wan/inference_wan.py | 62 +++--- 18 files changed, 345 insertions(+), 354 deletions(-) diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index 30e2add1..b55109ba 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -14,16 +14,15 @@ # pylint: disable=C0115,C0116,C0301 -from dataclasses import dataclass import logging +from dataclasses import dataclass from typing import Any, Dict, Literal -from torch import int_repr - from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from megatron.energon import DefaultTaskEncoder, get_train_dataset +from torch import int_repr @dataclass(kw_only=True) class DiffusionDataModuleConfig(DatasetProvider): diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 5c217770..1921ee52 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -14,12 +14,12 @@ # pylint: disable=C0115,C0116,C0301 -import torch -import torch.nn.functional as F -from megatron.energon import DefaultTaskEncoder, SkipSample -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify from megatron.core import parallel_state +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import basic_sample_keys, Cooker +import torch +import torch.nn.functional as F def cook(sample: dict) -> dict: diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index adb97802..bd7e4c19 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -68,7 +68,6 @@ def _encode_text( return outputs class FlowInferencePipeline: - def __init__( self, config, @@ -129,7 +128,7 @@ def __init__( ) self.vae_stride = config.vae_stride - self.patch_size = config.patch_size + self.patch_size = config.patch_size self.vae = AutoencoderKLWan.from_pretrained( model_id, subfolder="vae", @@ -139,9 +138,9 @@ def __init__( wan_checkpoint_dir = self._select_checkpoint_dir(checkpoint_dir, checkpoint_step) self.model = self.setup_model_from_checkpoint(wan_checkpoint_dir) - + # if we use context parallelism, we need to set qkv_format to "thd" for context parallelism - self.model.config.qkv_format = "thd" # "sbhd" + self.model.config.qkv_format = "thd" # "sbhd" # set self.sp_size=1 for later use, just to respect the original Wan inference code self.sp_size = 1 @@ -151,7 +150,7 @@ def __init__( self.model.to(self.device) self.sample_neg_prompt = config.sample_neg_prompt - + def setup_model_from_checkpoint(self, checkpoint_dir): provider = WanModelProvider() @@ -163,7 +162,7 @@ def setup_model_from_checkpoint(self, checkpoint_dir): # Once all overrides are set, finalize the model provider to ensure the post initialization logic is run provider.finalize() provider.initialize_model_parallel(seed=0) - + ## Read from megatron checkpoint model = _load_megatron_model( checkpoint_dir, @@ -210,6 +209,7 @@ def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: raise FileNotFoundError( f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + logging.info(f"Auto-selected latest checkpoint: {latest_path}") return latest_path @@ -220,7 +220,7 @@ def forward_pp_step( grid_sizes: list[Tuple[int, int, int]], max_video_seq_len: int, timestep: torch.Tensor, - arg_c: dict, + arg_c: dict, ) -> torch.Tensor: """ Forward pass supporting pipeline parallelism. @@ -232,11 +232,7 @@ def forward_pp_step( # PP=1: no pipeline parallelism if pp_world_size == 1: - noise_pred_pp = self.model( - latent_model_input, - grid_sizes=grid_sizes, - t=timestep, - **arg_c) + noise_pred_pp = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) return noise_pred_pp # PP>1: pipeline parallelism @@ -247,11 +243,7 @@ def forward_pp_step( if is_pp_first: # First stage: compute multimodal + first PP slice, send activations, then receive sampled token - hidden_states = self.model( - latent_model_input, - grid_sizes=grid_sizes, - t=timestep, - **arg_c) + hidden_states = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) send_to_next_pipeline_rank(hidden_states) noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) @@ -265,15 +257,13 @@ def forward_pp_step( device=latent_model_input[0].device, ) recv_from_prev_pipeline_rank_(recv_buffer) - recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + recv_buffer = recv_buffer.to(torch.bfloat16) self.model.set_input_tensor(recv_buffer) - noise_pred_pp = self.model( - latent_model_input, - grid_sizes=grid_sizes, - t=timestep, - **arg_c) + noise_pred_pp = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) - noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous()) + noise_pred_pp = broadcast_from_last_pipeline_stage( + noise_pred_pp_shape, dtype=noise_pred_pp.dtype, tensor=noise_pred_pp.contiguous() + ) return noise_pred_pp # Intermediate stages: recv -> run local slice -> send -> receive broadcast token @@ -283,29 +273,27 @@ def forward_pp_step( device=latent_model_input[0].device, ) recv_from_prev_pipeline_rank_(recv_buffer) - recv_buffer = recv_buffer.to(torch.bfloat16) # ???? + recv_buffer = recv_buffer.to(torch.bfloat16) self.model.set_input_tensor(recv_buffer) - hidden_states = self.model( - latent_model_input, - grid_sizes=grid_sizes, - t=timestep, - **arg_c) + hidden_states = self.model(latent_model_input, grid_sizes=grid_sizes, t=timestep, **arg_c) send_to_next_pipeline_rank(hidden_states) noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) return noise_pred_pp - def generate(self, - prompts, - sizes, - frame_nums, - shift=5.0, - sampling_steps=50, - guide_scale=5.0, - n_prompt="", - seed=-1, - offload_model=True): + def generate( + self, + prompts, + sizes, + frame_nums, + shift=5.0, + sampling_steps=50, + guide_scale=5.0, + n_prompt="", + seed=-1, + offload_model=True, + ): r""" Generates video frames from text prompt using diffusion process. @@ -337,20 +325,30 @@ def generate(self, - H: Frame height (from size) - W: Frame width from size) """ - + # preprocess target_shapes = [] for size, frame_num in zip(sizes, frame_nums): - target_shapes.append((self.vae.config.z_dim, (frame_num - 1) // self.vae_stride[0] + 1, - size[1] // self.vae_stride[1], - size[0] // self.vae_stride[2])) - + target_shapes.append( + ( + self.vae.config.z_dim, + (frame_num - 1) // self.vae_stride[0] + 1, + size[1] // self.vae_stride[1], + size[0] // self.vae_stride[2], + ) + ) max_video_seq_len = 0 seq_lens = [] for target_shape in target_shapes: - seq_len = math.ceil((target_shape[2] * target_shape[3]) / - (self.patch_size[1] * self.patch_size[2]) * - target_shape[1] / self.sp_size) * self.sp_size + seq_len = ( + math.ceil( + (target_shape[2] * target_shape[3]) + / (self.patch_size[1] * self.patch_size[2]) + * target_shape[1] + / self.sp_size + ) + * self.sp_size + ) seq_lens.append(seq_len) max_video_seq_len = max(seq_lens) @@ -360,7 +358,6 @@ def generate(self, seed_g = torch.Generator(device=self.device) seed_g.manual_seed(seed) - ## process context # we implement similar to Wan's diffuser setup # (https://github.com/huggingface/diffusers/blob/0f252be0ed42006c125ef4429156cb13ae6c1d60/src/diffusers/pipelines/wan/pipeline_wan.py#L157) @@ -377,15 +374,17 @@ def generate(self, if offload_model: self.text_encoder.cpu() else: - context = self.text_encoder([prompt], torch.device('cpu'))[0].to(self.device) - context_null = self.text_encoder([n_prompt], torch.device('cpu'))[0].to(self.device) - context_lens.append(context_max_len) # all samples have the same context_max_len + context = self.text_encoder([prompt], torch.device("cpu"))[0].to(self.device) + context_null = self.text_encoder([n_prompt], torch.device("cpu"))[0].to(self.device) + context_lens.append(context_max_len) # all samples have the same context_max_len contexts.append(context) contexts_null.append(context_null) # pad to context_max_len tokens, and stack to a tensor of shape [s, b, hidden] contexts = [F.pad(context, (0, 0, 0, context_max_len - context.shape[0])) for context in contexts] - contexts_null = [F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null] + contexts_null = [ + F.pad(context_null, (0, 0, 0, context_max_len - context_null.shape[0])) for context_null in contexts_null + ] contexts = torch.stack(contexts, dim=1) contexts_null = torch.stack(contexts_null, dim=1) @@ -401,23 +400,26 @@ def generate(self, target_shape[3], dtype=torch.float32, device=self.device, - generator=seed_g) + generator=seed_g, + ) ) # calculate grid_sizes - grid_sizes = [grid_sizes_calculation( - input_shape =u.shape[1:], - patch_size=self.model.patch_size, - ) for u in noises] + grid_sizes = [ + grid_sizes_calculation( + input_shape=u.shape[1:], + patch_size=self.model.patch_size, + ) + for u in noises + ] grid_sizes = torch.tensor(grid_sizes, dtype=torch.long) - @contextmanager def noop_no_sync(): yield - no_sync = getattr(self.model, 'no_sync', noop_no_sync) + no_sync = getattr(self.model, "no_sync", noop_no_sync) # evaluation mode with amp.autocast(dtype=self.param_dtype), torch.no_grad(), no_sync(): @@ -425,9 +427,7 @@ def noop_no_sync(): batch_size_for_schedulers = len(noises) schedulers = [] for _ in range(batch_size_for_schedulers): - base_sched = FlowMatchEulerDiscreteScheduler.from_pretrained( - self.model_id, subfolder="scheduler" - ) + base_sched = FlowMatchEulerDiscreteScheduler.from_pretrained(self.model_id, subfolder="scheduler") s = UniPCMultistepScheduler.from_config(base_sched.config, flow_shift=shift) s.set_timesteps(sampling_steps, device=self.device) @@ -454,13 +454,15 @@ def noop_no_sync(): qkv_format=self.model.config.qkv_format, ), } - - arg_c = {'context': contexts, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} - arg_null = {'context': contexts_null, 'max_seq_len': max_video_seq_len, 'packed_seq_params': packed_seq_params} + arg_c = {"context": contexts, "max_seq_len": max_video_seq_len, "packed_seq_params": packed_seq_params} + arg_null = { + "context": contexts_null, + "max_seq_len": max_video_seq_len, + "packed_seq_params": packed_seq_params, + } for _, t in enumerate(tqdm(timesteps)): - batch_size = len(latents) # patchify latents @@ -471,42 +473,54 @@ def noop_no_sync(): latents[i] = F.pad(latents[i], (0, 0, 0, max_video_seq_len - latents[i].shape[0])) latents = torch.stack(latents, dim=1) - latent_model_input = latents timestep = [t] * batch_size timestep = torch.stack(timestep) self.model.to(self.device) noise_pred_cond = self.forward_pp_step( - latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_c) + latent_model_input, + grid_sizes=grid_sizes, + max_video_seq_len=max_video_seq_len, + timestep=timestep, + arg_c=arg_c, + ) noise_pred_uncond = self.forward_pp_step( - latent_model_input, grid_sizes=grid_sizes, max_video_seq_len=max_video_seq_len, timestep=timestep, arg_c=arg_null) + latent_model_input, + grid_sizes=grid_sizes, + max_video_seq_len=max_video_seq_len, + timestep=timestep, + arg_c=arg_null, + ) # run unpatchify unpatchified_noise_pred_cond = noise_pred_cond - unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd + unpatchified_noise_pred_cond = unpatchified_noise_pred_cond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_cond = unpatchify(unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim, self.patch_size) + unpatchified_noise_pred_cond = unpatchify( + unpatchified_noise_pred_cond, grid_sizes, self.vae.config.z_dim, self.patch_size + ) unpatchified_noise_pred_uncond = noise_pred_uncond - unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd + unpatchified_noise_pred_uncond = unpatchified_noise_pred_uncond.transpose(0, 1) # bring sbhd -> bshd # when unpatchifying, the code will truncate the padded videos into the original video shape, based on the grid_sizes. - unpatchified_noise_pred_uncond = unpatchify(unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim, self.patch_size) + unpatchified_noise_pred_uncond = unpatchify( + unpatchified_noise_pred_uncond, grid_sizes, self.vae.config.z_dim, self.patch_size + ) noise_preds = [] for i in range(batch_size): noise_pred = unpatchified_noise_pred_uncond[i] + guide_scale * ( - unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i]) + unpatchified_noise_pred_cond[i] - unpatchified_noise_pred_uncond[i] + ) noise_preds.append(noise_pred) # step and update latents latents = [] for i in range(batch_size): - temp_x0 =schedulers[i].step( - noise_preds[i].unsqueeze(0), - t, - unpatchified_latents[i].unsqueeze(0), - return_dict=False)[0] + temp_x0 = schedulers[i].step( + noise_preds[i].unsqueeze(0), t, unpatchified_latents[i].unsqueeze(0), return_dict=False + )[0] latents.append(temp_x0.squeeze(0)) x0 = latents @@ -522,9 +536,9 @@ def noop_no_sync(): .view(1, self.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype) ) - latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( - latents.device, latents.dtype - ) + latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( + 1, self.vae.config.z_dim, 1, 1, 1 + ).to(latents.device, latents.dtype) latents = latents / latents_std + latents_mean videos = self.vae.decode(latents).sample else: diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 2bbe5204..49bd5d72 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -22,7 +22,6 @@ class FlowPipeline: - def __init__( self, model_id="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", @@ -33,10 +32,9 @@ def __init__( """ self.pipe = WanPipeline.from_pretrained(model_id, vae=None, torch_dtype=torch.float32, text_encoder=None) - def training_step( - self, - model, + self, + model, data_batch: dict[str, torch.Tensor], # Flow matching parameters use_sigma_noise: bool = True, @@ -55,13 +53,13 @@ def training_step( 3. Compute the loss based on the difference between the predictions and target. """ - video_latents = data_batch['video_latents'] - max_video_seq_len = data_batch['max_video_seq_len'] - context_embeddings = data_batch['context_embeddings'] - loss_mask = data_batch['loss_mask'] - grid_sizes = data_batch['grid_sizes'] - packed_seq_params = data_batch['packed_seq_params'] - video_metadata = data_batch['video_metadata'] + video_latents = data_batch["video_latents"] + max_video_seq_len = data_batch["max_video_seq_len"] + context_embeddings = data_batch["context_embeddings"] + loss_mask = data_batch["loss_mask"] + grid_sizes = data_batch["grid_sizes"] + packed_seq_params = data_batch["packed_seq_params"] + video_metadata = data_batch["video_metadata"] self.model = model @@ -75,12 +73,12 @@ def training_step( # ======================================================================== # Flow Matching Timestep Sampling # ======================================================================== - + num_train_timesteps = self.pipe.scheduler.config.num_train_timesteps - + if use_sigma_noise: use_uniform = torch.rand(1).item() < mix_uniform_ratio - + if use_uniform or timestep_sampling == "uniform": # Pure uniform: u ~ U(0, 1) u = torch.rand(size=(batch_size,), device=device) @@ -94,12 +92,12 @@ def training_step( logit_std=logit_std, ).to(device) sampling_method = timestep_sampling - + # Apply flow shift: σ = shift/(shift + (1/u - 1)) u_clamped = torch.clamp(u, min=1e-5) # Avoid division by zero sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) sigma = torch.clamp(sigma, 0.0, 1.0) - + else: # Simple uniform without shift u = torch.rand(size=(batch_size,), device=device) @@ -119,33 +117,33 @@ def training_step( sample_noise = torch.randn( 1, in_channels, - grid_size[0]*patch_temporal, - grid_size[1]*patch_spatial, - grid_size[2]*patch_spatial, + grid_size[0] * patch_temporal, + grid_size[1] * patch_spatial, + grid_size[2] * patch_spatial, dtype=torch.float32, device=video_latents.device, ) - sample_noise = patchify(sample_noise, (patch_temporal, patch_spatial, patch_spatial))[0] # shape [noise_seq, c * ( pF * pH * pW)] + sample_noise = patchify(sample_noise, (patch_temporal, patch_spatial, patch_spatial))[ + 0 + ] # shape [noise_seq, c * ( pF * pH * pW)] # because video_latents might be padded, we need to make sure noise also be padded to have the same shape noise_seq = sample_noise.shape[0] video_seq = video_latents.shape[0] if noise_seq < video_seq: pad_len = video_seq - noise_seq - pad = torch.zeros((pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype) + pad = torch.zeros( + (pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype + ) sample_noise = torch.cat([sample_noise, pad], dim=0) noise.append(sample_noise) - noise = torch.stack(noise, dim=1) # shape [noise_seq, batch_size, c * ( pF * pH * pW)] - + noise = torch.stack(noise, dim=1) # shape [noise_seq, batch_size, c * ( pF * pH * pW)] # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) # x_t = (1 - σ) * x_0 + σ * ε sigma_reshaped = sigma.view(1, batch_size, 1) - noisy_latents = ( - (1.0 - sigma_reshaped) * video_latents.float() - + sigma_reshaped * noise - ) - + noisy_latents = (1.0 - sigma_reshaped) * video_latents.float() + sigma_reshaped * noise + # Timesteps for model [0, 1000] timesteps = sigma * num_train_timesteps @@ -161,13 +159,31 @@ def training_step( # ======================================================================== # Split accross context parallelism # ======================================================================== - + if parallel_state.get_context_parallel_world_size() > 1: - video_latents = thd_split_inputs_cp(video_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) - noisy_latents = thd_split_inputs_cp(noisy_latents, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) - noise = thd_split_inputs_cp(noise, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) - context_embeddings = thd_split_inputs_cp(context_embeddings, packed_seq_params['cross_attention'].cu_seqlens_kv, parallel_state.get_context_parallel_group()) - split_loss_mask = thd_split_inputs_cp(loss_mask, packed_seq_params['self_attention'].cu_seqlens_q, parallel_state.get_context_parallel_group()) + video_latents = thd_split_inputs_cp( + video_latents, + packed_seq_params["self_attention"].cu_seqlens_q, + parallel_state.get_context_parallel_group(), + ) + noisy_latents = thd_split_inputs_cp( + noisy_latents, + packed_seq_params["self_attention"].cu_seqlens_q, + parallel_state.get_context_parallel_group(), + ) + noise = thd_split_inputs_cp( + noise, packed_seq_params["self_attention"].cu_seqlens_q, parallel_state.get_context_parallel_group() + ) + context_embeddings = thd_split_inputs_cp( + context_embeddings, + packed_seq_params["cross_attention"].cu_seqlens_kv, + parallel_state.get_context_parallel_group(), + ) + split_loss_mask = thd_split_inputs_cp( + loss_mask, + packed_seq_params["self_attention"].cu_seqlens_q, + parallel_state.get_context_parallel_group(), + ) else: video_latents = video_latents noisy_latents = noisy_latents @@ -178,40 +194,35 @@ def training_step( # ======================================================================== # Forward Pass # ======================================================================== - + if parallel_state.is_pipeline_last_stage(): - model_pred = self.model( - x = noisy_latents, - grid_sizes = grid_sizes, - t = timesteps, - context = context_embeddings, - max_seq_len = max_video_seq_len, + x=noisy_latents, + grid_sizes=grid_sizes, + t=timesteps, + context=context_embeddings, + max_seq_len=max_video_seq_len, packed_seq_params=packed_seq_params, ) # ======================================================================== # Target: Flow Matching Velocity # ======================================================================== - + # Flow matching target: v = ε - x_0 target = noise - video_latents.float() - + # ======================================================================== # Loss with Flow Weighting # ======================================================================== - - loss = torch.nn.functional.mse_loss( - model_pred.float(), - target.float(), - reduction="none" - ) + + loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") # Flow weight: w = 1 + shift * σ - loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] - loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] + loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] + loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] unweighted_loss = loss - weighted_loss = (loss * loss_weight) # shape [seq_length / cp_size, batch_size, -1] + weighted_loss = loss * loss_weight # shape [seq_length / cp_size, batch_size, -1] # Safety check mean_weighted_loss = weighted_loss.mean() @@ -224,11 +235,11 @@ def training_step( else: hidden_states = self.model( - x = noisy_latents, - grid_sizes = grid_sizes, - t = timesteps, - context = context_embeddings, - max_seq_len = max_video_seq_len, + x=noisy_latents, + grid_sizes=grid_sizes, + t=timesteps, + context=context_embeddings, + max_seq_len=max_video_seq_len, packed_seq_params=packed_seq_params, ) diff --git a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py index 1ca78f47..cb337959 100644 --- a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py +++ b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py @@ -28,7 +28,7 @@ def time_shift( ): """ Convert timesteps to sigmas with sequence-length-aware shifting. - + Args: t: timesteps in range [0, 1] image_seq_len: number of tokens (frames * height * width / patch_size^2) @@ -36,7 +36,7 @@ def time_shift( base_shift: base shift for linear mode max_shift: max shift for linear mode constant: shift value for constant mode (default 3.0 matches Pika) - + Returns: sigma values for noise scheduling """ @@ -44,17 +44,17 @@ def time_shift( # Linear interpolation based on sequence length mu = base_shift + (max_shift - base_shift) * (image_seq_len / 4096) return math.exp(mu) / (math.exp(mu) + (1 / t - 1)) - + elif shift_type == "sqrt": # Square root scaling (Flux-style) # Assuming 128x128 latent space (1024x1024 image) gives mu=3 mu = np.maximum(1.0, np.sqrt(image_seq_len / (128.0 * 128.0)) * 3.0) return mu / (mu + (1 / t - 1)) - + elif shift_type == "constant": # Constant shift (Pika default) return constant / (constant + (1 / t - 1)) - + else: # No shift, return original t return t @@ -69,49 +69,44 @@ def compute_density_for_timestep_sampling( ): """ Sample timesteps from different distributions for better training coverage. - + Args: weighting_scheme: "uniform", "logit_normal", or "mode" batch_size: number of samples to generate logit_mean: mean for logit-normal distribution logit_std: std for logit-normal distribution mode_scale: scale for mode-based sampling - + Returns: Tensor of shape (batch_size,) with values in [0, 1] """ if weighting_scheme == "logit_normal": # SD3-style logit-normal sampling - u = torch.normal( - mean=logit_mean, - std=logit_std, - size=(batch_size,), - device="cpu" - ) + u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu") u = torch.nn.functional.sigmoid(u) - + elif weighting_scheme == "mode": # Mode-based sampling (concentrates around certain timesteps) u = torch.rand(size=(batch_size,), device="cpu") u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) - + else: # Uniform sampling (default) u = torch.rand(size=(batch_size,), device="cpu") - + return u def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): """ Compute loss weights for flow matching based on sigma values. - + Higher sigma (more noise) typically gets higher weight. - + Args: sigma: sigma values in range [0, 1] shift: weight scaling factor - + Returns: Loss weights with same shape as sigma """ diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index dd12de34..a82afb99 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -1,32 +1,32 @@ import os -os.environ['TOKENIZERS_PARALLELISM'] = 'false' - -from .wan_t2v_1_3B import t2v_1_3B from .wan_t2v_14B import t2v_14B +from .wan_t2v_1_3B import t2v_1_3B + +os.environ["TOKENIZERS_PARALLELISM"] = "false" WAN_CONFIGS = { - 't2v-14B': t2v_14B, - 't2v-1.3B': t2v_1_3B, + "t2v-14B": t2v_14B, + "t2v-1.3B": t2v_1_3B, } SIZE_CONFIGS = { - '720*1280': (720, 1280), - '1280*720': (1280, 720), - '480*832': (480, 832), - '832*480': (832, 480), - '1024*1024': (1024, 1024), + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "1024*1024": (1024, 1024), } MAX_AREA_CONFIGS = { - '720*1280': 720 * 1280, - '1280*720': 1280 * 720, - '480*832': 480 * 832, - '832*480': 832 * 480, + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, } SUPPORTED_SIZES = { - 't2v-14B': ('720*1280', '1280*720', '480*832', '832*480'), - 't2v-1.3B': ('480*832', '832*480'), + "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2v-1.3B": ("480*832", "832*480"), } diff --git a/dfm/src/megatron/model/wan/inference/configs/shared_config.py b/dfm/src/megatron/model/wan/inference/configs/shared_config.py index b4aa8e21..7b2e9432 100644 --- a/dfm/src/megatron/model/wan/inference/configs/shared_config.py +++ b/dfm/src/megatron/model/wan/inference/configs/shared_config.py @@ -17,11 +17,11 @@ from easydict import EasyDict -#------------------------ Wan shared config ------------------------# +# ------------------------ Wan shared config ------------------------# wan_shared_cfg = EasyDict() # t5 -wan_shared_cfg.t5_model = 'umt5_xxl' +wan_shared_cfg.t5_model = "umt5_xxl" wan_shared_cfg.t5_dtype = torch.bfloat16 wan_shared_cfg.text_len = 512 @@ -31,4 +31,4 @@ # inference wan_shared_cfg.num_train_timesteps = 1000 wan_shared_cfg.sample_fps = 16 -wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走' +wan_shared_cfg.sample_neg_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py index 9909bb5f..00c82d6d 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py @@ -17,17 +17,17 @@ from .shared_config import wan_shared_cfg -#------------------------ Wan T2V 14B ------------------------# +# ------------------------ Wan T2V 14B ------------------------# -t2v_14B = EasyDict(__name__='Config: Wan T2V 14B') +t2v_14B = EasyDict(__name__="Config: Wan T2V 14B") t2v_14B.update(wan_shared_cfg) # t5 -t2v_14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_14B.t5_tokenizer = 'google/umt5-xxl' +t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_14B.t5_tokenizer = "google/umt5-xxl" # vae -t2v_14B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" t2v_14B.vae_stride = (4, 8, 8) # transformer diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py index 2fa292b4..66b5df12 100644 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py +++ b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py @@ -17,17 +17,17 @@ from .shared_config import wan_shared_cfg -#------------------------ Wan T2V 1.3B ------------------------# +# ------------------------ Wan T2V 1.3B ------------------------# -t2v_1_3B = EasyDict(__name__='Config: Wan T2V 1.3B') +t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B") t2v_1_3B.update(wan_shared_cfg) # t5 -t2v_1_3B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth' -t2v_1_3B.t5_tokenizer = 'google/umt5-xxl' +t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" +t2v_1_3B.t5_tokenizer = "google/umt5-xxl" # vae -t2v_1_3B.vae_checkpoint = 'Wan2.1_VAE.pth' +t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth" t2v_1_3B.vae_stride = (4, 8, 8) # transformer diff --git a/dfm/src/megatron/model/wan/inference/utils.py b/dfm/src/megatron/model/wan/inference/utils.py index 199bf968..6ba11cc3 100644 --- a/dfm/src/megatron/model/wan/inference/utils.py +++ b/dfm/src/megatron/model/wan/inference/utils.py @@ -22,29 +22,21 @@ import torchvision -__all__ = ['cache_video', 'cache_image', 'str2bool'] +__all__ = ["cache_video", "cache_image", "str2bool"] -def rand_name(length=8, suffix=''): - name = binascii.b2a_hex(os.urandom(length)).decode('utf-8') +def rand_name(length=8, suffix=""): + name = binascii.b2a_hex(os.urandom(length)).decode("utf-8") if suffix: - if not suffix.startswith('.'): - suffix = '.' + suffix + if not suffix.startswith("."): + suffix = "." + suffix name += suffix return name -def cache_video(tensor, - save_file=None, - fps=30, - suffix='.mp4', - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): +def cache_video(tensor, save_file=None, fps=30, suffix=".mp4", nrow=8, normalize=True, value_range=(-1, 1), retry=5): # cache file - cache_file = osp.join('/tmp', rand_name( - suffix=suffix)) if save_file is None else save_file + cache_file = osp.join("/tmp", rand_name(suffix=suffix)) if save_file is None else save_file # save to cache error = None @@ -52,17 +44,17 @@ def cache_video(tensor, try: # preprocess tensor = tensor.clamp(min(value_range), max(value_range)) - tensor = torch.stack([ - torchvision.utils.make_grid( - u, nrow=nrow, normalize=normalize, value_range=value_range) - for u in tensor.unbind(2) - ], - dim=1).permute(1, 2, 3, 0) + tensor = torch.stack( + [ + torchvision.utils.make_grid(u, nrow=nrow, normalize=normalize, value_range=value_range) + for u in tensor.unbind(2) + ], + dim=1, + ).permute(1, 2, 3, 0) tensor = (tensor * 255).type(torch.uint8).cpu() # write video - writer = imageio.get_writer( - cache_file, fps=fps, codec='libx264', quality=8) + writer = imageio.get_writer(cache_file, fps=fps, codec="libx264", quality=8) for frame in tensor.numpy(): writer.append_data(frame) writer.close() @@ -71,34 +63,22 @@ def cache_video(tensor, error = e continue else: - print(f'cache_video failed, error: {error}', flush=True) + print(f"cache_video failed, error: {error}", flush=True) return None -def cache_image(tensor, - save_file, - nrow=8, - normalize=True, - value_range=(-1, 1), - retry=5): +def cache_image(tensor, save_file, nrow=8, normalize=True, value_range=(-1, 1), retry=5): # cache file suffix = osp.splitext(save_file)[1] - if suffix.lower() not in [ - '.jpg', '.jpeg', '.png', '.tiff', '.gif', '.webp' - ]: - suffix = '.png' + if suffix.lower() not in [".jpg", ".jpeg", ".png", ".tiff", ".gif", ".webp"]: + suffix = ".png" # save to cache error = None for _ in range(retry): try: tensor = tensor.clamp(min(value_range), max(value_range)) - torchvision.utils.save_image( - tensor, - save_file, - nrow=nrow, - normalize=normalize, - value_range=value_range) + torchvision.utils.save_image(tensor, save_file, nrow=nrow, normalize=normalize, value_range=value_range) return save_file except Exception as e: error = e @@ -124,9 +104,9 @@ def str2bool(v): if isinstance(v, bool): return v v_lower = v.lower() - if v_lower in ('yes', 'true', 't', 'y', '1'): + if v_lower in ("yes", "true", "t", "y", "1"): return True - elif v_lower in ('no', 'false', 'f', 'n', '0'): + elif v_lower in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected (True/False)') + raise argparse.ArgumentTypeError("Boolean value expected (True/False)") diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 0b6fcf4a..86689393 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -24,22 +24,26 @@ class Wan3DRopeEmbeddings(torch.nn.Module): def __init__(self, dim_head, max_position_len): super().__init__() - self.freqs = torch.cat([ - self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), - self.rope_params(max_position_len, 2 * (dim_head // 6)), - self.rope_params(max_position_len, 2 * (dim_head // 6)) - ], dim=1) + self.freqs = torch.cat( + [ + self.rope_params(max_position_len, dim_head - 4 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + self.rope_params(max_position_len, 2 * (dim_head // 6)), + ], + dim=1, + ) def rope_params(self, max_position_len, dim_head, theta=10000): assert dim_head % 2 == 0 freqs = torch.outer( - torch.arange(max_position_len), - 1.0 / torch.pow(theta, - torch.arange(0, dim_head, 2).div(dim_head))) + torch.arange(max_position_len), 1.0 / torch.pow(theta, torch.arange(0, dim_head, 2).div(dim_head)) + ) return freqs def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): - self.freqs = self.freqs.to(device) # ??? do we need to put this here, or the when we move WanModel to device, it also move freqs to device? + self.freqs = self.freqs.to( + device, + ) n, c = n_head, dim_head // 2 @@ -49,11 +53,14 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): freqs_real = [] for i, (f, h, w) in enumerate(grid_sizes.tolist()): seq_len = f * h * w - freqs_real_i = torch.cat([ - freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), - freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), - freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1) - ], dim=-1).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting + freqs_real_i = torch.cat( + [ + freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), + freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), + freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ).reshape(seq_len, 1, 1, -1) # <-- add 1,1 for batch/head broadcasting # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index 2a1a7025..9d719f96 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -55,8 +55,9 @@ def patchify(x, patch_size): for u in x: c, F_pF, H_pH, W_pW = u.shape pF, pH, pW = patch_size - assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, \ - "Spatial dimensions must be divisible by patch size." + assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, ( + "Spatial dimensions must be divisible by patch size." + ) F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW @@ -72,7 +73,9 @@ def patchify(x, patch_size): return out -def unpatchify(x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int]) -> list[torch.Tensor]: +def unpatchify( + x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], out_dim: int, patch_size: Tuple[int, int, int] +) -> list[torch.Tensor]: """ Reconstruct video tensors from patch embeddings into a list of videotensors. @@ -90,8 +93,8 @@ def unpatchify(x: list[torch.Tensor], grid_sizes: list[Tuple[int, int, int]], ou c = out_dim out = [] for u, v in zip(x, grid_sizes): - u = u[:math.prod(v)].view(*v, *patch_size, c) - u = torch.einsum('fhwpqrc->cfphqwr', u) + u = u[: math.prod(v)].view(*v, *patch_size, c) + u = torch.einsum("fhwpqrc->cfphqwr", u) u = u.reshape(c, *[i * j for i, j in zip(v, patch_size)]) out.append(u) return out @@ -146,9 +149,9 @@ def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: return x -def thd_split_inputs_cp(x: torch.Tensor, - cu_seqlens_q_padded: torch.Tensor, - cp_group: dist.ProcessGroup) -> torch.Tensor: +def thd_split_inputs_cp( + x: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, cp_group: dist.ProcessGroup +) -> torch.Tensor: """ Split a THD-packed tensor across CP ranks for inputs shaped [S, B, ...]. diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 23a6f22f..9b47f593 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -1,4 +1,3 @@ - # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -58,9 +57,7 @@ class WanAdaLN(MegatronModule): Adaptive Layer Normalization Module for DiT. """ - def __init__( - self, config: TransformerConfig - ): + def __init__(self, config: TransformerConfig): super().__init__(config) # modulation self.modulation = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5) @@ -123,7 +120,7 @@ def __init__( submodules.norm1, normalized_shape=config.hidden_size, eps=config.layernorm_epsilon, - elementwise_affine=False + elementwise_affine=False, ) self.norm3 = build_module( submodules.norm3, @@ -138,7 +135,6 @@ def __init__( elementwise_affine=False, ) - def forward( self, hidden_states, @@ -182,7 +178,7 @@ def forward( rotary_pos_emb=rope_emb, rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, - packed_seq_params=packed_seq_params['self_attention'], + packed_seq_params=packed_seq_params["self_attention"], ) if bias is not None: attention_output = attention_output + bias @@ -195,7 +191,7 @@ def forward( self.norm3(hidden_states), attention_mask=context_mask, key_value_states=context, - packed_seq_params=packed_seq_params['cross_attention'], + packed_seq_params=packed_seq_params["cross_attention"], ) if bias is not None: attention_output = attention_output + bias @@ -212,7 +208,7 @@ def forward( mlp_output, bias = self.mlp(pre_mlp_layernorm_output_ada) if bias is not None: - mlp_output = mlp_output + bias + mlp_output = mlp_output + bias hidden_states = self.adaLN.scale_add(residual=hidden_states, x=mlp_output, gate=gate_mlp) @@ -244,7 +240,7 @@ def get_wan_block_with_transformer_engine_spec() -> ModuleSpec: core_attention=TEDotProductAttention, linear_proj=TERowParallelLinear, q_layernorm=TENorm, - k_layernorm=TENorm, + k_layernorm=TENorm, ), ), cross_attention=ModuleSpec( diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index b43a6a70..d721aad0 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -28,7 +28,8 @@ from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint from torch import Tensor - +from diffusers.models.embeddings import Timesteps +from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) @@ -43,14 +44,12 @@ def sinusoidal_embedding_1d(dim, position): position = position # calculation - sinusoid = torch.outer( - position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) return x class Head(nn.Module): - def __init__(self, dim, out_dim, patch_size, eps=1e-6): super().__init__() self.dim = dim @@ -73,7 +72,7 @@ def forward(self, x, e): e(Tensor): Shape [B, C] """ e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) - x = (self.head(self.norm(x) * (1 + e[1]) + e[0])) + x = self.head(self.norm(x) * (1 + e[1]) + e[0]) return x @@ -131,21 +130,26 @@ def __init__( # embeddings if self.pre_process: self.patch_embedding = nn.Conv3d( - self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size) + self.in_channels, self.config.hidden_size, kernel_size=self.patch_size, stride=self.patch_size + ) self.text_embedding = nn.Sequential( - nn.Linear(self.config.text_dim, self.config.hidden_size), nn.GELU(approximate='tanh'), - nn.Linear(self.config.hidden_size, self.config.hidden_size)) + nn.Linear(self.config.text_dim, self.config.hidden_size), + nn.GELU(approximate="tanh"), + nn.Linear(self.config.hidden_size, self.config.hidden_size), + ) # As in diffuser's Wan implementation - from diffusers.models.embeddings import Timesteps - from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding self.timesteps_proj = Timesteps(num_channels=self.freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) - self.time_embedder = ParallelTimestepEmbedding(in_channels=self.freq_dim, time_embed_dim=self.config.hidden_size) + self.time_embedder = ParallelTimestepEmbedding( + in_channels=self.freq_dim, time_embed_dim=self.config.hidden_size + ) self.time_proj_act_fn = nn.SiLU() self.time_proj = nn.Linear(self.config.hidden_size, self.config.hidden_size * 6) - self.rope_embeddings = Wan3DRopeEmbeddings(dim_head = self.config.hidden_size // self.num_heads, max_position_len = 1024) + self.rope_embeddings = Wan3DRopeEmbeddings( + dim_head=self.config.hidden_size // self.num_heads, max_position_len=1024 + ) # decoder blocks self.decoder = TransformerBlock( @@ -158,7 +162,7 @@ def __init__( # output head if self.post_process: - self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps = 1e-6) + self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps=1e-6) def forward( @@ -195,16 +199,18 @@ def forward( seq_len, batch_size, _ = x.shape c = self.out_channels pF, pH, pW = self.patch_size - x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] - x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] - x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] - x = x.flatten(1) # output: x.shape [s * b, hidden_size] - x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] + x = x.reshape(seq_len * batch_size, pF, pH, pW, c) # output: x.shape [s * b, pF, pH, pW, c] + x = x.permute(0, 4, 1, 2, 3) # output: x.shape [s * b, c, pF, pH, pW] + x = self.patch_embedding(x) # output: x.shape [s * b, hidden_size, 1, 1, 1] + x = x.flatten(1) # output: x.shape [s * b, hidden_size] + x = x.reshape(seq_len, batch_size, -1) # output: x.shape [s, b, hidden_size] # split sequence for sequence_parallel # TODO: for PP, do we move scatter_to_sequence_parallel_region here or after "x = self.decoder.input_tensor" ??? if self.config.sequence_parallel: - x = tensor_parallel.scatter_to_sequence_parallel_region(x) # output: x.shape [s * b // tp_size, hidden_size] + x = tensor_parallel.scatter_to_sequence_parallel_region( + x + ) # output: x.shape [s * b // tp_size, hidden_size] else: # intermediate stage of pipeline @@ -215,13 +221,14 @@ def forward( e0 = self.time_proj(self.time_proj_act_fn(e)).unflatten(1, (6, self.config.hidden_size)) # context embeddings - context = self.text_embedding(context) # shape [text_len, b, hidden_size] - + context = self.text_embedding(context) # shape [text_len, b, hidden_size] # ============= decoder ============= # calculate rotary pos emb n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads - rotary_pos_emb = self.rope_embeddings(n_head, dim_head, max_seq_len, grid_sizes, t.device) # output: rotary_pos_emb.shape [s, b, 1, dim_head] + rotary_pos_emb = self.rope_embeddings( + n_head, dim_head, max_seq_len, grid_sizes, t.device + ) # output: rotary_pos_emb.shape [s, b, 1, dim_head] # run decoder x = self.decoder( @@ -240,18 +247,17 @@ def forward( return x # head - x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] - x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] - x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] + x = x.transpose(0, 1) # head expects shape [b, s, hidden_size] + x = self.head(x, e) # output: x.shape [b, s, c * pF * pH * pW] + x = x.transpose(0, 1) # reshape back to shape [s, b, c * pF * pH * pW] # gather outputs for sequence_parallel - # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is + # Note: in GPT models, because the vocab projection matrix is ColumnParallelLinear, the sequence is # automatically gathered in ColumnParallelLinear forward pass. # However, in Wan models, we need to gather the outputs manually. if self.config.sequence_parallel: x = tensor_parallel.gather_from_sequence_parallel_region(x) - - return x # output: x.shape [s, b, c * pF * pH * pW] + return x # output: x.shape [s, b, c * pF * pH * pW] def set_input_tensor(self, input_tensor: Tensor) -> None: @@ -270,7 +276,6 @@ def set_input_tensor(self, input_tensor: Tensor) -> None: assert len(input_tensor) == 1, "input_tensor should only be length 1 for gpt/bert" self.decoder.set_input_tensor(input_tensor[0]) - def sharded_state_dict( self, prefix: str = "module.", sharded_offsets: tuple = (), metadata: Optional[Dict] = None ) -> ShardedStateDict: @@ -302,7 +307,6 @@ def sharded_state_dict( return sharded_state_dict - def _set_embedder_weights_replica_id( self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str ) -> None: diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index 95f803e9..d349b6ad 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -27,6 +27,7 @@ logger = logging.getLogger(__name__) + @dataclass class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): crossattn_emb_size: int = 1536 @@ -49,7 +50,7 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): parallel_output: bool = True bf16: bool = False params_dtype: torch.dtype = torch.float32 - qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index 999728d1..a5131a43 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -30,6 +30,7 @@ logger = logging.getLogger(__name__) + def wan_data_step(qkv_format, dataloader_iter): batch = next(iter(dataloader_iter.iterable)) @@ -75,16 +76,14 @@ def __call__( straggler_timer = state.straggler_timer config = get_model_config(model) - + timers("batch-generator", log_level=2).start() qkv_format = getattr(config, "qkv_format", "sbhd") with straggler_timer(bdata=True): - batch = wan_data_step( - qkv_format, data_iterator - ) + batch = wan_data_step(qkv_format, data_iterator) timers("batch-generator").stop() - + check_for_nan_in_loss = state.cfg.rerun_state_machine.check_for_nan_in_loss check_for_spiky_loss = state.cfg.rerun_state_machine.check_for_spiky_loss @@ -105,13 +104,15 @@ def __call__( if "loss_mask" not in batch or batch["loss_mask"] is None: loss_mask = torch.ones_like(loss) loss_mask = batch["loss_mask"] - + loss_function = self._create_loss_function(loss_mask, check_for_nan_in_loss, check_for_spiky_loss) return output_tensor, loss_function - def _create_loss_function(self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool) -> partial: + def _create_loss_function( + self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool + ) -> partial: """Create a partial loss function with the specified configuration. Args: diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 306d9d8a..5f06e9d9 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -136,7 +136,6 @@ def pretrain_config( checkpoint_dir = os.path.join(run_output_dir, "checkpoints") tensorboard_dir = os.path.join(run_output_dir, "tb_logs") - model_cfg = model_config( tensor_parallelism=tensor_parallelism, pipeline_parallelism=pipeline_parallelism, @@ -159,7 +158,6 @@ def pretrain_config( precision_config.grad_reduce_in_fp32 = False - # Config Container cfg = ConfigContainer( model=model_cfg, @@ -184,13 +182,13 @@ def pretrain_config( use_distributed_optimizer=True, use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True ), - dataset= WanDataModuleConfig( + dataset=WanDataModuleConfig( path=None, - seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, - num_workers=10) - , + num_workers=10, + ), logger=LoggerConfig( log_interval=10, tensorboard_dir=tensorboard_dir, diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 1e92a2c6..e1c7776a 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -21,7 +21,7 @@ from datetime import datetime -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") import random @@ -35,12 +35,10 @@ EXAMPLE_PROMPT = { "t2v-1.3B": { - "prompt": - "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2v-14B": { - "prompt": - "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, } @@ -62,75 +60,59 @@ def _validate_args(args): # Frames default handled later; no single frame arg anymore - args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( - 0, sys.maxsize) + args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) # Size check: only validate provided --sizes; default handled later if args.sizes is not None and len(args.sizes) > 0: for s in args.sizes: assert s in SUPPORTED_SIZES[args.task], ( f"Unsupport size {s} for task {args.task}, supported sizes are: " - f"{', '.join(SUPPORTED_SIZES[args.task])}") + f"{', '.join(SUPPORTED_SIZES[args.task])}" def _parse_args(): - parser = argparse.ArgumentParser( - description="Generate a image or video from a text prompt or image using Wan" - ) + parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") parser.add_argument( - "--task", - type=str, - default="t2v-14B", - choices=list(WAN_CONFIGS.keys()), - help="The task to run.") + "--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run." + ) parser.add_argument( "--sizes", type=str, nargs="+", default=None, choices=list(SIZE_CONFIGS.keys()), - help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080" + help="A list of sizes to generate multiple images or videos (WIDTH*HEIGHT). Example: --sizes 1280*720 1920*1080", ) parser.add_argument( "--frame_nums", type=int, nargs="+", default=None, - help="List of frame counts (each should be 4n+1). Broadcasts if single value." + help="List of frame counts (each should be 4n+1). Broadcasts if single value.", ) parser.add_argument( - "--checkpoint_dir", - type=str, - default=None, - help="The path to the main WAN checkpoint directory.") + "--checkpoint_dir", type=str, default=None, help="The path to the main WAN checkpoint directory.", + ) parser.add_argument( "--checkpoint_step", type=int, default=None, help=( "Optional training step to load, e.g. 1800 -> iter_0001800. " - "If not provided, the latest (largest) step in --checkpoint_dir is used.") + "If not provided, the latest (largest) step in --checkpoint_dir is used.", + ), ) parser.add_argument( - "--t5_checkpoint_dir", - type=str, - default=None, - help="Optional directory containing T5 checkpoint/tokenizer") + "--t5_checkpoint_dir", type=str, default=None, help="Optional directory containing T5 checkpoint/tokenizer" + ) parser.add_argument( - "--vae_checkpoint_dir", - type=str, - default=None, - help="Optional directory containing VAE checkpoint") + "--vae_checkpoint_dir", type=str, default=None, help="Optional directory containing VAE checkpoint" + ) parser.add_argument( - "--offload_model", - type=str2bool, - default=None, - help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + "--offload_model", type=str2bool, default=None, help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." ) parser.add_argument( - "--t5_cpu", - action="store_true", - default=False, - help="Whether to place T5 model on CPU.") + "--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.", + ) parser.add_argument( "--save_file", type=str, @@ -141,7 +123,7 @@ def _parse_args(): type=str, nargs="+", default=None, - help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'" + help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'", ) parser.add_argument( "--base_seed", From e8de1ae7381b712c7d0456372118d6e3c7e2b08f Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:23:21 -0800 Subject: [PATCH 20/80] fix Ruff + Lint --- examples/megatron/recipes/wan/inference_wan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index e1c7776a..c67f6fbd 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -67,7 +67,7 @@ def _validate_args(args): assert s in SUPPORTED_SIZES[args.task], ( f"Unsupport size {s} for task {args.task}, supported sizes are: " f"{', '.join(SUPPORTED_SIZES[args.task])}" - + ) def _parse_args(): parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") From 287ad34b20c809d9f4457acce0568a843b6460ed Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:31:54 -0800 Subject: [PATCH 21/80] fix Ruff + Lint --- dfm/src/megatron/data/dit/base.py | 18 ++++---- .../data/dit/diffusion_energon_datamodule.py | 9 ++-- .../data/wan/wan_energon_datamodule.py | 3 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 44 ++++++++++-------- .../megatron/model/common/dit_attention.py | 20 ++++---- .../flow_matching/flow_inference_pipeline.py | 12 +++-- .../model/wan/inference/configs/__init__.py | 2 +- dfm/src/megatron/model/wan/utils.py | 2 +- dfm/src/megatron/model/wan/wan_model.py | 6 +-- dfm/src/megatron/model/wan/wan_step.py | 2 - .../megatron/recipes/wan/inference_wan.py | 46 +++++++++++-------- 11 files changed, 92 insertions(+), 72 deletions(-) diff --git a/dfm/src/megatron/data/dit/base.py b/dfm/src/megatron/data/dit/base.py index 81219486..e59f2a6b 100644 --- a/dfm/src/megatron/data/dit/base.py +++ b/dfm/src/megatron/data/dit/base.py @@ -106,7 +106,7 @@ def __init__( self.multimodal_sample_config = multimodal_sample_config self.shuffle_buffer_size = shuffle_buffer_size self.max_samples_per_sequence = max_samples_per_sequence - self.task_encoder = task_encoder + self.task_encoder = task_encoder self.init_global_step = 0 self.train_dataloader_object = None self.val_dataloader_object = None @@ -116,7 +116,7 @@ def __init__( self.kwargs = kwargs - def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val'): + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): """ Provide the dataset for training or validation. @@ -131,7 +131,7 @@ def datasets_provider(self, worker_config, split: Literal['train', 'val'] = 'val Dataset: The dataset configured for the specified split. """ - if split not in {'train', 'val'}: + if split not in {"train", "val"}: raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") if split == "train": @@ -193,7 +193,7 @@ def train_dataloader(self) -> Any: worker_debug_path=None, worker_log_level=0, ) - train_dataset = self.datasets_provider(worker_config, split='train') + train_dataset = self.datasets_provider(worker_config, split="train") energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) self.train_dataloader_object = energon_dataloader return self.train_dataloader_object @@ -231,7 +231,7 @@ def val_dataloader(self): worker_debug_path=None, worker_log_level=0, ) - val_dataset = self.datasets_provider(worker_config, split='val') + val_dataset = self.datasets_provider(worker_config, split="val") energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) self.val_dataloader_object = energon_loader return self.val_dataloader_object @@ -284,7 +284,7 @@ def state_dict(self) -> Dict[str, Any]: state = [] # Megatron core requires all the states on all the ranks to have same python # type. Energon sends the state as a list logger.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") - return {'dataloader_state': state, 'consumed_samples': consumed_samples} + return {"dataloader_state": state, "consumed_samples": consumed_samples} logger.warning("trainer object not connected to data module object returning empty state") return {} @@ -299,14 +299,14 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: Parameters: state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. """ - if not 'dataloader_state' in state_dict: + if not "dataloader_state" in state_dict: logger.warning( f"Data loader state cannot be resumed from state_dict, " f"it does not have the required key dataloader_state. It has {state_dict.keys()}" ) return - state = state_dict['dataloader_state'] + state = state_dict["dataloader_state"] try: if self.trainer: self.trainer.datamodule.train_dataloader().restore_state_global(state) @@ -330,7 +330,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: logger.warning("Megatron num_microbatches_calculator not found, using Apex version.") from apex.transformer.pipeline_parallel.utils import update_num_microbatches - consumed_samples = state_dict['consumed_samples'] + consumed_samples = state_dict["consumed_samples"] self.data_sampler.init_consumed_samples = consumed_samples self.data_sampler.prev_consumed_samples = consumed_samples logger.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index b55109ba..94902e45 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -18,12 +18,13 @@ from dataclasses import dataclass from typing import Any, Dict, Literal -from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule -from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from megatron.energon import DefaultTaskEncoder, get_train_dataset from torch import int_repr +from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule +from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder + @dataclass(kw_only=True) class DiffusionDataModuleConfig(DatasetProvider): path: str @@ -41,7 +42,8 @@ def __post_init__(self): task_encoder=BasicDiffusionTaskEncoder(seq_length=self.task_encoder_seq_length), micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, - num_workers=self.num_workers) + num_workers=self.num_workers, + ) self.sequence_length = self.dataset.seq_length def build_datasets(self, context: DatasetBuildContext): @@ -49,7 +51,6 @@ def build_datasets(self, context: DatasetBuildContext): - class DiffusionDataModule(EnergonMultiModalDataModule): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 5649c7ca..1b9b7751 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -39,7 +39,8 @@ def __post_init__(self): task_encoder=WanTaskEncoder(seq_length=self.seq_length), micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, - num_workers=self.num_workers) + num_workers=self.num_workers, + ) self.sequence_length = self.dataset.seq_length def build_datasets(self, context: DatasetBuildContext): diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 1921ee52..e8d9b76e 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -14,12 +14,13 @@ # pylint: disable=C0115,C0116,C0301 -from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify -from megatron.core import parallel_state -from megatron.energon import DefaultTaskEncoder, SkipSample -from megatron.energon.task_encoder.cooking import basic_sample_keys, Cooker import torch import torch.nn.functional as F +from megatron.core import parallel_state +from megatron.energon import DefaultTaskEncoder, SkipSample +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify def cook(sample: dict) -> dict: @@ -74,7 +75,6 @@ def __init__( ## actual encode_sample() for production def encode_sample(self, sample: dict) -> dict: - video_latent = sample["pth"] context_embeddings = sample["pickle"] video_metadata = sample["json"] @@ -87,8 +87,8 @@ def encode_sample(self, sample: dict) -> dict: # calculate grid size grid_size = grid_sizes_calculation( - input_shape = video_latent.shape[1:], - patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial), + input_shape=video_latent.shape[1:], + patch_size=(self.patch_temporal, self.patch_spatial, self.patch_spatial), ) ### Note: shape of sample's values @@ -124,7 +124,6 @@ def encode_sample(self, sample: dict) -> dict: def batch(self, samples: list[dict]) -> dict: - # process video latents # do padding here for video latents self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) @@ -146,7 +145,9 @@ def batch(self, samples: list[dict]) -> dict: # because pipeline parallelism requires pre-specified sequence length to create buffer if parallel_state.get_pipeline_model_parallel_world_size() > 1: if max_video_seq_len > self.seq_length: - raise ValueError(f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}") + raise ValueError( + f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}" + ) else: # set max_video_seq_len to DataModule's seq_length max_video_seq_len = self.seq_length @@ -158,7 +159,9 @@ def batch(self, samples: list[dict]) -> dict: assert batch_size == 1, "Error: Batch size must be 1 when using context parallelism" sharding_factor = parallel_state.get_context_parallel_world_size() * 2 max_video_seq_len = ((max_video_seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor - video_latents = [F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents] + video_latents = [ + F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents + ] video_latents = torch.stack(video_latents, dim=1) # pad and stack loss masks to shape [S_max, B] loss_masks = [F.pad(m, (0, max_video_seq_len - m.shape[0])) for m in loss_masks] @@ -172,7 +175,10 @@ def batch(self, samples: list[dict]) -> dict: # pad here for text embeddings context_max_len = 512 context_embeddings = [sample["context_embeddings"] for sample in samples] - context_embeddings = [F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) for context_embedding in context_embeddings] + context_embeddings = [ + F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) + for context_embedding in context_embeddings + ] # calculate all sequence lengths of context embeddings for cross-attention (for videos, we do this after padding to get padded seq len) seq_len_kv = [c.shape[0] for c in context_embeddings] seq_len_kv = torch.tensor(seq_len_kv, dtype=torch.int32) @@ -183,12 +189,12 @@ def batch(self, samples: list[dict]) -> dict: video_metadata = [sample["video_metadata"] for sample in samples] return dict( - video_latents = video_latents, - max_video_seq_len = max_video_seq_len, - grid_sizes = grid_sizes, - context_embeddings = context_embeddings, - loss_mask = loss_masks, - seq_len_q = seq_len_q, - seq_len_kv = seq_len_kv, - video_metadata = video_metadata, + video_latents=video_latents, + max_video_seq_len=max_video_seq_len, + grid_sizes=grid_sizes, + context_embeddings=context_embeddings, + loss_mask=loss_masks, + seq_len_q=seq_len_q, + seq_len_kv=seq_len_kv, + video_metadata=video_metadata, ) \ No newline at end of file diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index f667c0e7..2086055f 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -123,12 +123,10 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): ] if SplitAlongDim is not None: - # [sq, b, ng, (np/ng + 2) * hn] # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list) else: - # [sq, b, ng, (np/ng + 2) * hn] # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] (query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3) @@ -146,10 +144,12 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): key = key.transpose(-2, -1) if self.q_layernorm is not None: - if self.layernorm_across_heads: + if self.layernorm_across_heads: q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] q_flat = self.q_layernorm(q_flat) - query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + query = q_flat.view( + query.size(0), query.size(1), -1, self.hidden_size_per_attention_head + ) # [sq, b, np, hn] else: query = self.q_layernorm(query.contiguous()) @@ -169,8 +169,8 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None): key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) query = query.transpose(-2, -1) key = key.transpose(-2, -1) - query = query.contiguous() # important becuase TE attention expects contiguous tensors - key = key.contiguous() # important becuase TE attention expects contiguous tensors + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors if self.config.test_mode: self.run_realtime_tests() @@ -254,7 +254,9 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): if self.layernorm_across_heads: q_flat = query.reshape(query.size(0), query.size(1), -1).contiguous() # [sq, b, np*hn] q_flat = self.q_layernorm(q_flat) - query = q_flat.view(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) # [sq, b, np, hn] + query = q_flat.view( + query.size(0), query.size(1), -1, self.hidden_size_per_attention_head + ) # [sq, b, np, hn] else: query = self.q_layernorm(query.contiguous()) @@ -274,8 +276,8 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): key = tensor_parallel.scatter_to_tensor_model_parallel_region(key) query = query.transpose(-2, -1) key = key.transpose(-2, -1) - query = query.contiguous() # important becuase TE attention expects contiguous tensors - key = key.contiguous() # important becuase TE attention expects contiguous tensors + query = query.contiguous() # important becuase TE attention expects contiguous tensors + key = key.contiguous() # important becuase TE attention expects contiguous tensors return query, key, value diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index bd7e4c19..42320cdf 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -67,6 +67,7 @@ def _encode_text( outputs = outputs[0, :true_len, :] return outputs + class FlowInferencePipeline: def __init__( self, @@ -200,14 +201,17 @@ def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: pattern = re.compile(r"^iter_(\d+)$") try: _, latest_path = max( - ((int(pattern.match(e.name).group(1)), e.path) - for e in os.scandir(base_dir) - if e.is_dir() and pattern.match(e.name)), + ( + (int(pattern.match(e.name).group(1)), e.path) + for e in os.scandir(base_dir) + if e.is_dir() and pattern.match(e.name) + ), key=lambda x: x[0], ) except ValueError: raise FileNotFoundError( - f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'.") + f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'." + ) logging.info(f"Auto-selected latest checkpoint: {latest_path}") diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index a82afb99..3769af08 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -1,7 +1,7 @@ import os -from .wan_t2v_14B import t2v_14B from .wan_t2v_1_3B import t2v_1_3B +from .wan_t2v_14B import t2v_14B os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index 9d719f96..5a84eb14 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -56,7 +56,7 @@ def patchify(x, patch_size): c, F_pF, H_pH, W_pW = u.shape pF, pH, pW = patch_size assert F_pF % pF == 0 and H_pH % pH == 0 and W_pW % pW == 0, ( - "Spatial dimensions must be divisible by patch size." + "Spatial dimensions must be divisible by patch size." ) F_patches, H_patches, W_patches = F_pF // pF, H_pH // pH, W_pW // pW diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index d721aad0..76aeb004 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -19,6 +19,7 @@ import torch import torch.nn as nn +from diffusers.models.embeddings import Timesteps from megatron.core import parallel_state, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.models.common.vision_module.vision_module import VisionModule @@ -27,9 +28,8 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint -from torch import Tensor -from diffusers.models.embeddings import Timesteps from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding +from torch import Tensor from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) @@ -164,7 +164,6 @@ def __init__( if self.post_process: self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps=1e-6) - def forward( self, x: Tensor, @@ -259,7 +258,6 @@ def forward( x = tensor_parallel.gather_from_sequence_parallel_region(x) return x # output: x.shape [s, b, c * pF * pH * pW] - def set_input_tensor(self, input_tensor: Tensor) -> None: """Sets input tensor to the model. diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index a5131a43..c898221f 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -65,7 +65,6 @@ class WanForwardStep: def __init__(self): self.diffusion_pipeline = FlowPipeline() - def __call__( self, state: GlobalState, data_iterator: Iterable, model: VisionModule ) -> tuple[torch.Tensor, partial]: @@ -109,7 +108,6 @@ def __call__( return output_tensor, loss_function - def _create_loss_function( self, loss_mask: torch.Tensor, check_for_nan_in_loss: bool, check_for_spiky_loss: bool ) -> partial: diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index c67f6fbd..6e0bac00 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -90,7 +90,10 @@ def _parse_args(): help="List of frame counts (each should be 4n+1). Broadcasts if single value.", ) parser.add_argument( - "--checkpoint_dir", type=str, default=None, help="The path to the main WAN checkpoint directory.", + "--checkpoint_dir", + type=str, + default=None, + help="The path to the main WAN checkpoint directory.", ) parser.add_argument( "--checkpoint_step", @@ -108,16 +111,23 @@ def _parse_args(): "--vae_checkpoint_dir", type=str, default=None, help="Optional directory containing VAE checkpoint" ) parser.add_argument( - "--offload_model", type=str2bool, default=None, help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." + "--offload_model", + type=str2bool, + default=None, + help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage.", ) parser.add_argument( - "--t5_cpu", action="store_true", default=False, help="Whether to place T5 model on CPU.", + "--t5_cpu", + action="store_true", + default=False, + help="Whether to place T5 model on CPU.", ) parser.add_argument( "--save_file", type=str, default=None, - help="The file to save the generated image or video to.") + help="The file to save the generated image or video to." + ) parser.add_argument( "--prompts", type=str, @@ -177,7 +187,8 @@ def _init_logging(rank): logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", - handlers=[logging.StreamHandler(stream=sys.stdout)]) + handlers=[logging.StreamHandler(stream=sys.stdout)], + ) else: logging.basicConfig(level=logging.ERROR) @@ -195,11 +206,7 @@ def generate(args): f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: torch.cuda.set_device(local_rank) - dist.init_process_group( - backend="nccl", - init_method="env://", - rank=rank, - world_size=world_size) + dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) cfg = WAN_CONFIGS[args.task] @@ -233,7 +240,8 @@ def generate(args): # Enforce 1:1 pairing across lists assert len(prompts) == len(size_keys) == len(frame_nums), ( f"prompts ({len(prompts)}), sizes ({len(size_keys)}), and frame_nums ({len(frame_nums)}) " - f"must have the same length") + f"must have the same length" + ) logging.info("Creating flow inference pipeline.") pipeline = FlowInferencePipeline( @@ -262,8 +270,7 @@ def generate(args): print("sequence_parallel:", args.sequence_parallel) print("\n\n\n") - logging.info( - "Generating videos ...") + logging.info("Generating videos ...") videos = pipeline.generate( prompts=prompts, sizes=[SIZE_CONFIGS[size] for size in size_keys], @@ -278,10 +285,12 @@ def generate(args): for i, video in enumerate(videos): formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") - formatted_prompt = prompts[i].replace(" ", "_").replace("/", - "_")[:50] - suffix = '.mp4' - formatted_save_file = f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*','x') if sys.platform=='win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + suffix + formatted_prompt = prompts[i].replace(" ", "_").replace("/", "_")[:50] + suffix = ".mp4" + formatted_save_file = ( + f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*', 'x') if sys.platform == 'win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + + suffix + ) if "t2v" in args.task: logging.info(f"Saving generated video to {formatted_save_file}") @@ -291,7 +300,8 @@ def generate(args): fps=cfg.sample_fps, nrow=1, normalize=True, - value_range=(-1, 1)) + value_range=(-1, 1), + ) logging.info("Finished.") From 4464fd2b511e8f1b6a44903e3bfea9ed7ac34a09 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:36:36 -0800 Subject: [PATCH 22/80] fix Ruff + Lint --- dfm/src/megatron/data/dit/base.py | 2 -- dfm/src/megatron/data/dit/diffusion_energon_datamodule.py | 4 ++-- dfm/src/megatron/data/wan/wan_energon_datamodule.py | 2 +- dfm/src/megatron/model/wan/inference/configs/__init__.py | 1 + dfm/src/megatron/model/wan/wan_model.py | 1 + examples/megatron/recipes/wan/prepare_energon_dataset_wan.py | 4 +--- 6 files changed, 6 insertions(+), 8 deletions(-) diff --git a/dfm/src/megatron/data/dit/base.py b/dfm/src/megatron/data/dit/base.py index e59f2a6b..6e12d203 100644 --- a/dfm/src/megatron/data/dit/base.py +++ b/dfm/src/megatron/data/dit/base.py @@ -338,5 +338,3 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: consumed_samples=consumed_samples, consistency_check=False, ) - - diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index 94902e45..ebbece18 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -25,6 +25,7 @@ from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder + @dataclass(kw_only=True) class DiffusionDataModuleConfig(DatasetProvider): path: str @@ -45,12 +46,11 @@ def __post_init__(self): num_workers=self.num_workers, ) self.sequence_length = self.dataset.seq_length - + def build_datasets(self, context: DatasetBuildContext): return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() - class DiffusionDataModule(EnergonMultiModalDataModule): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 1b9b7751..1969fbd8 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -42,6 +42,6 @@ def __post_init__(self): num_workers=self.num_workers, ) self.sequence_length = self.dataset.seq_length - + def build_datasets(self, context: DatasetBuildContext): return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py index 3769af08..3237e0f5 100644 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ b/dfm/src/megatron/model/wan/inference/configs/__init__.py @@ -3,6 +3,7 @@ from .wan_t2v_1_3B import t2v_1_3B from .wan_t2v_14B import t2v_14B + os.environ["TOKENIZERS_PARALLELISM"] = "false" diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 76aeb004..e08c4a37 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -30,6 +30,7 @@ from megatron.core.utils import make_sharded_tensor_for_checkpoint from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding from torch import Tensor + from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) diff --git a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index 748c2c53..98386c5a 100644 --- a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -351,7 +351,7 @@ def main(): for index, meta in enumerate(metadata_list): video_name = meta["file_name"] start_frame = int(meta["start_frame"]) # inclusive - end_frame = int(meta["end_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive caption_text = meta.get("vila_caption", "") video_path = str(video_folder / video_name) @@ -412,5 +412,3 @@ def main(): if __name__ == "__main__": main() - - From 547339a72a3abe434b3fd098fabb77980a4f9f75 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:40:30 -0800 Subject: [PATCH 23/80] fix Ruff + Lint --- dfm/src/megatron/data/dit/base.py | 1 - dfm/src/megatron/data/dit/diffusion_energon_datamodule.py | 1 - dfm/src/megatron/data/wan/wan_taskencoder.py | 3 --- 3 files changed, 5 deletions(-) diff --git a/dfm/src/megatron/data/dit/base.py b/dfm/src/megatron/data/dit/base.py index 6e12d203..f903ee7f 100644 --- a/dfm/src/megatron/data/dit/base.py +++ b/dfm/src/megatron/data/dit/base.py @@ -115,7 +115,6 @@ def __init__( self.num_val_workers = num_val_workers or self.num_workers self.kwargs = kwargs - def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): """ Provide the dataset for training or validation. diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index ebbece18..303dd192 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -49,7 +49,6 @@ def __post_init__(self): def build_datasets(self, context: DatasetBuildContext): return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() - class DiffusionDataModule(EnergonMultiModalDataModule): """ diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index e8d9b76e..5f822dd1 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -122,7 +122,6 @@ def encode_sample(self, sample: dict) -> dict: # video_metadata=video_metadata, # ) - def batch(self, samples: list[dict]) -> dict: # process video latents # do padding here for video latents @@ -136,8 +135,6 @@ def batch(self, samples: list[dict]) -> dict: # calculate all sequence lengths of video latents for self-attention (for videos, we do this before padding to get original seq len) seq_len_q = [v.shape[0] for v in video_latents] seq_len_q = torch.tensor(seq_len_q, dtype=torch.int32) - - # padding and stack video latents max_video_seq_len = max([video_latent.shape[0] for video_latent in video_latents]) # CAVEAT: From 9cd082b2ccae775ff6472c28ca0f1e8e173f5cac Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:52:55 -0800 Subject: [PATCH 24/80] fix Ruff + Lint --- .../data/dit/diffusion_energon_datamodule.py | 1 + .../data/wan/wan_energon_datamodule.py | 2 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- .../megatron/model/common/dit_attention.py | 3 +- .../flow_matching/flow_inference_pipeline.py | 6 -- .../model/wan/flow_matching/flow_pipeline.py | 2 +- .../wan/flow_matching/time_shift_utils.py | 2 +- dfm/src/megatron/model/wan/rope_utils.py | 2 +- dfm/src/megatron/model/wan/wan_provider.py | 2 +- .../megatron/recipes/wan/README_perf_test.md | 2 - .../megatron/recipes/wan/inference_wan.py | 58 +++++-------------- 11 files changed, 22 insertions(+), 60 deletions(-) diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index 303dd192..bd5b7190 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -50,6 +50,7 @@ def __post_init__(self): def build_datasets(self, context: DatasetBuildContext): return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + class DiffusionDataModule(EnergonMultiModalDataModule): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 1969fbd8..fd3c2a01 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -44,4 +44,4 @@ def __post_init__(self): self.sequence_length = self.dataset.seq_length def build_datasets(self, context: DatasetBuildContext): - return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() \ No newline at end of file + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 5f822dd1..5bc387f8 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -194,4 +194,4 @@ def batch(self, samples: list[dict]) -> dict: seq_len_q=seq_len_q, seq_len_kv=seq_len_kv, video_metadata=video_metadata, - ) \ No newline at end of file + ) diff --git a/dfm/src/megatron/model/common/dit_attention.py b/dfm/src/megatron/model/common/dit_attention.py index 2086055f..0e088214 100644 --- a/dfm/src/megatron/model/common/dit_attention.py +++ b/dfm/src/megatron/model/common/dit_attention.py @@ -238,7 +238,7 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): Derives `query` tensor from `hidden_states`, and `key`/`value` tensors from `key_value_states`. """ - + query, key, value = super().get_query_key_value_tensors(hidden_states, key_value_states) # gather query and key heads across TP ranks if self.layernorm_across_heads is True @@ -280,4 +280,3 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states): key = key.contiguous() # important becuase TE attention expects contiguous tensors return query, key, value - diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index 42320cdf..8a4cb685 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -152,7 +152,6 @@ def __init__( self.sample_neg_prompt = config.sample_neg_prompt - def setup_model_from_checkpoint(self, checkpoint_dir): provider = WanModelProvider() provider.tensor_model_parallel_size = self.tensor_parallel_size @@ -213,11 +212,9 @@ def _select_checkpoint_dir(self, base_dir: str, checkpoint_step) -> str: f"No checkpoints found under {base_dir}. Expected subdirectories named like 'iter_0001800'." ) - logging.info(f"Auto-selected latest checkpoint: {latest_path}") return latest_path - def forward_pp_step( self, latent_model_input: torch.Tensor, @@ -285,7 +282,6 @@ def forward_pp_step( noise_pred_pp = broadcast_from_last_pipeline_stage(noise_pred_pp_shape, dtype=torch.float32) return noise_pred_pp - def generate( self, prompts, @@ -392,7 +388,6 @@ def generate( contexts = torch.stack(contexts, dim=1) contexts_null = torch.stack(contexts_null, dim=1) - ## setup noise noises = [] for target_shape in target_shapes: @@ -408,7 +403,6 @@ def generate( ) ) - # calculate grid_sizes grid_sizes = [ grid_sizes_calculation( diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 49bd5d72..25c9b93d 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -243,4 +243,4 @@ def training_step( packed_seq_params=packed_seq_params, ) - return hidden_states \ No newline at end of file + return hidden_states diff --git a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py index cb337959..a221610d 100644 --- a/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py +++ b/dfm/src/megatron/model/wan/flow_matching/time_shift_utils.py @@ -113,4 +113,4 @@ def get_flow_match_loss_weight(sigma: torch.Tensor, shift: float = 3.0): # Flow matching weight: weight = 1 + shift * sigma # This gives more weight to noisier timesteps weight = 1.0 + shift * sigma - return weight \ No newline at end of file + return weight diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 86689393..76e076eb 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -82,4 +82,4 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): # we don't need to scatter the freqs to the context parallel region, # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region - return freqs_real \ No newline at end of file + return freqs_real diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index d349b6ad..befecd8c 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -82,4 +82,4 @@ def provide(self, pre_process=None, post_process=None, vp_stage=None) -> WanMode post_process=parallel_state.is_pipeline_last_stage(), fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, parallel_output=self.parallel_output, - ) \ No newline at end of file + ) diff --git a/examples/megatron/recipes/wan/README_perf_test.md b/examples/megatron/recipes/wan/README_perf_test.md index a89c8d40..f62b88c2 100644 --- a/examples/megatron/recipes/wan/README_perf_test.md +++ b/examples/megatron/recipes/wan/README_perf_test.md @@ -174,5 +174,3 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/megatron/recipes/wan/infe - Replace placeholders (tokens, account, dataset/checkpoint paths) with your own. - Keep the specified commit hashes for compatibility. - `NVTE_FUSED_ATTN=1` enables fused attention where supported. - - diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 6e0bac00..5e8eda84 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -69,6 +69,7 @@ def _validate_args(args): f"{', '.join(SUPPORTED_SIZES[args.task])}" ) + def _parse_args(): parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") parser.add_argument( @@ -123,10 +124,7 @@ def _parse_args(): help="Whether to place T5 model on CPU.", ) parser.add_argument( - "--save_file", - type=str, - default=None, - help="The file to save the generated image or video to." + "--save_file", type=str, default=None, help="The file to save the generated image or video to." ) parser.add_argument( "--prompts", @@ -135,43 +133,16 @@ def _parse_args(): default=None, help="A list of prompts to generate multiple images or videos. Example: --prompts 'a cat' 'a dog'", ) + parser.add_argument("--base_seed", type=int, default=-1, help="The seed to use for generating the image or video.") + parser.add_argument("--sample_steps", type=int, default=None, help="The sampling steps.") parser.add_argument( - "--base_seed", - type=int, - default=-1, - help="The seed to use for generating the image or video.") - parser.add_argument( - "--sample_steps", type=int, default=None, help="The sampling steps.") - parser.add_argument( - "--sample_shift", - type=float, - default=None, - help="Sampling shift factor for flow matching schedulers.") - parser.add_argument( - "--sample_guide_scale", - type=float, - default=5.0, - help="Classifier free guidance scale.") - parser.add_argument( - "--tensor_parallel_size", - type=int, - default=1, - help="Tensor parallel size.") - parser.add_argument( - "--context_parallel_size", - type=int, - default=1, - help="Context parallel size.") - parser.add_argument( - "--pipeline_parallel_size", - type=int, - default=1, - help="Pipeline parallel size.") - parser.add_argument( - "--sequence_parallel", - type=str2bool, - default=False, - help="Sequence parallel.") + "--sample_shift", type=float, default=None, help="Sampling shift factor for flow matching schedulers." + ) + parser.add_argument("--sample_guide_scale", type=float, default=5.0, help="Classifier free guidance scale.") + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Tensor parallel size.") + parser.add_argument("--context_parallel_size", type=int, default=1, help="Context parallel size.") + parser.add_argument("--pipeline_parallel_size", type=int, default=1, help="Pipeline parallel size.") + parser.add_argument("--sequence_parallel", type=str2bool, default=False, help="Sequence parallel.") args = parser.parse_args() @@ -201,9 +172,7 @@ def generate(args): _init_logging(rank) if args.offload_model is None: - args.offload_model = False if world_size > 1 else True - logging.info( - f"offload_model is not specified, set to {args.offload_model}.") + logging.info(f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) @@ -279,7 +248,8 @@ def generate(args): sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, - offload_model=args.offload_model) + offload_model=args.offload_model, + ) if rank == 0: for i, video in enumerate(videos): From 4514eee3b11430f24378531a86d0a7bcf4535601 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:57:34 -0800 Subject: [PATCH 25/80] fix Ruff + Lint --- dfm/src/megatron/model/wan/utils.py | 2 +- examples/megatron/recipes/wan/example_commands.md | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index 5a84eb14..ac4de4e6 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -183,4 +183,4 @@ def thd_split_inputs_cp( # Return to [S, B, ...] x_local = x_local_bs.transpose(0, 1).contiguous() # [S_local, B, ...] - return x_local \ No newline at end of file + return x_local diff --git a/examples/megatron/recipes/wan/example_commands.md b/examples/megatron/recipes/wan/example_commands.md index 8ecb8a97..a8f531ad 100644 --- a/examples/megatron/recipes/wan/example_commands.md +++ b/examples/megatron/recipes/wan/example_commands.md @@ -95,4 +95,3 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan --base_seed 42 \ --sample_steps 50 ``` - From acd430d20f4af132f29637e030129cfb9b72413f Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 5 Nov 2025 19:59:52 -0800 Subject: [PATCH 26/80] fix Ruff + Lint --- dfm/src/megatron/data/dit/diffusion_energon_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index bd5b7190..eaa9aa73 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -50,7 +50,7 @@ def __post_init__(self): def build_datasets(self, context: DatasetBuildContext): return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() - + class DiffusionDataModule(EnergonMultiModalDataModule): """ A PyTorch Lightning DataModule for handling multimodal datasets with images and text. From a147258508fa5b32b7e2d1dd0ea9ff983e463e6d Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 6 Nov 2025 00:01:47 -0800 Subject: [PATCH 27/80] merged main + address comments --- dfm/src/megatron/model/wan/wan_model.py | 6 ------ dfm/src/megatron/model/wan/wan_step.py | 1 - examples/megatron/recipes/wan/inference_wan.py | 9 ++++----- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index e08c4a37..6c96bd54 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -290,12 +290,6 @@ def sharded_state_dict( """ sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) - # DEBUGGING - # for module in ["t_embedder"]: - # for param_name, param in getattr(self, module).named_parameters(): - # weight_key = f"{prefix}{module}.{param_name}" - # self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key) - # DEBUGGING # Ensure replica ids for non-transformer embedder weights include pipeline dimension for module in ["text_embedding", "time_embedding", "time_projection"]: if hasattr(self, module): diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index c898221f..c7a155ca 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -95,7 +95,6 @@ def __call__( else: output_tensor = self.diffusion_pipeline.training_step(model, batch) - # DEBUGGING # TODO: do we need to gather output with sequence or context parallelism here # especially when we have pipeline parallelism diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 5e8eda84..63c3d020 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -230,13 +230,12 @@ def generate(args): pipeline_dtype=torch.float32, ) - # DEBUGGING rank = dist.get_rank() if rank == 0: - print("tensor_parallel_size:", args.tensor_parallel_size) - print("context_parallel_size:", args.context_parallel_size) - print("pipeline_parallel_size:", args.pipeline_parallel_size) - print("sequence_parallel:", args.sequence_parallel) + print("Running inference with tensor_parallel_size:", args.tensor_parallel_size) + print("Running inference with context_parallel_size:", args.context_parallel_size) + print("Running inference with pipeline_parallel_size:", args.pipeline_parallel_size) + print("Running inference with sequence_parallel:", args.sequence_parallel) print("\n\n\n") logging.info("Generating videos ...") From f3828b03076ff16d4957705afedca91fd64491e2 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 6 Nov 2025 00:03:56 -0800 Subject: [PATCH 28/80] remove example_commands.md, Google waits until mid Nov --- .../megatron/recipes/wan/example_commands.md | 97 ------------------- 1 file changed, 97 deletions(-) delete mode 100644 examples/megatron/recipes/wan/example_commands.md diff --git a/examples/megatron/recipes/wan/example_commands.md b/examples/megatron/recipes/wan/example_commands.md deleted file mode 100644 index a8f531ad..00000000 --- a/examples/megatron/recipes/wan/example_commands.md +++ /dev/null @@ -1,97 +0,0 @@ -## WAN example commands - -### Launch container -Example command on EOS cluster: -``` -CONT="nvcr.io/nvidia/nemo:25.09.00" -MOUNT="/lustre/fsw/:/lustre/fsw/" -srun -t 02:00:00 --account coreai_dlalgo_llm -N 1 -J coreai_dlalgo_llm:* -p interactive --exclusive --container-image="${CONT}" --container-mounts="${MOUNT}" --pty bash -``` - - -### Set paths to Megatron-Bridge -Inside container: -```bash -DFM_PATH=/path/to/dfm -MBRIDGE_PATH=/path/to/megatron-bridge -export PYTHONPATH="${DFM_PATH}/.:${MBRIDGE_PATH}/src/.:/opt/NeMo-Framework-Launcher/launcher_scripts" -``` - -### Install dependencies -```bash -pip install --upgrade git+https://github.com/NVIDIA/Megatron-LM.git@ce8185cbbe04f38beb74360e878450f2e8525885 -python3 -m pip install --upgrade diffusers==0.35.1 -pip install easydict -pip install imageio -pip install imageio-ffmpeg -``` - -### Convert checkpoint -See `examples/conversion/convert_wan_checkpoints.py` under `MBRIDGE_PATH` for details. - -### Finetuning -Set environment variables and run training: -```bash -export HF_TOKEN=... -export WANDB_API_KEY=... -EXP_NAME=... -PRETRAINED_CHECKPOINT=/path/to/pretrained_checkpoint -CHECKPOINT_DIR=/path/to/checkpoint_dir -DATASET_PATH=/path/to/dataset -cd ${MBRIDGE_PATH} -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/recipes/wan/pretrain_wan.py \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.context_parallel_size=4 \ - model.sequence_parallel=false \ - model.qkv_format=thd \ - dataset.path=${DATASET_PATH} \ - checkpoint.save=${CHECKPOINT_DIR} \ - checkpoint.load=${PRETRAINED_CHECKPOINT} \ - checkpoint.load_optim=false \ - checkpoint.save_interval=200 \ - optimizer.lr=5e-6 \ - optimizer.min_lr=5e-6 \ - train.eval_iters=0 \ - scheduler.lr_decay_style=constant \ - scheduler.lr_warmup_iters=0 \ - model.seq_length=2048 \ - dataset.seq_length=2048 \ - train.global_batch_size=1 \ - train.micro_batch_size=1 \ - dataset.global_batch_size=1 \ - dataset.micro_batch_size=1 \ - logger.log_interval=1 \ - logger.wandb_project="wan" \ - logger.wandb_exp_name=${EXP_NAME} \ - logger.wandb_save_dir=${CHECKPOINT_DIR} -``` - -### Inference -Download T5 and VAE weights from the [Wan-AI/Wan2.1-T2V-1.3B repository](https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B/tree/main): -- T5: `models_t5_umt5-xxl-enc-bf16.pth`, provider `google` -- VAE: `Wan2.1_VAE.pth` - -Then run: -```bash -export HF_TOKEN=... -CHECKPOINT_DIR=/path/to/checkpoint_dir -T5_DIR=/path/to/t5_weights -VAE_DIR=/path/to/vae_weights -cd ${MBRIDGE_PATH} -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/recipes/wan/inference_wan.py \ - --task t2v-1.3B \ - --sizes 832*480 \ - --checkpoint_dir ${CHECKPOINT_DIR} \ - --checkpoint_step 1000 \ - --t5_checkpoint_dir ${T5_DIR} \ - --vae_checkpoint_dir ${VAE_DIR} \ - --prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \ - --frame_nums 81 \ - --tensor_parallel_size 1 \ - --context_parallel_size 1 \ - --pipeline_parallel_size 1 \ - --sequence_parallel False \ - --base_seed 42 \ - --sample_steps 50 -``` From 47274473c12f2d2d4417195a8028559f34dd7dde Mon Sep 17 00:00:00 2001 From: root Date: Thu, 6 Nov 2025 12:34:34 -0800 Subject: [PATCH 29/80] refactor inference_configs + mockdatamodule --- dfm/src/megatron/data/wan/wan_taskencoder.py | 20 +------ .../flow_matching/flow_inference_pipeline.py | 25 ++++----- .../model/wan/inference/configs/__init__.py | 33 ------------ .../wan/inference/configs/shared_config.py | 34 ------------ .../wan/inference/configs/wan_t2v_14B.py | 43 --------------- .../wan/inference/configs/wan_t2v_1_3B.py | 43 --------------- dfm/src/megatron/model/wan/wan_model.py | 2 +- dfm/src/megatron/recipes/wan/wan.py | 31 ++++++++--- .../megatron/recipes/wan/inference_wan.py | 53 +++++++++++-------- examples/megatron/recipes/wan/pretrain_wan.py | 7 ++- 10 files changed, 77 insertions(+), 214 deletions(-) delete mode 100644 dfm/src/megatron/model/wan/inference/configs/__init__.py delete mode 100644 dfm/src/megatron/model/wan/inference/configs/shared_config.py delete mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py delete mode 100644 dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 5bc387f8..3bfe71bb 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -73,7 +73,7 @@ def __init__( self.patch_temporal = patch_temporal self.seq_length = seq_length - ## actual encode_sample() for production + def encode_sample(self, sample: dict) -> dict: video_latent = sample["pth"] context_embeddings = sample["pickle"] @@ -103,24 +103,6 @@ def encode_sample(self, sample: dict) -> dict: video_metadata=video_metadata, ) - ## mock encode_sample() for debugging - # def encode_sample(self, sample: dict) -> dict: - - # # mock encode sample - # F_latents = 24 - # H_latents = 104 - # W_latents = 60 - # video_latent = torch.tensor(torch.randn(16, F_latents, H_latents, W_latents), dtype=torch.float32) - # grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) - # context_embeddings = torch.tensor(torch.randn(512, 4096), dtype=torch.float32) - # video_metadata = {} - - # return dict( - # video_latent=video_latent, - # grid_size=grid_size, - # context_embeddings=context_embeddings, - # video_metadata=video_metadata, - # ) def batch(self, samples: list[dict]) -> dict: # process video latents diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index 8a4cb685..32893e58 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -71,7 +71,7 @@ def _encode_text( class FlowInferencePipeline: def __init__( self, - config, + inference_cfg, model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", checkpoint_dir=None, checkpoint_step=None, @@ -90,8 +90,8 @@ def __init__( Initializes the FlowInferencePipeline with the given parameters. Args: - config (EasyDict): - Object containing model parameters initialized from config.py + inference_cfg (dict): + Object containing inference configuration. checkpoint_dir (`str`): Path to directory containing model checkpoints t5_checkpoint_dir (`str`, *optional*, defaults to None): @@ -106,7 +106,7 @@ def __init__( Whether to place T5 model on CPU. Only works without t5_fsdp. """ self.device = torch.device(f"cuda:{device_id}") - self.config = config + self.inference_cfg = inference_cfg self.model_id = model_id self.rank = rank self.t5_cpu = t5_cpu @@ -115,25 +115,26 @@ def __init__( self.pipeline_parallel_size = pipeline_parallel_size self.sequence_parallel = sequence_parallel self.pipeline_dtype = pipeline_dtype - self.num_train_timesteps = config.num_train_timesteps - self.param_dtype = config.param_dtype + self.num_train_timesteps = inference_cfg.num_train_timesteps + self.param_dtype = inference_cfg.param_dtype + self.text_len = inference_cfg.text_len self.text_encoder = UMT5EncoderModel.from_pretrained( model_id, subfolder="text_encoder", - torch_dtype=config.t5_dtype, + torch_dtype=inference_cfg.t5_dtype, ) self.tokenizer = AutoTokenizer.from_pretrained( model_id, subfolder="tokenizer", ) - self.vae_stride = config.vae_stride - self.patch_size = config.patch_size + self.vae_stride = inference_cfg.vae_stride + self.patch_size = inference_cfg.patch_size self.vae = AutoencoderKLWan.from_pretrained( model_id, subfolder="vae", - torch_dtype=config.param_dtype, + torch_dtype=inference_cfg.param_dtype, ) self.vae.to(self.device) @@ -150,7 +151,7 @@ def __init__( dist.barrier() self.model.to(self.device) - self.sample_neg_prompt = config.sample_neg_prompt + self.sample_neg_prompt = inference_cfg.sample_neg_prompt def setup_model_from_checkpoint(self, checkpoint_dir): provider = WanModelProvider() @@ -362,7 +363,7 @@ def generate( # we implement similar to Wan's diffuser setup # (https://github.com/huggingface/diffusers/blob/0f252be0ed42006c125ef4429156cb13ae6c1d60/src/diffusers/pipelines/wan/pipeline_wan.py#L157) # in which we pad the text to 512, pass through text encoder, and truncate to the actual tokens, then pad with 0s to 512. - context_max_len = 512 + context_max_len = self.text_len context_lens = [] contexts = [] contexts_null = [] diff --git a/dfm/src/megatron/model/wan/inference/configs/__init__.py b/dfm/src/megatron/model/wan/inference/configs/__init__.py deleted file mode 100644 index 3237e0f5..00000000 --- a/dfm/src/megatron/model/wan/inference/configs/__init__.py +++ /dev/null @@ -1,33 +0,0 @@ -import os - -from .wan_t2v_1_3B import t2v_1_3B -from .wan_t2v_14B import t2v_14B - - -os.environ["TOKENIZERS_PARALLELISM"] = "false" - - -WAN_CONFIGS = { - "t2v-14B": t2v_14B, - "t2v-1.3B": t2v_1_3B, -} - -SIZE_CONFIGS = { - "720*1280": (720, 1280), - "1280*720": (1280, 720), - "480*832": (480, 832), - "832*480": (832, 480), - "1024*1024": (1024, 1024), -} - -MAX_AREA_CONFIGS = { - "720*1280": 720 * 1280, - "1280*720": 1280 * 720, - "480*832": 480 * 832, - "832*480": 832 * 480, -} - -SUPPORTED_SIZES = { - "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), - "t2v-1.3B": ("480*832", "832*480"), -} diff --git a/dfm/src/megatron/model/wan/inference/configs/shared_config.py b/dfm/src/megatron/model/wan/inference/configs/shared_config.py deleted file mode 100644 index 7b2e9432..00000000 --- a/dfm/src/megatron/model/wan/inference/configs/shared_config.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -from easydict import EasyDict - - -# ------------------------ Wan shared config ------------------------# -wan_shared_cfg = EasyDict() - -# t5 -wan_shared_cfg.t5_model = "umt5_xxl" -wan_shared_cfg.t5_dtype = torch.bfloat16 -wan_shared_cfg.text_len = 512 - -# transformer -wan_shared_cfg.param_dtype = torch.bfloat16 - -# inference -wan_shared_cfg.num_train_timesteps = 1000 -wan_shared_cfg.sample_fps = 16 -wan_shared_cfg.sample_neg_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py deleted file mode 100644 index 00c82d6d..00000000 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_14B.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from easydict import EasyDict - -from .shared_config import wan_shared_cfg - - -# ------------------------ Wan T2V 14B ------------------------# - -t2v_14B = EasyDict(__name__="Config: Wan T2V 14B") -t2v_14B.update(wan_shared_cfg) - -# t5 -t2v_14B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" -t2v_14B.t5_tokenizer = "google/umt5-xxl" - -# vae -t2v_14B.vae_checkpoint = "Wan2.1_VAE.pth" -t2v_14B.vae_stride = (4, 8, 8) - -# transformer -t2v_14B.patch_size = (1, 2, 2) -t2v_14B.dim = 5120 -t2v_14B.ffn_dim = 13824 -t2v_14B.freq_dim = 256 -t2v_14B.num_heads = 40 -t2v_14B.num_layers = 40 -t2v_14B.window_size = (-1, -1) -t2v_14B.qk_norm = True -t2v_14B.cross_attn_norm = True -t2v_14B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py b/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py deleted file mode 100644 index 66b5df12..00000000 --- a/dfm/src/megatron/model/wan/inference/configs/wan_t2v_1_3B.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from easydict import EasyDict - -from .shared_config import wan_shared_cfg - - -# ------------------------ Wan T2V 1.3B ------------------------# - -t2v_1_3B = EasyDict(__name__="Config: Wan T2V 1.3B") -t2v_1_3B.update(wan_shared_cfg) - -# t5 -t2v_1_3B.t5_checkpoint = "models_t5_umt5-xxl-enc-bf16.pth" -t2v_1_3B.t5_tokenizer = "google/umt5-xxl" - -# vae -t2v_1_3B.vae_checkpoint = "Wan2.1_VAE.pth" -t2v_1_3B.vae_stride = (4, 8, 8) - -# transformer -t2v_1_3B.patch_size = (1, 2, 2) -t2v_1_3B.dim = 1536 -t2v_1_3B.ffn_dim = 8960 -t2v_1_3B.freq_dim = 256 -t2v_1_3B.num_heads = 12 -t2v_1_3B.num_layers = 30 -t2v_1_3B.window_size = (-1, -1) -t2v_1_3B.qk_norm = True -t2v_1_3B.cross_attn_norm = True -t2v_1_3B.eps = 1e-6 diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 6c96bd54..86784d3a 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -28,7 +28,7 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint -from nemo.collections.diffusion.models.dit.dit_embeddings import ParallelTimestepEmbedding +from dfm.src.megatron.model.common.dit_embeddings import ParallelTimestepEmbedding from torch import Tensor from dfm.src.megatron.model.wan.wan_layer_spec import ( diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 5f06e9d9..6f41e0a6 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -31,6 +31,7 @@ from megatron.core.distributed import DistributedDataParallelConfig from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig +from dfm.src.megatron.data.wan.wan_mock_energon_datamodule import WanMockDataModuleConfig from dfm.src.megatron.model.wan.wan_provider import WanModelProvider @@ -158,6 +159,28 @@ def pretrain_config( precision_config.grad_reduce_in_fp32 = False + if mock: + dataset = WanMockDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + F_latents=3, + H_latents=104, + W_latents=60, + context_seq_len=512, + context_embeddings_dim=4096, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10, + ) + else: + dataset = WanDataModuleConfig( + path=None, + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + num_workers=10, + ) + # Config Container cfg = ConfigContainer( model=model_cfg, @@ -182,13 +205,7 @@ def pretrain_config( use_distributed_optimizer=True, use_megatron_fsdp=use_megatron_fsdp, # need use_distributed_optimizer=True ), - dataset=WanDataModuleConfig( - path=None, - seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs - micro_batch_size=micro_batch_size, - global_batch_size=global_batch_size, - num_workers=10, - ), + dataset=dataset, logger=LoggerConfig( log_interval=10, tensorboard_dir=tensorboard_dir, diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 63c3d020..3568d211 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -19,6 +19,7 @@ import sys import warnings from datetime import datetime +from easydict import EasyDict warnings.filterwarnings("ignore") @@ -29,7 +30,7 @@ import torch.distributed as dist from dfm.src.megatron.model.wan.flow_matching.flow_inference_pipeline import FlowInferencePipeline -from dfm.src.megatron.model.wan.inference.configs import SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from dfm.src.megatron.model.wan.inference import SIZE_CONFIGS, SUPPORTED_SIZES from dfm.src.megatron.model.wan.inference.utils import cache_video, str2bool @@ -45,10 +46,7 @@ def _validate_args(args): # Basic check - assert args.checkpoint_dir is not None, "Please specify the checkpoint directory." - assert args.t5_checkpoint_dir is not None, "Please specify the T5 checkpoint directory." - assert args.vae_checkpoint_dir is not None, "Please specify the VAE checkpoint directory." - assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" + assert args.task in SUPPORTED_SIZES, f"Unsupport task: {args.task}" assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}" # The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks. @@ -73,7 +71,7 @@ def _validate_args(args): def _parse_args(): parser = argparse.ArgumentParser(description="Generate a image or video from a text prompt or image using Wan") parser.add_argument( - "--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run." + "--task", type=str, default="t2v-14B", choices=list(SUPPORTED_SIZES.keys()), help="The task to run." ) parser.add_argument( "--sizes", @@ -170,6 +168,7 @@ def generate(args): local_rank = int(os.getenv("LOCAL_RANK", 0)) device = local_rank _init_logging(rank) + videos = [] if args.offload_model is None: logging.info(f"offload_model is not specified, set to {args.offload_model}.") @@ -177,10 +176,23 @@ def generate(args): torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) - cfg = WAN_CONFIGS[args.task] + inference_cfg = EasyDict({ + # t5 + "t5_dtype": torch.bfloat16, + "text_len": 512, + # vae + "vae_stride": (4, 8, 8), + # transformer + "param_dtype": torch.bfloat16, + "patch_size": (1, 2, 2), + # others + "num_train_timesteps": 1000, + "sample_fps": 16, + "sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + }) logging.info(f"Generation job args: {args}") - logging.info(f"Generation model config: {cfg}") + logging.info(f"Generation model config: {inference_cfg}") if dist.is_initialized(): base_seed = [args.base_seed] if rank == 0 else [None] @@ -214,7 +226,7 @@ def generate(args): logging.info("Creating flow inference pipeline.") pipeline = FlowInferencePipeline( - config=cfg, + inference_cfg=inference_cfg, checkpoint_dir=args.checkpoint_dir, model_id="Wan-AI/Wan2.1-T2V-14B-Diffusers", checkpoint_step=args.checkpoint_step, @@ -250,23 +262,22 @@ def generate(args): offload_model=args.offload_model, ) - if rank == 0: - for i, video in enumerate(videos): - formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" - formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") - formatted_prompt = prompts[i].replace(" ", "_").replace("/", "_")[:50] - suffix = ".mp4" - formatted_save_file = ( - f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*', 'x') if sys.platform == 'win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" - + suffix - ) + if rank == 0: + for i, video in enumerate(videos): + formatted_experiment_name = (args.save_file) if args.save_file is not None else "DefaultExp" + formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") + formatted_prompt = prompts[i].replace(" ", "_").replace("/", "_")[:50] + suffix = ".mp4" + formatted_save_file = ( + f"{args.task}_{formatted_experiment_name}_videoindex{int(i)}_size{size_keys[i].replace('*', 'x') if sys.platform == 'win32' else size_keys[i]}_{formatted_prompt}_{formatted_time}" + + suffix + ) - if "t2v" in args.task: logging.info(f"Saving generated video to {formatted_save_file}") cache_video( tensor=video[None], save_file=formatted_save_file, - fps=cfg.sample_fps, + fps=inference_cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1), diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 6eae3e31..27b69145 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -86,6 +86,11 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", formatter_class=argparse.RawTextHelpFormatter, ) + parser.add_argument( + "--mock", + action="store_true", + help="Whether to use mock data." + ) parser.add_argument( "--config-file", type=str, @@ -131,7 +136,7 @@ def main() -> None: logger.info("------------------------------------------------------------------") # Load base configuration from the recipe as a Python dataclass - cfg: ConfigContainer = pretrain_config() + cfg: ConfigContainer = pretrain_config(mock=args.mock) logger.info("Loaded base configuration") # Print configuration on rank 0 From 8f49e23d6977656d4ea5a2cc02e9891e088ffaea Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 6 Nov 2025 12:41:47 -0800 Subject: [PATCH 30/80] add dit_embeddings.py --- .../data/wan/wan_mock_energon_datamodule.py | 119 ++++++++++++++++++ .../megatron/model/common/dit_embeddings.py | 71 +++++++++++ .../megatron/model/wan/inference/__init__.py | 25 ++++ 3 files changed, 215 insertions(+) create mode 100644 dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py create mode 100644 dfm/src/megatron/model/common/dit_embeddings.py create mode 100644 dfm/src/megatron/model/wan/inference/__init__.py diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py new file mode 100644 index 00000000..675c4992 --- /dev/null +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + +from dataclasses import dataclass +import torch + +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider + +from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder + + +class WanMockTaskEncoder(WanTaskEncoder): + """ + Mock task encoder for Wan dataset. + Attributes: + cookers (list): A list of Cooker objects used for processing. + patch_spatial (int): The spatial patch size. Defaults to 2. + patch_temporal (int): The temporal patch size. Defaults to 1. + seq_length (int): The sequence length. Defaults to 1024. + """ + + F_latents: int + H_latents: int + W_latents: int + context_seq_len: int + context_embeddings_dim: int + + def __init__( + self, + *args, + F_latents: int, + H_latents: int, + W_latents: int, + context_seq_len: int, + context_embeddings_dim: int, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.F_latents = F_latents + self.H_latents = H_latents + self.W_latents = W_latents + self.context_seq_len = context_seq_len + self.context_embeddings_dim = context_embeddings_dim + + # mock encode_sample() for debugging + def encode_sample(self, sample: dict) -> dict: + + # mock encode sample + video_latent = torch.tensor(torch.randn(16, self.F_latents, self.H_latents, self.W_latents), dtype=torch.float32) + grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) + context_embeddings = torch.tensor(torch.randn(self.context_seq_len, self.context_embeddings_dim), dtype=torch.float32) + video_metadata = {} + + # DEBUGGING + output = "" + output += "----------------------------------------\n" + output += f"video_latent.shape: {video_latent.shape}\n" + output += f"grid_size.shape: {grid_size.shape}\n" + output += f"context_embeddings.shape: {context_embeddings.shape}\n" + output += f"video_metadata: {video_metadata}\n" + output += "----------------------------------------\n" + print(output) + + return dict( + video_latent=video_latent, + grid_size=grid_size, + context_embeddings=context_embeddings, + video_metadata=video_metadata, + ) + + +@dataclass(kw_only=True) +class WanMockDataModuleConfig(DatasetProvider): + path: str + seq_length: int + micro_batch_size: int + global_batch_size: int + num_workers: int + dataloader_type: str = "external" + F_latents: int = 3 + H_latents: int = 104 + W_latents: int = 60 + context_seq_len: int = 512 + context_embeddings_dim: int = 4096 + + def __post_init__(self): + self.dataset = DiffusionDataModule( + path=self.path, + seq_length=self.seq_length, + task_encoder=WanMockTaskEncoder( + seq_length=self.seq_length, + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + micro_batch_size=self.micro_batch_size, + global_batch_size=self.global_batch_size, + num_workers=self.num_workers, + ) + self.sequence_length = self.dataset.seq_length + + def build_datasets(self, context: DatasetBuildContext): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() diff --git a/dfm/src/megatron/model/common/dit_embeddings.py b/dfm/src/megatron/model/common/dit_embeddings.py new file mode 100644 index 00000000..4fb17573 --- /dev/null +++ b/dfm/src/megatron/model/common/dit_embeddings.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=C0115,C0116,C0301 + + +import logging + +import torch +from diffusers.models.embeddings import TimestepEmbedding +from megatron.core import parallel_state + + +log = logging.getLogger(__name__) + + +class ParallelTimestepEmbedding(TimestepEmbedding): + """ + ParallelTimestepEmbedding is a subclass of TimestepEmbedding that initializes + the embedding layers with an optional random seed for syncronization. + + Args: + in_channels (int): Number of input channels. + time_embed_dim (int): Dimension of the time embedding. + seed (int, optional): Random seed for initializing the embedding layers. + If None, no specific seed is set. + + Attributes: + linear_1 (nn.Module): First linear layer for the embedding. + linear_2 (nn.Module): Second linear layer for the embedding. + + Methods: + __init__(in_channels, time_embed_dim, seed=None): Initializes the embedding layers. + """ + + def __init__(self, in_channels: int, time_embed_dim: int, seed=None): + super().__init__(in_channels=in_channels, time_embed_dim=time_embed_dim) + if seed is not None: + with torch.random.fork_rng(): + torch.manual_seed(seed) + self.linear_1.reset_parameters() + self.linear_2.reset_parameters() + + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + setattr(self.linear_1.weight, "pipeline_parallel", True) + setattr(self.linear_1.bias, "pipeline_parallel", True) + setattr(self.linear_2.weight, "pipeline_parallel", True) + setattr(self.linear_2.bias, "pipeline_parallel", True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Computes the positional embeddings for the input tensor. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, H, W, C). + + Returns: + torch.Tensor: Positional embeddings of shape (B, T, H, W, C). + """ + return super().forward(x.to(torch.bfloat16, non_blocking=True)) \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/inference/__init__.py b/dfm/src/megatron/model/wan/inference/__init__.py new file mode 100644 index 00000000..e477e446 --- /dev/null +++ b/dfm/src/megatron/model/wan/inference/__init__.py @@ -0,0 +1,25 @@ +import os + + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + + +SIZE_CONFIGS = { + "720*1280": (720, 1280), + "1280*720": (1280, 720), + "480*832": (480, 832), + "832*480": (832, 480), + "1024*1024": (1024, 1024), +} + +MAX_AREA_CONFIGS = { + "720*1280": 720 * 1280, + "1280*720": 1280 * 720, + "480*832": 480 * 832, + "832*480": 832 * 480, +} + +SUPPORTED_SIZES = { + "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), + "t2v-1.3B": ("480*832", "832*480"), +} From 4766b1b56506290ff446fb23daf0ac1f44d0fd04 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Thu, 6 Nov 2025 12:52:49 -0800 Subject: [PATCH 31/80] fix lint ruff --- .../data/wan/wan_mock_energon_datamodule.py | 20 +++++++++--- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 -- .../megatron/model/common/dit_embeddings.py | 2 +- dfm/src/megatron/model/wan/wan_model.py | 2 +- dfm/src/megatron/recipes/wan/wan.py | 4 +-- .../megatron/recipes/wan/inference_wan.py | 31 ++++++++++--------- examples/megatron/recipes/wan/pretrain_wan.py | 6 +--- 7 files changed, 37 insertions(+), 30 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index 675c4992..45430232 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -15,8 +15,8 @@ # pylint: disable=C0115,C0116,C0301 from dataclasses import dataclass -import torch +import torch from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule @@ -58,11 +58,21 @@ def __init__( # mock encode_sample() for debugging def encode_sample(self, sample: dict) -> dict: - # mock encode sample - video_latent = torch.tensor(torch.randn(16, self.F_latents, self.H_latents, self.W_latents), dtype=torch.float32) - grid_size = torch.tensor([video_latent.shape[1] // self.patch_temporal, video_latent.shape[2] // self.patch_spatial, video_latent.shape[3] // self.patch_spatial], dtype=torch.int32) - context_embeddings = torch.tensor(torch.randn(self.context_seq_len, self.context_embeddings_dim), dtype=torch.float32) + video_latent = torch.tensor( + torch.randn(16, self.F_latents, self.H_latents, self.W_latents), dtype=torch.float32 + ) + grid_size = torch.tensor( + [ + video_latent.shape[1] // self.patch_temporal, + video_latent.shape[2] // self.patch_spatial, + video_latent.shape[3] // self.patch_spatial, + ], + dtype=torch.int32, + ) + context_embeddings = torch.tensor( + torch.randn(self.context_seq_len, self.context_embeddings_dim), dtype=torch.float32 + ) video_metadata = {} # DEBUGGING diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 3bfe71bb..bb3025b0 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -73,7 +73,6 @@ def __init__( self.patch_temporal = patch_temporal self.seq_length = seq_length - def encode_sample(self, sample: dict) -> dict: video_latent = sample["pth"] context_embeddings = sample["pickle"] @@ -103,7 +102,6 @@ def encode_sample(self, sample: dict) -> dict: video_metadata=video_metadata, ) - def batch(self, samples: list[dict]) -> dict: # process video latents # do padding here for video latents diff --git a/dfm/src/megatron/model/common/dit_embeddings.py b/dfm/src/megatron/model/common/dit_embeddings.py index 4fb17573..fea46480 100644 --- a/dfm/src/megatron/model/common/dit_embeddings.py +++ b/dfm/src/megatron/model/common/dit_embeddings.py @@ -68,4 +68,4 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Positional embeddings of shape (B, T, H, W, C). """ - return super().forward(x.to(torch.bfloat16, non_blocking=True)) \ No newline at end of file + return super().forward(x.to(torch.bfloat16, non_blocking=True)) diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 86784d3a..b339dbc6 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -28,9 +28,9 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import make_sharded_tensor_for_checkpoint -from dfm.src.megatron.model.common.dit_embeddings import ParallelTimestepEmbedding from torch import Tensor +from dfm.src.megatron.model.common.dit_embeddings import ParallelTimestepEmbedding from dfm.src.megatron.model.wan.wan_layer_spec import ( get_wan_block_with_transformer_engine_spec as WanLayerWithAdaLNspec, ) diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 6f41e0a6..b092b96a 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -162,7 +162,7 @@ def pretrain_config( if mock: dataset = WanMockDataModuleConfig( path=None, - seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs F_latents=3, H_latents=104, W_latents=60, @@ -175,7 +175,7 @@ def pretrain_config( else: dataset = WanDataModuleConfig( path=None, - seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs + seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, num_workers=10, diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 3568d211..b4d1cfb4 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -19,6 +19,7 @@ import sys import warnings from datetime import datetime + from easydict import EasyDict @@ -176,20 +177,22 @@ def generate(args): torch.cuda.set_device(local_rank) dist.init_process_group(backend="nccl", init_method="env://", rank=rank, world_size=world_size) - inference_cfg = EasyDict({ - # t5 - "t5_dtype": torch.bfloat16, - "text_len": 512, - # vae - "vae_stride": (4, 8, 8), - # transformer - "param_dtype": torch.bfloat16, - "patch_size": (1, 2, 2), - # others - "num_train_timesteps": 1000, - "sample_fps": 16, - "sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - }) + inference_cfg = EasyDict( + { + # t5 + "t5_dtype": torch.bfloat16, + "text_len": 512, + # vae + "vae_stride": (4, 8, 8), + # transformer + "param_dtype": torch.bfloat16, + "patch_size": (1, 2, 2), + # others + "num_train_timesteps": 1000, + "sample_fps": 16, + "sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + } + ) logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {inference_cfg}") diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 27b69145..950c4cb6 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -86,11 +86,7 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: description="Pretrain Wan model using Megatron-Bridge with YAML and CLI overrides", formatter_class=argparse.RawTextHelpFormatter, ) - parser.add_argument( - "--mock", - action="store_true", - help="Whether to use mock data." - ) + parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") parser.add_argument( "--config-file", type=str, From c4004ea88eb56389ed9bde623d76b354869869e0 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 7 Nov 2025 14:27:13 -0800 Subject: [PATCH 32/80] add 'average_gradients_across_tp_domain' to torch.nn for when running sequence_parallelism --- .../data/wan/wan_mock_energon_datamodule.py | 10 --------- dfm/src/megatron/model/wan/wan_layer_spec.py | 12 +++++++++++ dfm/src/megatron/model/wan/wan_model.py | 21 +++++++++++++++++++ 3 files changed, 33 insertions(+), 10 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index 45430232..7ae7cac3 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -75,16 +75,6 @@ def encode_sample(self, sample: dict) -> dict: ) video_metadata = {} - # DEBUGGING - output = "" - output += "----------------------------------------\n" - output += f"video_latent.shape: {video_latent.shape}\n" - output += f"grid_size.shape: {grid_size.shape}\n" - output += f"context_embeddings.shape: {context_embeddings.shape}\n" - output += f"video_metadata: {video_metadata}\n" - output += "----------------------------------------\n" - print(output) - return dict( video_latent=video_latent, grid_size=grid_size, diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 9b47f593..053e3e94 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -135,6 +135,18 @@ def __init__( elementwise_affine=False, ) + # set attributes "average_gradients_across_tp_domain" for nn.Parameter objects + # this is used for gradient averaging across TP domain with sequence parallelism + self._mark_trainable_params_for_tp_grad_avg([self.norm3, self.adaLN]) + + def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) -> None: + """Mark selected modules' trainable parameters to average gradients across TP domain.""" + target_modules = modules if modules is not None else [self] + for module in target_modules: + for _name, param in module.named_parameters(recurse=True): + if isinstance(param, nn.Parameter) and param.requires_grad: + setattr(param, "average_gradients_across_tp_domain", True) + def forward( self, hidden_states, diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index b339dbc6..f39e6b61 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -165,6 +165,19 @@ def __init__( if self.post_process: self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps=1e-6) + + # set attributes "average_gradients_across_tp_domain" for nn.Parameter objects + # this is used for gradient averaging across TP domain with sequence parallelism + self._mark_trainable_params_for_tp_grad_avg( + [ + self.patch_embedding, + self.text_embedding, + self.time_embedder, + self.time_proj, + self.head, + ] + ) + def forward( self, x: Tensor, @@ -300,6 +313,14 @@ def sharded_state_dict( return sharded_state_dict + def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) -> None: + """Mark selected modules' trainable parameters to average gradients across TP domain.""" + target_modules = modules if modules is not None else [self] + for module in target_modules: + for _name, param in module.named_parameters(recurse=True): + if isinstance(param, nn.Parameter) and param.requires_grad: + setattr(param, "average_gradients_across_tp_domain", True) + def _set_embedder_weights_replica_id( self, tensor: Tensor, sharded_state_dict: ShardedStateDict, embedder_weight_key: str ) -> None: From e332cb2482317af3277a13e14994f6cdae0cc983 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 7 Nov 2025 14:54:58 -0800 Subject: [PATCH 33/80] add english negative prompt --- .../wan/flow_matching/flow_inference_pipeline.py | 2 +- examples/megatron/recipes/wan/inference_wan.py | 3 ++- pyproject.toml | 14 ++------------ 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index 32893e58..e246b48c 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -151,7 +151,7 @@ def __init__( dist.barrier() self.model.to(self.device) - self.sample_neg_prompt = inference_cfg.sample_neg_prompt + self.sample_neg_prompt = inference_cfg.english_sample_neg_prompt def setup_model_from_checkpoint(self, checkpoint_dir): provider = WanModelProvider() diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index b4d1cfb4..7392e9b7 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -190,7 +190,8 @@ def generate(args): # others "num_train_timesteps": 1000, "sample_fps": 16, - "sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "chinese_sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "english_sample_neg_prompt": "Bright and vivid tones, overexposed, static, blurry details, subtitles, style, artwork, painting, image, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, defective, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, misshapen limbs, fused fingers, motionless image, messy background, three legs, crowded background, walking backward.", } ) diff --git a/pyproject.toml b/pyproject.toml index 2cdb905e..ee72098a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,21 +55,11 @@ classifiers = [ "Topic :: Utilities", ] dependencies = [ + "nemo-automodel @ git+https://github.com/NVIDIA-NeMo/Automodel@main", "diffusers==0.35.1", "easydict", - "imageio", - "imageio-ffmpeg", -] - -[build-system] -requires = ["setuptools>=61"] -build-backend = "setuptools.build_meta" - -[dependency-groups] -automodel = [ - "nemo-automodel @ git+https://github.com/NVIDIA-NeMo/Automodel@main", - "diffusers", "ftfy", + "imageio", "imageio-ffmpeg", "opencv-python-headless==4.10.0.84", ] From bc0372767fc63edd3ebcbd56b1323c5fe6179a4c Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 7 Nov 2025 14:57:09 -0800 Subject: [PATCH 34/80] fix ruff lint --- dfm/src/megatron/model/wan/wan_model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index f39e6b61..3154affd 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -165,7 +165,6 @@ def __init__( if self.post_process: self.head = Head(self.config.hidden_size, self.out_channels, self.patch_size, eps=1e-6) - # set attributes "average_gradients_across_tp_domain" for nn.Parameter objects # this is used for gradient averaging across TP domain with sequence parallelism self._mark_trainable_params_for_tp_grad_avg( From d7c1acbfc4865064e1bdaddb5adf07e21349e85f Mon Sep 17 00:00:00 2001 From: root Date: Fri, 7 Nov 2025 15:18:32 -0800 Subject: [PATCH 35/80] Update uv.lock for deps: diffusers==0.35.1, easydict, imageio --- uv.lock | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index 7bbc0b45..475d65b5 100644 --- a/uv.lock +++ b/uv.lock @@ -524,7 +524,7 @@ wheels = [ [[package]] name = "diffusers" -version = "0.35.2" +version = "0.35.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -537,9 +537,9 @@ dependencies = [ { name = "requests" }, { name = "safetensors" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/03/68/288ca23c7c05c73e87ffe5efffc282400ac9b017f7a9bb03883f4310ea15/diffusers-0.35.2.tar.gz", hash = "sha256:30ecd552303edfcfe1724573c3918a8462ee3ab4d529bdbd4c0045f763affded", size = 3366711, upload-time = "2025-10-15T04:05:17.213Z" } +sdist = { url = "https://files.pythonhosted.org/packages/49/05/c4c8736c14e0efe9a835fb91c6ff5e1abddf9894a2f2a28fffe6429378a6/diffusers-0.35.1.tar.gz", hash = "sha256:6f4dc0c9d309a4c4914a2179646f2bc801b5e395a43295fff3b5f9dbd3e28fd3", size = 3369127, upload-time = "2025-08-20T04:16:10.668Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/2e/38d9824f8c6bb048c5ba21c6d4da54c29c162a46b58b3ef907a360a76d3e/diffusers-0.35.2-py3-none-any.whl", hash = "sha256:d50d5e74fdd6dcf55e5c1d304bc52cc7c2659abd1752740d736d7b54078b4db5", size = 4121649, upload-time = "2025-10-15T04:05:14.391Z" }, + { url = "https://files.pythonhosted.org/packages/06/a7/c53f294f34d9e1584388721b3d7aa024ea1ac46e86d0c302fc3db40ed960/diffusers-0.35.1-py3-none-any.whl", hash = "sha256:fe29ff10200970c7c5934c6488c213e2a77a03dad5e6fa00bbd8e1d04234cb0e", size = 4121424, upload-time = "2025-08-20T04:16:08.359Z" }, ] [[package]] @@ -569,6 +569,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408, upload-time = "2024-04-23T18:57:14.835Z" }, ] +[[package]] +name = "easydict" +version = "1.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/9f/d18d6b5e19244788a6d09c14a8406376b4f4bfcc008e6d17a4f4c15362e8/easydict-1.13.tar.gz", hash = "sha256:b1135dedbc41c8010e2bc1f77ec9744c7faa42bce1a1c87416791449d6c87780", size = 6809, upload-time = "2024-03-04T12:04:41.251Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/05/ec/fa6963f1198172c2b75c9ab6ecefb3045991f92f75f5eb41b6621b198123/easydict-1.13-py3-none-any.whl", hash = "sha256:6b787daf4dcaf6377b4ad9403a5cee5a86adbc0ca9a5bcf5410e9902002aeac2", size = 6804, upload-time = "2024-03-04T12:04:39.508Z" }, +] + [[package]] name = "einops" version = "0.8.1" @@ -887,6 +896,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "imageio" +version = "2.37.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "numpy", version = "2.3.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "pillow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/6f/606be632e37bf8d05b253e8626c2291d74c691ddc7bcdf7d6aaf33b32f6a/imageio-2.37.2.tar.gz", hash = "sha256:0212ef2727ac9caa5ca4b2c75ae89454312f440a756fcfc8ef1993e718f50f8a", size = 389600, upload-time = "2025-11-04T14:29:39.898Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/fe/301e0936b79bcab4cacc7548bf2853fc28dced0a578bab1f7ef53c9aa75b/imageio-2.37.2-py3-none-any.whl", hash = "sha256:ad9adfb20335d718c03de457358ed69f141021a333c40a53e57273d8a5bd0b9b", size = 317646, upload-time = "2025-11-04T14:29:37.948Z" }, +] + [[package]] name = "imageio-ffmpeg" version = "0.6.0" @@ -1372,7 +1395,9 @@ name = "nemo-vfm" source = { editable = "." } dependencies = [ { name = "diffusers" }, + { name = "easydict" }, { name = "ftfy" }, + { name = "imageio" }, { name = "imageio-ffmpeg" }, { name = "nemo-automodel" }, { name = "opencv-python-headless" }, @@ -1405,8 +1430,10 @@ test = [ [package.metadata] requires-dist = [ - { name = "diffusers" }, + { name = "diffusers", specifier = "==0.35.1" }, + { name = "easydict" }, { name = "ftfy" }, + { name = "imageio" }, { name = "imageio-ffmpeg" }, { name = "nemo-automodel", git = "https://github.com/NVIDIA-NeMo/Automodel?rev=main" }, { name = "opencv-python-headless", specifier = "==4.10.0.84" }, From c5250139b18e279946ebbbe3212c576763dc6644 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 7 Nov 2025 15:48:52 -0800 Subject: [PATCH 36/80] update dfm/src/megatron/data/dit --- dfm/src/megatron/data/dit/__init__.py | 13 - dfm/src/megatron/data/dit/base.py | 339 ------------------ .../data/dit/diffusion_energon_datamodule.py | 17 +- .../data/dit/diffusion_taskencoder.py | 256 ------------- dfm/src/megatron/data/dit/utils.py | 203 ----------- 5 files changed, 13 insertions(+), 815 deletions(-) delete mode 100644 dfm/src/megatron/data/dit/base.py delete mode 100644 dfm/src/megatron/data/dit/diffusion_taskencoder.py delete mode 100644 dfm/src/megatron/data/dit/utils.py diff --git a/dfm/src/megatron/data/dit/__init__.py b/dfm/src/megatron/data/dit/__init__.py index d9155f92..e69de29b 100644 --- a/dfm/src/megatron/data/dit/__init__.py +++ b/dfm/src/megatron/data/dit/__init__.py @@ -1,13 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/dfm/src/megatron/data/dit/base.py b/dfm/src/megatron/data/dit/base.py deleted file mode 100644 index f903ee7f..00000000 --- a/dfm/src/megatron/data/dit/base.py +++ /dev/null @@ -1,339 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import Any, Dict, Literal, Optional - -from megatron.core import parallel_state -from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset - - -logger = logging.getLogger(__name__) - - -class EnergonMultiModalDataModule: - """ - A PyTorch Lightning DataModule for handling multimodal datasets with images and text. - - This data module is designed to work with multimodal datasets that involve both images and text. - It provides a seamless interface to load training and validation data, manage batching, and handle - the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon - framework for efficient data handling in large-scale distributed training. - - Attributes: - path (str): Path to the energon dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int): The maximum sequence length for tokenized text. - micro_batch_size (int): The batch size for training and validation. - num_workers (int): Number of workers for data loading. - pin_memory (bool): Whether to pin memory in the DataLoader. - multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. - task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. - init_global_step (int): The initial global step for the trainer, used for resuming training. - data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. - train_dataloader_object (Optional): The DataLoader object for training data. - val_dataloader_object (Optional): The DataLoader object for validation data. - """ - - def __init__( - self, - path: str, - tokenizer, - image_processor, - seq_length: int = 2048, - micro_batch_size: int = 1, - global_batch_size: int = 1, - num_workers: int = 1, - num_val_workers: int | None = None, - pin_memory: bool = True, - shuffle_buffer_size: int = 100, - max_samples_per_sequence: int | None = None, - multimodal_sample_config: Optional[Any] = None, - task_encoder: Optional[Any] = None, - decoder_seq_length: Optional[int] = None, - packing_buffer_size: Optional[int] = None, - validation_task_encoder: Optional[Any] = None, - **kwargs, - ) -> None: - """ - Initialize the EnergonMultiModalDataModule. - - Parameters: - path (str): Path to the dataset. - tokenizer (Tokenizer): The tokenizer used for processing text. - image_processor (ImageProcessor): The image processor used for preprocessing images. - seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. - micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. - num_workers (int, optional): Number of workers for data loading. Defaults to 1. - num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. - pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. - multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. - Defaults to MultiModalSampleConfig(). - shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. - max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. - Defaults to None (loads the whole tar file at once). - task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. - If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. - decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models - packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. - validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding - and batching samples for validation. Defaults to None and will be the same as task_encoder. - **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon - """ - - super().__init__() - self.path = path - self.tokenizer = tokenizer - self.image_processor = image_processor - self.seq_length = seq_length - self.decoder_seq_length = decoder_seq_length - self.micro_batch_size = micro_batch_size - self.global_batch_size = global_batch_size - self.num_workers = num_workers - self.pin_memory = pin_memory - self.multimodal_sample_config = multimodal_sample_config - self.shuffle_buffer_size = shuffle_buffer_size - self.max_samples_per_sequence = max_samples_per_sequence - self.task_encoder = task_encoder - self.init_global_step = 0 - self.train_dataloader_object = None - self.val_dataloader_object = None - self.packing_buffer_size = packing_buffer_size - self.validation_task_encoder = validation_task_encoder or self.task_encoder - self.num_val_workers = num_val_workers or self.num_workers - self.kwargs = kwargs - - def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): - """ - Provide the dataset for training or validation. - - This method retrieves the dataset for the specified split (either 'train' or 'val') and configures - it according to the worker configuration. - - Parameters: - worker_config: Configuration for the data loader workers. - split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. - - Returns: - Dataset: The dataset configured for the specified split. - """ - - if split not in {"train", "val"}: - raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") - - if split == "train": - task_encoder = self.task_encoder - else: - task_encoder = self.validation_task_encoder - - _dataset = get_train_dataset( - self.path, - batch_size=self.micro_batch_size, - task_encoder=task_encoder, - worker_config=worker_config, - packing_buffer_size=self.packing_buffer_size, - split_part=split, - shuffle_buffer_size=self.shuffle_buffer_size, - max_samples_per_sequence=self.max_samples_per_sequence, - **self.kwargs, - ) - - return _dataset - - def build(self): - return self.train_dataloader(), self.val_dataloader() - - def train_dataloader(self) -> Any: - """ - Initialize and return the training DataLoader. - - This method initializes the DataLoader for the training dataset. It uses the global step - from the trainer to configure the data sampler and ensures that the parallel state is initialized - correctly for distributed training. - - Returns: - TRAIN_DATALOADERS: The DataLoader for the training dataset. - """ - - logger.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") - if self.train_dataloader_object: - return self.train_dataloader_object - if not parallel_state.is_initialized(): - logger.info( - f"Muiltimodal data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - logger.info( - f" Multimodal train dataloader initializing with" - f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " - ) - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) - train_dataset = self.datasets_provider(worker_config, split="train") - energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) - self.train_dataloader_object = energon_dataloader - return self.train_dataloader_object - - def val_dataloader(self): - """ - Initialize and return the validation DataLoader. - - This method initializes the DataLoader for the validation dataset. It ensures that the parallel state - is initialized correctly for distributed training and returns a configured DataLoader object. - - Returns: - EVAL_DATALOADERS: The DataLoader for the validation dataset. - """ - if self.val_dataloader_object: - return self.val_dataloader_object - - if not parallel_state.is_initialized(): - logger.info( - f"Muiltimodal val data loader parallel state is not initialized," - f"using default worker config with no_workers {self.num_workers}" - ) - worker_config = WorkerConfig.default_worker_config(self.num_val_workers) - else: - rank = parallel_state.get_data_parallel_rank() - world_size = parallel_state.get_data_parallel_world_size() - data_parallel_group = parallel_state.get_data_parallel_group() - - logger.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") - worker_config = WorkerConfig( - rank=rank, - world_size=world_size, - num_workers=self.num_workers, - data_parallel_group=data_parallel_group, - worker_debug_path=None, - worker_log_level=0, - ) - val_dataset = self.datasets_provider(worker_config, split="val") - energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) - self.val_dataloader_object = energon_loader - return self.val_dataloader_object - - def test_dataloader(self) -> None: - """ - Return None as test dataset split does not exist. - - This method overrides the test_dataloader method and returns None since the test dataset split - is not defined or used in this module. - - Returns: - None - """ - logger.warning("Multimodal dataloader test dataset split does not exist") - return None - - def state_dict(self) -> Dict[str, Any]: - """ - Save the state of the data module. - - This method is called when saving a checkpoint. It generates and saves the state of the data module, - including the state of the dataloader and the number of consumed samples. - - Returns: - Dict[str, Any]: A dictionary containing the state of the data module. - """ - - if self.trainer: - dataloader_obj = self.trainer.train_dataloader - - state = [] - # All ranks should be zero except the dp rank. - if ( - parallel_state.get_context_parallel_rank() - or parallel_state.get_pipeline_model_parallel_rank() - or parallel_state.get_tensor_model_parallel_rank() - or parallel_state.get_expert_model_parallel_rank() - ) == 0: - # Save_state_global in energon assumes that we call it for only the first rank within each group that - # shares the same dataloader state. By making sure that current rank is the first rank in a model - # parallel group, we ensure this. - state = dataloader_obj.save_state_global(global_dst_rank=0) - - consumed_samples = self.data_sampler.compute_consumed_samples( - self.trainer.global_step - self.init_global_step - ) - - if state is None: - state = [] # Megatron core requires all the states on all the ranks to have same python - # type. Energon sends the state as a list - logger.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") - return {"dataloader_state": state, "consumed_samples": consumed_samples} - - logger.warning("trainer object not connected to data module object returning empty state") - return {} - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: - """ - Load the state of the data module from a checkpoint. - - This method is called when loading a checkpoint. It restores the state of the data module, - including the state of the dataloader and the number of consumed samples. - - Parameters: - state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. - """ - if not "dataloader_state" in state_dict: - logger.warning( - f"Data loader state cannot be resumed from state_dict, " - f"it does not have the required key dataloader_state. It has {state_dict.keys()}" - ) - return - - state = state_dict["dataloader_state"] - try: - if self.trainer: - self.trainer.datamodule.train_dataloader().restore_state_global(state) - logger.info("Multimodal dataloader state restored") - else: - logger.error(f"Cannot restore state from state_dict {state_dict}") - raise ValueError( - "Cannot restore state from state_dict: " - "Is the trainer object is initialized and attached to datamodule???" - ) - except Exception as e: - logger.warning( - f"Failed to dataloader restore state due to [Please ensure you are using same version " - f"of energon while saving and loading, Continuing without restoring data loader] : {e}" - ) - - try: - from megatron.core.num_microbatches_calculator import update_num_microbatches - - except (ImportError, ModuleNotFoundError): - logger.warning("Megatron num_microbatches_calculator not found, using Apex version.") - from apex.transformer.pipeline_parallel.utils import update_num_microbatches - - consumed_samples = state_dict["consumed_samples"] - self.data_sampler.init_consumed_samples = consumed_samples - self.data_sampler.prev_consumed_samples = consumed_samples - logger.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") - update_num_microbatches( - consumed_samples=consumed_samples, - consistency_check=False, - ) diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py index eaa9aa73..4fb5785e 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py @@ -14,6 +14,7 @@ # pylint: disable=C0115,C0116,C0301 + import logging from dataclasses import dataclass from typing import Any, Dict, Literal @@ -22,8 +23,8 @@ from megatron.energon import DefaultTaskEncoder, get_train_dataset from torch import int_repr -from dfm.src.megatron.data.dit.base import EnergonMultiModalDataModule -from dfm.src.megatron.data.dit.diffusion_taskencoder import BasicDiffusionTaskEncoder +from dfm.src.megatron.data.dit.base_energon_datamodule import EnergonMultiModalDataModule +from dfm.src.megatron.data.dit.dit_taskencoder import DiTTaskEncoder @dataclass(kw_only=True) @@ -32,6 +33,7 @@ class DiffusionDataModuleConfig(DatasetProvider): seq_length: int micro_batch_size: int task_encoder_seq_length: int + packing_buffer_size: int global_batch_size: int num_workers: int_repr dataloader_type: str = "external" @@ -40,14 +42,18 @@ def __post_init__(self): self.dataset = DiffusionDataModule( path=self.path, seq_length=self.seq_length, - task_encoder=BasicDiffusionTaskEncoder(seq_length=self.task_encoder_seq_length), + task_encoder=DiTTaskEncoder( + seq_length=self.task_encoder_seq_length, packing_buffer_size=self.packing_buffer_size + ), micro_batch_size=self.micro_batch_size, + packing_buffer_size=self.packing_buffer_size, global_batch_size=self.global_batch_size, num_workers=self.num_workers, ) self.sequence_length = self.dataset.seq_length def build_datasets(self, context: DatasetBuildContext): + # TODO: add validation and test datasets return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() @@ -84,6 +90,7 @@ def __init__( global_batch_size: int = 8, num_workers: int = 1, pin_memory: bool = True, + packing_buffer_size: int = None, task_encoder: DefaultTaskEncoder = None, use_train_split_for_val: bool = False, ) -> None: @@ -108,6 +115,7 @@ def __init__( micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, num_workers=num_workers, + packing_buffer_size=packing_buffer_size, pin_memory=pin_memory, task_encoder=task_encoder, ) @@ -134,6 +142,7 @@ def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val _dataset = get_train_dataset( self.path, batch_size=self.micro_batch_size, + packing_buffer_size=self.packing_buffer_size, task_encoder=self.task_encoder, worker_config=worker_config, max_samples_per_sequence=None, @@ -173,4 +182,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: try: super().load_state_dict(state_dict) except Exception as e: - logging.warning(f"datamodule.load_state_dict failed {e}") + logging.warning(f"datamodule.load_state_dict failed {e}") \ No newline at end of file diff --git a/dfm/src/megatron/data/dit/diffusion_taskencoder.py b/dfm/src/megatron/data/dit/diffusion_taskencoder.py deleted file mode 100644 index 7faa1aaa..00000000 --- a/dfm/src/megatron/data/dit/diffusion_taskencoder.py +++ /dev/null @@ -1,256 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import torch -import torch.nn.functional as F -from einops import rearrange -from megatron.energon import DefaultTaskEncoder, SkipSample -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys - - -def cook(sample: dict) -> dict: - """ - Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. - - Args: - sample (dict): The input dictionary containing the raw sample data. - - Returns: - dict: A new dictionary containing the processed sample data with the following keys: - - All keys from the result of `basic_sample_keys(sample)` - - 'json': The contains meta data like resolution, aspect ratio, fps, etc. - - 'pth': contains video latent tensor - - 'pickle': contains text embeddings - """ - return dict( - **basic_sample_keys(sample), - json=sample[".json"], - pth=sample[".pth"], - pickle=sample[".pickle"], - ) - - -class BasicDiffusionTaskEncoder(DefaultTaskEncoder): - """ - BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. - Attributes: - cookers (list): A list of Cooker objects used for processing. - max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. - text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. - Methods: - __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): - Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. - encode_sample(sample: dict) -> dict: - Encodes a given sample dictionary containing video and text data. - Args: - sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. - Returns: - dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. - Raises: - SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. - """ - - cookers = [ - Cooker(cook), - ] - - def __init__( - self, - *args, - max_frames: int = None, - text_embedding_padding_size: int = 512, - seq_length: int = None, - patch_spatial: int = 2, - patch_temporal: int = 1, - **kwargs, - ): - super().__init__(*args, **kwargs) - self.max_frames = max_frames - self.text_embedding_padding_size = text_embedding_padding_size - self.seq_length = seq_length - self.patch_spatial = patch_spatial - self.patch_temporal = patch_temporal - - def encode_sample(self, sample: dict) -> dict: - video_latent = sample["pth"] - - if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): - raise SkipSample() - if torch.max(torch.abs(video_latent)) > 1e3: - raise SkipSample() - - info = sample["json"] - # remove batch dimension - video_latent = video_latent.squeeze(0) - # print(f"video_latent shape at start: {video_latent.shape}") - C, T, H, W = video_latent.shape - seq_len = ( - video_latent.shape[-1] - * video_latent.shape[-2] - * video_latent.shape[-3] - // self.patch_spatial**2 - // self.patch_temporal - ) - # seq_len = 1536 - is_image = T == 1 - - # print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") - if seq_len > self.seq_length: - print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") - raise SkipSample() - - if self.max_frames is not None: - video_latent = video_latent[:, : self.max_frames, :, :] - - # tpcp_size = parallel_state.get_tensor_model_parallel_world_size() - # if parallel_state.get_context_parallel_world_size() > 1: - # tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 - # if (T * H * W) % tpcp_size != 0: - # warnings.warn(f'skipping {video_latent.shape=} not divisible by {tpcp_size=}') - # raise SkipSample() - # print(f"video_latent shape before rearrange: {video_latent.shape}") - # video_latent shape before rearrange: torch.Size([16, 1, 64, 96]) - video_latent = rearrange( - video_latent, - "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", - ph=self.patch_spatial, - pw=self.patch_spatial, - pt=self.patch_temporal, - ) - # print(f"video_latent shape after rearrange: {video_latent.shape}") - # After reaaranging: video_latent shape after rearrange: torch.Size([1536, 64]) - # convert sample["pickle"] to numpy, and remove batch dimension - sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) - if is_image: - t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) - else: - t5_text_embeddings = torch.from_numpy(sample["pickle"][0]).to(torch.bfloat16) - t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] - - if t5_text_embeddings_seq_length > self.text_embedding_padding_size: - t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] - else: - t5_text_embeddings = F.pad( - t5_text_embeddings, - ( - 0, - 0, - 0, - self.text_embedding_padding_size - t5_text_embeddings_seq_length, - ), - ) - t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) - - if is_image: - h, w = info["image_height"], info["image_width"] - fps = torch.tensor([30] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) - else: - h, w = info["height"], info["width"] - fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) - num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) - image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) - - pos_ids = rearrange( - pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), - "T H W d -> (T H W) d", - ) - - if self.seq_length is not None: - pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) - loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) - loss_mask[:seq_len] = 1 - video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) - else: - loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) - - print(f"Loss mask shape: {loss_mask.shape}") - print(f"video_latent shape final: {video_latent.shape}") - return dict( - video=video_latent, - t5_text_embeddings=t5_text_embeddings, - t5_text_mask=t5_text_mask, - image_size=image_size, - fps=fps, - num_frames=num_frames, - loss_mask=loss_mask, - seq_len_q=torch.tensor(seq_len, dtype=torch.int32), - seq_len_kv=torch.tensor(self.text_embedding_padding_size, dtype=torch.int32), - pos_ids=pos_ids, - latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), - ) - - -class PosID3D: - def __init__(self, *, max_t=32, max_h=128, max_w=128): - self.max_t = max_t - self.max_h = max_h - self.max_w = max_w - self.generate_pos_id() - - def generate_pos_id(self): - self.grid = torch.stack( - torch.meshgrid( - torch.arange(self.max_t, device="cpu"), - torch.arange(self.max_h, device="cpu"), - torch.arange(self.max_w, device="cpu"), - ), - dim=-1, - ) - - def get_pos_id_3d(self, *, t, h, w): - if t > self.max_t or h > self.max_h or w > self.max_w: - self.max_t = max(self.max_t, t) - self.max_h = max(self.max_h, h) - self.max_w = max(self.max_w, w) - self.generate_pos_id() - return self.grid[:t, :h, :w] - - -pos_id_3d = PosID3D() - - -def cook_raw_iamges(sample: dict) -> dict: - """ - Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. - - Args: - sample (dict): The input dictionary containing the raw sample data. - - Returns: - dict: A new dictionary containing the processed sample data with the following keys: - - All keys from the result of `basic_sample_keys(sample)` - - 'jpg': original images - - 'png': contains control images - - 'txt': contains raw text - """ - return dict( - **basic_sample_keys(sample), - images=sample["jpg"], - hint=sample["png"], - txt=sample["txt"], - ) - - -class RawImageDiffusionTaskEncoder(DefaultTaskEncoder): - """ - Dummy task encoder takes raw image input on CrudeDataset. - """ - - cookers = [ - # Cooker(cook), - Cooker(cook_raw_iamges), - ] diff --git a/dfm/src/megatron/data/dit/utils.py b/dfm/src/megatron/data/dit/utils.py deleted file mode 100644 index dbe8ebad..00000000 --- a/dfm/src/megatron/data/dit/utils.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# pylint: disable=C0115,C0116,C0301 - -import numpy as np - - -def minimal_crop(tensor, target_divisor): - """ - Crops the input tensor minimally so that the total number of elements - (T * H * W) is divisible by the specified target_divisor. - - Parameters: - - tensor: NumPy array of shape (C, T, H, W) - - target_divisor: Positive integer specifying the desired divisor - - Returns: - - cropped_tensor: Cropped tensor meeting the divisibility requirement - - Raises: - - ValueError: If it's impossible to meet the divisibility requirement - """ - if not isinstance(target_divisor, int) or target_divisor <= 0: - raise ValueError("target_divisor must be a positive integer greater than zero.") - - C, T, H, W = tensor.shape - total_elements = T * H * W - remainder = total_elements % target_divisor - - if remainder == 0: - return tensor # No cropping needed - - # Elements per unit length in each dimension - elements_per_T = H * W - elements_per_H = T * W - elements_per_W = T * H - - min_elements_removed = None - optimal_deltas = None - - # Limit the search range to avoid unnecessary computations - max_delta_T = min(T - 1, (remainder // elements_per_T) + 1) - max_delta_H = min(H - 1, (remainder // elements_per_H) + 1) - max_delta_W = min(W - 1, (remainder // elements_per_W) + 1) - - for delta_T in range(0, max_delta_T + 1): - for delta_H in range(0, max_delta_H + 1): - for delta_W in range(0, max_delta_W + 1): - if delta_T == delta_H == delta_W == 0: - continue # No cropping - - new_T = T - delta_T - new_H = H - delta_H - new_W = W - delta_W - - if new_T <= 0 or new_H <= 0 or new_W <= 0: - continue # Invalid dimensions - - new_total_elements = new_T * new_H * new_W - if new_total_elements % target_divisor == 0: - elements_removed = delta_T * elements_per_T + delta_H * elements_per_H + delta_W * elements_per_W - if min_elements_removed is None or elements_removed < min_elements_removed: - min_elements_removed = elements_removed - optimal_deltas = (delta_T, delta_H, delta_W) - - if optimal_deltas is None: - raise ValueError("Cannot crop tensor to meet divisibility requirement.") - - delta_T, delta_H, delta_W = optimal_deltas - - # Perform the cropping - # T dimension: crop from the end - end_T = T - delta_T - - # H dimension: center crop - start_H = delta_H // 2 - end_H = H - (delta_H - delta_H // 2) - - # W dimension: center crop - start_W = delta_W // 2 - end_W = W - (delta_W - delta_W // 2) - - cropped_tensor = tensor[:, :end_T, start_H:end_H, start_W:end_W] - return cropped_tensor - - -def test_no_cropping_needed(): - """Test when the tensor already meets the divisibility requirement.""" - C, T, H, W = 3, 8, 8, 8 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - assert cropped_tensor.shape == (C, T, H, W) - assert (T * H * W) % target_divisor == 0 - - -def test_minimal_cropping_T_dimension(): - """Test minimal cropping along the T dimension.""" - C, T, H, W = 3, 9, 7, 6 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_T = cropped_tensor.shape[1] - assert new_T == T - 1, cropped_tensor.shape - assert (new_T * H * W) % target_divisor == 0 - - -def test_minimal_cropping_H_dimension(): - """Test minimal cropping along the H dimension.""" - C, T, H, W = 3, 7, 9, 6 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_H = cropped_tensor.shape[2] - assert new_H == H - 1, cropped_tensor.shape - assert (T * new_H * W) % target_divisor == 0 - - -def test_minimal_cropping_W_dimension(): - """Test minimal cropping along the W dimension.""" - C, T, H, W = 3, 4, 3, 9 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_W = cropped_tensor.shape[3] - assert new_W == W - 1, cropped_tensor.shape - assert (T * H * new_W) % target_divisor == 0 - - -def test_cropping_multiple_dimensions(): - """Test when minimal cropping requires adjustments on multiple dimensions.""" - C, T, H, W = 3, 9, 9, 8 - target_divisor = 16 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - new_T, new_H, new_W = cropped_tensor.shape[1:] - assert new_T <= T and new_H <= H and new_W <= W - assert (new_T * new_H * new_W) % target_divisor == 0 - - -def test_large_tensor_high_divisor(): - """Test with a larger tensor and higher target_divisor.""" - C, T, H, W = 3, 50, 50, 50 - target_divisor = 1024 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - total_elements = cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3] - assert total_elements % target_divisor == 0 - - -def test_impossible_cropping(): - """Test that an error is raised when it's impossible to meet the requirement.""" - C, T, H, W = 3, 1, 1, 1 - target_divisor = 2 - tensor = np.zeros((C, T, H, W)) - try: - minimal_crop(tensor, target_divisor) - except ValueError: - pass - - -def test_invalid_target_divisor(): - """Test that an error is raised when target_divisor is invalid.""" - C, T, H, W = 3, 8, 8, 8 - tensor = np.zeros((C, T, H, W)) - try: - minimal_crop(tensor, -1) - except ValueError: - pass - - -def test_minimal_elements_removed(): - """Test that the minimal number of elements are removed.""" - C, T, H, W = 3, 7, 7, 7 - target_divisor = 8 - tensor = np.zeros((C, T, H, W)) - cropped_tensor = minimal_crop(tensor, target_divisor) - elements_removed = (T * H * W) - (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) - print(cropped_tensor.shape) - assert elements_removed > 0 - assert (cropped_tensor.shape[1] * cropped_tensor.shape[2] * cropped_tensor.shape[3]) % target_divisor == 0 - - -test_no_cropping_needed() -test_minimal_elements_removed() -test_cropping_multiple_dimensions() -test_minimal_cropping_T_dimension() -test_minimal_cropping_H_dimension() -test_minimal_cropping_W_dimension() -test_impossible_cropping() -test_invalid_target_divisor() From 0f57585c2b30263f2388b381f8deecab277cbf59 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Fri, 7 Nov 2025 16:13:15 -0800 Subject: [PATCH 37/80] change english negative prompt --- .../data/dit/base_energon_datamodule.py | 339 ++++++++++++++++++ dfm/src/megatron/data/dit/diffusion_sample.py | 96 +++++ .../dit/diffusion_task_encoder_with_sp.py | 98 +++++ dfm/src/megatron/data/dit/dit_taskencoder.py | 235 ++++++++++++ .../data/dit/sequence_packing_utils.py | 105 ++++++ .../megatron/recipes/wan/inference_wan.py | 2 +- 6 files changed, 874 insertions(+), 1 deletion(-) create mode 100644 dfm/src/megatron/data/dit/base_energon_datamodule.py create mode 100644 dfm/src/megatron/data/dit/diffusion_sample.py create mode 100644 dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py create mode 100644 dfm/src/megatron/data/dit/dit_taskencoder.py create mode 100644 dfm/src/megatron/data/dit/sequence_packing_utils.py diff --git a/dfm/src/megatron/data/dit/base_energon_datamodule.py b/dfm/src/megatron/data/dit/base_energon_datamodule.py new file mode 100644 index 00000000..1429790a --- /dev/null +++ b/dfm/src/megatron/data/dit/base_energon_datamodule.py @@ -0,0 +1,339 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Any, Dict, Literal, Optional + +from megatron.core import parallel_state +from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset + + +logger = logging.getLogger(__name__) + + +class EnergonMultiModalDataModule: + """ + A PyTorch Lightning DataModule for handling multimodal datasets with images and text. + + This data module is designed to work with multimodal datasets that involve both images and text. + It provides a seamless interface to load training and validation data, manage batching, and handle + the state of the data pipeline across training epochs. The module integrates with the Megatron-Energon + framework for efficient data handling in large-scale distributed training. + + Attributes: + path (str): Path to the energon dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int): The maximum sequence length for tokenized text. + micro_batch_size (int): The batch size for training and validation. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory in the DataLoader. + multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples. + task_encoder (MultiModalTaskEncoder): Encoder responsible for encoding and batching samples. + init_global_step (int): The initial global step for the trainer, used for resuming training. + data_sampler (SequentialMegatronSampler): Sampler responsible for generating sequential samples. + train_dataloader_object (Optional): The DataLoader object for training data. + val_dataloader_object (Optional): The DataLoader object for validation data. + """ + + def __init__( + self, + path: str, + tokenizer, + image_processor, + seq_length: int = 2048, + micro_batch_size: int = 1, + global_batch_size: int = 1, + num_workers: int = 1, + num_val_workers: int | None = None, + pin_memory: bool = True, + shuffle_buffer_size: int = 100, + max_samples_per_sequence: int | None = None, + multimodal_sample_config: Optional[Any] = None, + task_encoder: Optional[Any] = None, + decoder_seq_length: Optional[int] = None, + packing_buffer_size: Optional[int] = None, + validation_task_encoder: Optional[Any] = None, + **kwargs, + ) -> None: + """ + Initialize the EnergonMultiModalDataModule. + + Parameters: + path (str): Path to the dataset. + tokenizer (Tokenizer): The tokenizer used for processing text. + image_processor (ImageProcessor): The image processor used for preprocessing images. + seq_length (int, optional): The maximum sequence length for tokenized text. Defaults to 2048. + micro_batch_size (int, optional): The batch size for training and validation. Defaults to 1. + num_workers (int, optional): Number of workers for data loading. Defaults to 1. + num_val_workers (int, optional): Number of workers for validation data loading. Defaults to num_workers. + pin_memory (bool, optional): Whether to pin memory in the DataLoader. Defaults to True. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + shuffle_buffer_size (int, optional): Size of the shuffle buffer. Defaults to 100. + max_samples_per_sequence (int, optional): Maximum number of samples per sequence to load from memory. + Defaults to None (loads the whole tar file at once). + task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding and batching samples. + If not provided, a default (MultimodalTaskEncoder) encoder will be created. Defaults to None. + decoder_seq_length (int, optional): The max sequence length for the decoder. Used in encoder-decoder models + packing_buffer_size (int, optional): Size of the packing buffer for batched samples. Defaults to None. + validation_task_encoder (MultiModalTaskEncoder, optional): Encoder responsible for encoding + and batching samples for validation. Defaults to None and will be the same as task_encoder. + **kwargs: Additional keyword arguments. Will be passed to get_train_dataset() of Energon + """ + + super().__init__() + self.path = path + self.tokenizer = tokenizer + self.image_processor = image_processor + self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length + self.micro_batch_size = micro_batch_size + self.global_batch_size = global_batch_size + self.num_workers = num_workers + self.pin_memory = pin_memory + self.multimodal_sample_config = multimodal_sample_config + self.shuffle_buffer_size = shuffle_buffer_size + self.max_samples_per_sequence = max_samples_per_sequence + self.task_encoder = task_encoder + self.init_global_step = 0 + self.train_dataloader_object = None + self.val_dataloader_object = None + self.packing_buffer_size = packing_buffer_size + self.validation_task_encoder = validation_task_encoder or self.task_encoder + self.num_val_workers = num_val_workers or self.num_workers + self.kwargs = kwargs + + def datasets_provider(self, worker_config, split: Literal["train", "val"] = "val"): + """ + Provide the dataset for training or validation. + + This method retrieves the dataset for the specified split (either 'train' or 'val') and configures + it according to the worker configuration. + + Parameters: + worker_config: Configuration for the data loader workers. + split (Literal['train', 'val'], optional): The data split to retrieve ('train' or 'val'). Defaults to 'val'. + + Returns: + Dataset: The dataset configured for the specified split. + """ + + if split not in {"train", "val"}: + raise ValueError("Invalid value for split. Allowed values are 'train' or 'val'.") + + if split == "train": + task_encoder = self.task_encoder + else: + task_encoder = self.validation_task_encoder + + _dataset = get_train_dataset( + self.path, + batch_size=self.micro_batch_size, + task_encoder=task_encoder, + worker_config=worker_config, + packing_buffer_size=self.packing_buffer_size, + split_part=split, + shuffle_buffer_size=self.shuffle_buffer_size, + max_samples_per_sequence=self.max_samples_per_sequence, + **self.kwargs, + ) + + return _dataset + + def build(self): + return self.train_dataloader(), self.val_dataloader() + + def train_dataloader(self) -> Any: + """ + Initialize and return the training DataLoader. + + This method initializes the DataLoader for the training dataset. It uses the global step + from the trainer to configure the data sampler and ensures that the parallel state is initialized + correctly for distributed training. + + Returns: + TRAIN_DATALOADERS: The DataLoader for the training dataset. + """ + + logger.info(f"Multimodal train dataloader initializing with init_global_step {self.init_global_step}") + if self.train_dataloader_object: + return self.train_dataloader_object + if not parallel_state.is_initialized(): + logger.info( + f"Muiltimodal data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + logger.info( + f" Multimodal train dataloader initializing with" + f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group} ****** " + ) + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + train_dataset = self.datasets_provider(worker_config, split="train") + energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) + self.train_dataloader_object = energon_dataloader + return self.train_dataloader_object + + def val_dataloader(self): + """ + Initialize and return the validation DataLoader. + + This method initializes the DataLoader for the validation dataset. It ensures that the parallel state + is initialized correctly for distributed training and returns a configured DataLoader object. + + Returns: + EVAL_DATALOADERS: The DataLoader for the validation dataset. + """ + if self.val_dataloader_object: + return self.val_dataloader_object + + if not parallel_state.is_initialized(): + logger.info( + f"Muiltimodal val data loader parallel state is not initialized," + f"using default worker config with no_workers {self.num_workers}" + ) + worker_config = WorkerConfig.default_worker_config(self.num_val_workers) + else: + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + data_parallel_group = parallel_state.get_data_parallel_group() + + logger.info(f"rank {rank} world_size {world_size} data_parallel_group {data_parallel_group}") + worker_config = WorkerConfig( + rank=rank, + world_size=world_size, + num_workers=self.num_workers, + data_parallel_group=data_parallel_group, + worker_debug_path=None, + worker_log_level=0, + ) + val_dataset = self.datasets_provider(worker_config, split="val") + energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) + self.val_dataloader_object = energon_loader + return self.val_dataloader_object + + def test_dataloader(self) -> None: + """ + Return None as test dataset split does not exist. + + This method overrides the test_dataloader method and returns None since the test dataset split + is not defined or used in this module. + + Returns: + None + """ + logger.warning("Multimodal dataloader test dataset split does not exist") + return None + + def state_dict(self) -> Dict[str, Any]: + """ + Save the state of the data module. + + This method is called when saving a checkpoint. It generates and saves the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Returns: + Dict[str, Any]: A dictionary containing the state of the data module. + """ + + if self.trainer: + dataloader_obj = self.trainer.train_dataloader + + state = [] + # All ranks should be zero except the dp rank. + if ( + parallel_state.get_context_parallel_rank() + or parallel_state.get_pipeline_model_parallel_rank() + or parallel_state.get_tensor_model_parallel_rank() + or parallel_state.get_expert_model_parallel_rank() + ) == 0: + # Save_state_global in energon assumes that we call it for only the first rank within each group that + # shares the same dataloader state. By making sure that current rank is the first rank in a model + # parallel group, we ensure this. + state = dataloader_obj.save_state_global(global_dst_rank=0) + + consumed_samples = self.data_sampler.compute_consumed_samples( + self.trainer.global_step - self.init_global_step + ) + + if state is None: + state = [] # Megatron core requires all the states on all the ranks to have same python + # type. Energon sends the state as a list + logger.info(f"Multimodal data loader saving dataloader state dict consumed samples {consumed_samples}") + return {"dataloader_state": state, "consumed_samples": consumed_samples} + + logger.warning("trainer object not connected to data module object returning empty state") + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + """ + Load the state of the data module from a checkpoint. + + This method is called when loading a checkpoint. It restores the state of the data module, + including the state of the dataloader and the number of consumed samples. + + Parameters: + state_dict (Dict[str, Any]): The state dictionary containing the saved state of the data module. + """ + if not "dataloader_state" in state_dict: + logger.warning( + f"Data loader state cannot be resumed from state_dict, " + f"it does not have the required key dataloader_state. It has {state_dict.keys()}" + ) + return + + state = state_dict["dataloader_state"] + try: + if self.trainer: + self.trainer.datamodule.train_dataloader().restore_state_global(state) + logger.info("Multimodal dataloader state restored") + else: + logger.error(f"Cannot restore state from state_dict {state_dict}") + raise ValueError( + "Cannot restore state from state_dict: " + "Is the trainer object is initialized and attached to datamodule???" + ) + except Exception as e: + logger.warning( + f"Failed to dataloader restore state due to [Please ensure you are using same version " + f"of energon while saving and loading, Continuing without restoring data loader] : {e}" + ) + + try: + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logger.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + + consumed_samples = state_dict["consumed_samples"] + self.data_sampler.init_consumed_samples = consumed_samples + self.data_sampler.prev_consumed_samples = consumed_samples + logger.info(f"Multimodal dataloader load state dict with consumed_samples {consumed_samples}") + update_num_microbatches( + consumed_samples=consumed_samples, + consistency_check=False, + ) \ No newline at end of file diff --git a/dfm/src/megatron/data/dit/diffusion_sample.py b/dfm/src/megatron/data/dit/diffusion_sample.py new file mode 100644 index 00000000..20f76857 --- /dev/null +++ b/dfm/src/megatron/data/dit/diffusion_sample.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Any, Optional + +import torch +from megatron.energon import Sample + + +@dataclass +class DiffusionSample(Sample): + """ + Data class representing a sample for diffusion tasks. + + Attributes: + video (torch.Tensor): Video latents (C T H W). + t5_text_embeddings (torch.Tensor): Text embeddings (S D). + t5_text_mask (torch.Tensor): Mask for text embeddings. + loss_mask (torch.Tensor): Mask indicating valid positions for loss computation. + image_size (Optional[torch.Tensor]): Tensor containing image dimensions. + fps (Optional[torch.Tensor]): Frame rate of the video. + num_frames (Optional[torch.Tensor]): Number of frames in the video. + padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. + seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. + seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. + pos_ids (Optional[torch.Tensor]): Positional IDs. + latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. + """ + + video: torch.Tensor # video latents (C T H W) + context_embeddings: torch.Tensor # (S D) + context_mask: torch.Tensor = None # 1 + image_size: Optional[torch.Tensor] = None + loss_mask: torch.Tensor = None + fps: Optional[torch.Tensor] = None + num_frames: Optional[torch.Tensor] = None + padding_mask: Optional[torch.Tensor] = None + seq_len_q: Optional[torch.Tensor] = None + seq_len_kv: Optional[torch.Tensor] = None + pos_ids: Optional[torch.Tensor] = None + latent_shape: Optional[torch.Tensor] = None + + def to_dict(self) -> dict: + """Converts the sample to a dictionary.""" + return dict( + video=self.video, + context_embeddings=self.context_embeddings, + context_mask=self.context_mask, + loss_mask=self.loss_mask, + image_size=self.image_size, + fps=self.fps, + num_frames=self.num_frames, + padding_mask=self.padding_mask, + seq_len_q=self.seq_len_q, + seq_len_kv=self.seq_len_kv, + pos_ids=self.pos_ids, + latent_shape=self.latent_shape, + ) + + def __add__(self, other: Any) -> int: + """Adds the sequence length of this sample with another sample or integer.""" + if isinstance(other, DiffusionSample): + # Combine the values of the two instances + return self.seq_len_q.item() + other.seq_len_q.item() + elif isinstance(other, int): + # Add an integer to the value + return self.seq_len_q.item() + other + raise NotImplementedError + + def __radd__(self, other: Any) -> int: + """Handles reverse addition for summing with integers.""" + # This is called if sum or other operations start with a non-DiffusionSample object. + # e.g., sum([DiffusionSample(1), DiffusionSample(2)]) -> the 0 + DiffusionSample(1) calls __radd__. + if isinstance(other, int): + return self.seq_len_q.item() + other + raise NotImplementedError + + def __lt__(self, other: Any) -> bool: + """Compares this sample's sequence length with another sample or integer.""" + if isinstance(other, DiffusionSample): + return self.seq_len_q.item() < other.seq_len_q.item() + elif isinstance(other, int): + return self.seq_len_q.item() < other + raise NotImplementedError \ No newline at end of file diff --git a/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py new file mode 100644 index 00000000..55786c68 --- /dev/null +++ b/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py @@ -0,0 +1,98 @@ +import random +from abc import ABC, abstractmethod +from typing import List + +import torch +from megatron.energon import DefaultTaskEncoder +from megatron.energon.task_encoder.base import stateless +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.dit.sequence_packing_utils import first_fit_decreasing + + +def cook(sample: dict) -> dict: + """ + Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - All keys from the result of `basic_sample_keys(sample)` + - 'json': The contains meta data like resolution, aspect ratio, fps, etc. + - 'pth': contains video latent tensor + - 'pickle': contains text embeddings + """ + return dict( + **basic_sample_keys(sample), + json=sample[".json"], + pth=sample[".pth"], + pickle=sample[".pickle"], + ) + + +class DiffusionTaskEncoderWithSequencePacking(DefaultTaskEncoder, ABC): + cookers = [ + Cooker(cook), + ] + + def __init__( + self, + *args, + max_frames: int = None, + text_embedding_padding_size: int = 512, + seq_length: int = None, + patch_spatial: int = 2, + patch_temporal: int = 1, + packing_buffer_size: int = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.max_frames = max_frames + self.text_embedding_padding_size = text_embedding_padding_size + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.packing_buffer_size = packing_buffer_size + + @abstractmethod + def encode_sample(self, sample: dict) -> dict: + raise NotImplementedError + + def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[DiffusionSample]]: + """ + Selects sequences to pack for mixed image-video training. + """ + results = first_fit_decreasing(samples, self.packing_buffer_size) + random.shuffle(results) + return results + + @stateless + def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSample: + """Construct a new Diffusion sample by concatenating the sequences.""" + + def stack(attr): + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + + def cat(attr): + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + + return DiffusionSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=cat("video"), + context_embeddings=cat("context_embeddings"), + loss_mask=cat("loss_mask"), + seq_len_q=cat("seq_len_q"), + seq_len_kv=cat("seq_len_kv"), + pos_ids=cat("pos_ids"), + latent_shape=stack("latent_shape"), + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + raise NotImplementedError \ No newline at end of file diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py new file mode 100644 index 00000000..a3748672 --- /dev/null +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -0,0 +1,235 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List + +import torch +import torch.nn.functional as F +from einops import rearrange +from megatron.core import parallel_state +from megatron.energon import SkipSample, stateless + +from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.dit.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking + + +class DiTTaskEncoder(DiffusionTaskEncoderWithSequencePacking): + """ + BasicDiffusionTaskEncoder is a class that encodes image/video samples for diffusion tasks. + Attributes: + cookers (list): A list of Cooker objects used for processing. + max_frames (int, optional): The maximum number of frames to consider from the video. Defaults to None. + text_embedding_padding_size (int): The padding size for text embeddings. Defaults to 512. + Methods: + __init__(*args, max_frames=None, text_embedding_padding_size=512, **kwargs): + Initializes the BasicDiffusionTaskEncoder with optional maximum frames and text embedding padding size. + encode_sample(sample: dict) -> dict: + Encodes a given sample dictionary containing video and text data. + Args: + sample (dict): A dictionary containing 'pth' for video latent and 'json' for additional info. + Returns: + dict: A dictionary containing encoded video, text embeddings, text mask, and loss mask. + Raises: + SkipSample: If the video latent contains NaNs, Infs, or is not divisible by the tensor parallel size. + """ + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + + @stateless(restore_seeds=True) + def encode_sample(self, sample: dict) -> DiffusionSample: + video_latent = sample["pth"] + + if torch.isnan(video_latent).any() or torch.isinf(video_latent).any(): + raise SkipSample() + if torch.max(torch.abs(video_latent)) > 1e3: + raise SkipSample() + + info = sample["json"] + video_latent = video_latent.squeeze(0) + C, T, H, W = video_latent.shape + seq_len = ( + video_latent.shape[-1] + * video_latent.shape[-2] + * video_latent.shape[-3] + // self.patch_spatial**2 + // self.patch_temporal + ) + is_image = T == 1 + + if seq_len > self.seq_length: + print(f"Skipping sample {sample['__key__']} because seq_len {seq_len} > self.seq_length {self.seq_length}") + raise SkipSample() + + if self.max_frames is not None: + video_latent = video_latent[:, : self.max_frames, :, :] + + tpcp_size = parallel_state.get_tensor_model_parallel_world_size() + if parallel_state.get_context_parallel_world_size() > 1: + tpcp_size *= parallel_state.get_context_parallel_world_size() * 2 + if (T * H * W) % tpcp_size != 0: + warnings.warn(f"skipping {video_latent.shape=} not divisible by {tpcp_size=}") + raise SkipSample() + + video_latent = rearrange( + video_latent, + "C (T pt) (H ph) (W pw) -> (T H W) (ph pw pt C)", + ph=self.patch_spatial, + pw=self.patch_spatial, + pt=self.patch_temporal, + ) + sample["pickle"] = sample["pickle"].cpu().float().numpy().squeeze(0) + if is_image: + t5_text_embeddings = torch.from_numpy(sample["pickle"]).to(torch.bfloat16) + else: + t5_text_embeddings = torch.from_numpy(sample["pickle"][0]).to(torch.bfloat16) + t5_text_embeddings_seq_length = t5_text_embeddings.shape[0] + + if t5_text_embeddings_seq_length > self.text_embedding_padding_size: + t5_text_embeddings = t5_text_embeddings[: self.text_embedding_padding_size] + else: + t5_text_embeddings = F.pad( + t5_text_embeddings, + ( + 0, + 0, + 0, + self.text_embedding_padding_size - t5_text_embeddings_seq_length, + ), + ) + t5_text_mask = torch.ones(t5_text_embeddings_seq_length, dtype=torch.bfloat16) + + if is_image: + h, w = info["image_height"], info["image_width"] + fps = torch.tensor([30] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([1] * 1, dtype=torch.bfloat16) + else: + h, w = info["height"], info["width"] + fps = torch.tensor([info["framerate"]] * 1, dtype=torch.bfloat16) + num_frames = torch.tensor([info["num_frames"]] * 1, dtype=torch.bfloat16) + image_size = torch.tensor([[h, w, h, w]] * 1, dtype=torch.bfloat16) + + pos_ids = rearrange( + pos_id_3d.get_pos_id_3d(t=T // self.patch_temporal, h=H // self.patch_spatial, w=W // self.patch_spatial), + "T H W d -> (T H W) d", + ) + + if self.packing_buffer_size is None: + pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len)) + loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16) + loss_mask[:seq_len] = 1 + video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len)) + else: + loss_mask = torch.ones(seq_len, dtype=torch.bfloat16) + + return DiffusionSample( + __key__=sample["__key__"], + __restore_key__=sample["__restore_key__"], + __subflavor__=None, + __subflavors__=sample["__subflavors__"], + video=video_latent, + context_embeddings=t5_text_embeddings, + context_mask=t5_text_mask, + loss_mask=loss_mask, + seq_len_q=torch.tensor([seq_len], dtype=torch.int32), + seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32), + pos_ids=pos_ids, + latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32), + ) + + @stateless + def batch(self, samples: List[DiffusionSample]) -> dict: + """Return dictionary with data for batch.""" + if self.packing_buffer_size is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + return dict( + video=sample.video.unsqueeze_(0), + context_embeddings=sample.context_embeddings.unsqueeze_(0), + context_mask=sample.context_mask.unsqueeze_(0) if sample.context_mask is not None else None, + loss_mask=sample.loss_mask.unsqueeze_(0) if sample.loss_mask is not None else None, + seq_len_q=sample.seq_len_q, + seq_len_kv=sample.seq_len_kv, + pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None, + latent_shape=sample.latent_shape, + ) + + +class PosID3D: + def __init__(self, *, max_t=32, max_h=128, max_w=128): + self.max_t = max_t + self.max_h = max_h + self.max_w = max_w + self.generate_pos_id() + + def generate_pos_id(self): + self.grid = torch.stack( + torch.meshgrid( + torch.arange(self.max_t, device="cpu"), + torch.arange(self.max_h, device="cpu"), + torch.arange(self.max_w, device="cpu"), + ), + dim=-1, + ) + + def get_pos_id_3d(self, *, t, h, w): + if t > self.max_t or h > self.max_h or w > self.max_w: + self.max_t = max(self.max_t, t) + self.max_h = max(self.max_h, h) + self.max_w = max(self.max_w, w) + self.generate_pos_id() + return self.grid[:t, :h, :w] + + +pos_id_3d = PosID3D() + + +# def cook_raw_iamges(sample: dict) -> dict: +# """ +# Processes a raw sample dictionary from energon dataset and returns a new dictionary with specific keys. + +# Args: +# sample (dict): The input dictionary containing the raw sample data. + +# Returns: +# dict: A new dictionary containing the processed sample data with the following keys: +# - All keys from the result of `basic_sample_keys(sample)` +# - 'jpg': original images +# - 'png': contains control images +# - 'txt': contains raw text +# """ +# return dict( +# **basic_sample_keys(sample), +# images=sample["jpg"], +# hint=sample["png"], +# txt=sample["txt"], +# ) + + +# class RawImageDiffusionTaskEncoder(DefaultTaskEncoder): +# """ +# Dummy task encoder takes raw image input on CrudeDataset. +# """ + +# cookers = [ +# Cooker(cook_raw_iamges), +# ] \ No newline at end of file diff --git a/dfm/src/megatron/data/dit/sequence_packing_utils.py b/dfm/src/megatron/data/dit/sequence_packing_utils.py new file mode 100644 index 00000000..4a643b5f --- /dev/null +++ b/dfm/src/megatron/data/dit/sequence_packing_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import List + + +def find_first_bin_that_fits(bins: List[List[int]], s: int, bin_size: int) -> int: + """ + Finds the first bin in a list of bins that has enough space to fit a sequence of size 's'. + + Args: + bins: A list of lists, where each inner list represents a bin and contains the current elements in that bin. + s: The size of the sequence to be placed in a bin. + bin_size: The maximum capacity of each bin. + + Returns: + The index of the first bin that can fit the sequence 's', or -1 if no such bin exists. + """ + for i, abin in enumerate(bins): + if sum(abin) + s <= bin_size: + return i + return -1 + + +def first_fit(seqlens: List[int], pack_size: int) -> List[List[int]]: + """ + Packs sequences of varying lengths into bins using the First-Fit algorithm. + + Args: + seqlens: A list of integers, representing the lengths of the sequences to be packed. + pack_size: The maximum capacity of each bin. + + Returns: + A list of lists, where each inner list represents a bin and contains the indices + of the sequences assigned to that bin. + """ + res = [] + for s in seqlens: + first_bin = find_first_bin_that_fits(res, s, pack_size) + if first_bin == -1: # open a new bin + res.append([s]) + else: + res[first_bin].append(s) + return res + + +def first_fit_decreasing(seqlens: List[int], pack_size: int) -> List[List[int]]: + """ + Packs sequences of varying lengths into bins using the First-Fit Decreasing algorithm. + + This is a variation of the First-Fit algorithm where the sequences are sorted by decreasing length before packing. + + Args: + seqlens: A list of integers, representing the lengths of the sequences to be packed. + pack_size: The maximum capacity of each bin. + + Returns: + A list of lists, similar to the output of the 'first_fit' function. + """ + sorted_seqlens = sorted(seqlens, reverse=True) + return first_fit(sorted_seqlens, pack_size) + + +def concat_pad(tensor_list, max_seq_length): + """ + Efficiently concatenates a list of tensors along the first dimension and pads with zeros + to reach max_seq_length. + + Args: + tensor_list (list of torch.Tensor): List of tensors to concatenate and pad. + max_seq_length (int): The desired size of the first dimension of the output tensor. + + Returns: + torch.Tensor: A tensor of shape [max_seq_length, ...], where ... represents the remaining dimensions. + """ + import torch + + # Get common properties from the first tensor + other_shape = tensor_list[0].shape[1:] + dtype = tensor_list[0].dtype + device = tensor_list[0].device + + # Initialize the result tensor with zeros + result = torch.zeros((max_seq_length, *other_shape), dtype=dtype, device=device) + + current_index = 0 + for tensor in tensor_list: + length = tensor.shape[0] + # Directly assign the tensor to the result tensor without checks + result[current_index : current_index + length] = tensor + current_index += length + + return result \ No newline at end of file diff --git a/examples/megatron/recipes/wan/inference_wan.py b/examples/megatron/recipes/wan/inference_wan.py index 7392e9b7..34d10bc9 100644 --- a/examples/megatron/recipes/wan/inference_wan.py +++ b/examples/megatron/recipes/wan/inference_wan.py @@ -191,7 +191,7 @@ def generate(args): "num_train_timesteps": 1000, "sample_fps": 16, "chinese_sample_neg_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", - "english_sample_neg_prompt": "Bright and vivid tones, overexposed, static, blurry details, subtitles, style, artwork, painting, image, still, overall grayish, worst quality, low quality, JPEG compression artifacts, ugly, defective, extra fingers, poorly drawn hands, poorly drawn face, deformed, disfigured, misshapen limbs, fused fingers, motionless image, messy background, three legs, crowded background, walking backward.", + "english_sample_neg_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards", } ) From d17286d81b718728af92cc2959cf0114b342e68b Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 10 Nov 2025 10:23:21 -0800 Subject: [PATCH 38/80] seem to workable seq_packing --- .../dit/diffusion_task_encoder_with_sp.py | 2 +- .../data/wan/wan_energon_datamodule.py | 12 +- dfm/src/megatron/data/wan/wan_sample.py | 22 ++ dfm/src/megatron/data/wan/wan_taskencoder.py | 196 ++++++++++------ .../flow_matching/flow_inference_pipeline.py | 3 + .../model/wan/flow_matching/flow_pipeline.py | 32 +-- dfm/src/megatron/model/wan/rope_utils.py | 24 +- dfm/src/megatron/model/wan/wan_model.py | 5 +- dfm/src/megatron/model/wan/wan_step.py | 7 + dfm/src/megatron/recipes/wan/wan.py | 2 + .../wan/prepare_energon_dataset_wan.py | 220 ++++++++++++++---- 11 files changed, 364 insertions(+), 161 deletions(-) create mode 100644 dfm/src/megatron/data/wan/wan_sample.py diff --git a/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py index 55786c68..d97ab9aa 100644 --- a/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py @@ -65,7 +65,7 @@ def select_samples_to_pack(self, samples: List[DiffusionSample]) -> List[List[Di """ Selects sequences to pack for mixed image-video training. """ - results = first_fit_decreasing(samples, self.packing_buffer_size) + results = first_fit_decreasing(samples, self.seq_length) random.shuffle(results) return results diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index fd3c2a01..031d4395 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -19,24 +19,28 @@ from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from torch import int_repr -from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModuleConfig, DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder @dataclass(kw_only=True) -class WanDataModuleConfig(DatasetProvider): +class WanDataModuleConfig(DiffusionDataModuleConfig): path: str seq_length: int + packing_buffer_size: int micro_batch_size: int global_batch_size: int num_workers: int_repr dataloader_type: str = "external" - + def __post_init__(self): self.dataset = DiffusionDataModule( path=self.path, seq_length=self.seq_length, - task_encoder=WanTaskEncoder(seq_length=self.seq_length), + packing_buffer_size=self.packing_buffer_size, + task_encoder=WanTaskEncoder( + seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size + ), micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, num_workers=self.num_workers, diff --git a/dfm/src/megatron/data/wan/wan_sample.py b/dfm/src/megatron/data/wan/wan_sample.py new file mode 100644 index 00000000..a17cfb51 --- /dev/null +++ b/dfm/src/megatron/data/wan/wan_sample.py @@ -0,0 +1,22 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample + + +@dataclass +class WanSample(DiffusionSample): + video_metadata: dict = None + seq_len_q_padded: int = None \ No newline at end of file diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index bb3025b0..02c56b09 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -14,12 +14,16 @@ # pylint: disable=C0115,C0116,C0301 +from torch._tensor import Tensor import torch import torch.nn.functional as F +from megatron.energon.task_encoder.base import stateless from megatron.core import parallel_state -from megatron.energon import DefaultTaskEncoder, SkipSample +from typing import List +from megatron.energon import SkipSample +from dfm.src.megatron.data.dit.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys - +from dfm.src.megatron.data.wan.wan_sample import WanSample from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify @@ -45,7 +49,7 @@ def cook(sample: dict) -> dict: ) -class WanTaskEncoder(DefaultTaskEncoder): +class WanTaskEncoder(DiffusionTaskEncoderWithSequencePacking): """ Task encoder for Wan dataset. Attributes: @@ -73,6 +77,7 @@ def __init__( self.patch_temporal = patch_temporal self.seq_length = seq_length + @stateless(restore_seeds=True) def encode_sample(self, sample: dict) -> dict: video_latent = sample["pth"] context_embeddings = sample["pickle"] @@ -90,88 +95,125 @@ def encode_sample(self, sample: dict) -> dict: patch_size=(self.patch_temporal, self.patch_spatial, self.patch_spatial), ) - ### Note: shape of sample's values - # video_latent: [latents_channels, F_latents, W_latents, H_latents] - # grid_size: [F_patches, W_patches, H_patches] - # context_embeddings: [context_seq_len, text_embedding_dim] + # patchify video_latent + video_latent = patchify([video_latent], (self.patch_temporal, self.patch_spatial, self.patch_spatial))[0] - return dict( - video_latent=video_latent, - grid_size=grid_size, - context_embeddings=context_embeddings, - video_metadata=video_metadata, - ) + # process text embeddings + # pad here for text embeddings + context_max_len = 512 + context_embeddings = F.pad(context_embeddings, (0, 0, 0, context_max_len - context_embeddings.shape[0])) + + # calculate sequence length + seq_len_q = video_latent.shape[0] + seq_len_kv = context_embeddings.shape[0] + + # loss mask + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) - def batch(self, samples: list[dict]) -> dict: - # process video latents - # do padding here for video latents - self.patch_size = (self.patch_temporal, self.patch_spatial, self.patch_spatial) - - # running patchify - video_latents = patchify([sample["video_latent"] for sample in samples], self.patch_size) - - # build per-sample loss masks (1 for valid tokens pre-padding) - loss_masks = [torch.ones(v.shape[0]) for v in video_latents] - # calculate all sequence lengths of video latents for self-attention (for videos, we do this before padding to get original seq len) - seq_len_q = [v.shape[0] for v in video_latents] - seq_len_q = torch.tensor(seq_len_q, dtype=torch.int32) - # padding and stack video latents - max_video_seq_len = max([video_latent.shape[0] for video_latent in video_latents]) - # CAVEAT: - # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because - # because pipeline parallelism requires pre-specified sequence length to create buffer - if parallel_state.get_pipeline_model_parallel_world_size() > 1: - if max_video_seq_len > self.seq_length: - raise ValueError( - f"max_video_seq_len {max_video_seq_len} is greater than DataModule's seq_length {self.seq_length}" - ) - else: - # set max_video_seq_len to DataModule's seq_length - max_video_seq_len = self.seq_length # CAVEAT: # when using context parallelism, we need to pad batch sequence length to be divisible by [cp_rank*2] # (because TransformerEngine's context parallelism requires "AssertionError: Sequence length per GPU needs to be divisible by 2!") if parallel_state.get_context_parallel_world_size() > 1: - batch_size = len(video_latents) - assert batch_size == 1, "Error: Batch size must be 1 when using context parallelism" sharding_factor = parallel_state.get_context_parallel_world_size() * 2 - max_video_seq_len = ((max_video_seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor - video_latents = [ - F.pad(video_latent, (0, 0, 0, max_video_seq_len - video_latent.shape[0])) for video_latent in video_latents - ] - video_latents = torch.stack(video_latents, dim=1) - # pad and stack loss masks to shape [S_max, B] - loss_masks = [F.pad(m, (0, max_video_seq_len - m.shape[0])) for m in loss_masks] - loss_masks = torch.stack(loss_masks, dim=1) - - # process grid sizes - grid_sizes = [torch.tensor(sample["grid_size"], dtype=torch.int32) for sample in samples] - grid_sizes = torch.stack(grid_sizes, dim=0) + seq_len_q_padded = ((seq_len_q + sharding_factor - 1) // sharding_factor) * sharding_factor + else: + seq_len_q_padded = seq_len_q - # process text embeddings - # pad here for text embeddings - context_max_len = 512 - context_embeddings = [sample["context_embeddings"] for sample in samples] - context_embeddings = [ - F.pad(context_embedding, (0, 0, 0, context_max_len - context_embedding.shape[0])) - for context_embedding in context_embeddings - ] - # calculate all sequence lengths of context embeddings for cross-attention (for videos, we do this after padding to get padded seq len) - seq_len_kv = [c.shape[0] for c in context_embeddings] - seq_len_kv = torch.tensor(seq_len_kv, dtype=torch.int32) - # stack context embeddings - context_embeddings = torch.stack(context_embeddings, dim=1) - - # process video metadata - video_metadata = [sample["video_metadata"] for sample in samples] - - return dict( - video_latents=video_latents, - max_video_seq_len=max_video_seq_len, - grid_sizes=grid_sizes, + # padding + if seq_len_q < seq_len_q_padded: + video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len_q)) + loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len_q)) + + ### Note: shape of sample's values + # video_latent: [num_patches, latents_channels * pF * pH * pW] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + return WanSample( + __key__=sample["__key__"], + __restore_key__=sample["__restore_key__"], + __subflavor__=None, + __subflavors__=sample["__subflavors__"], + video=video_latent, context_embeddings=context_embeddings, - loss_mask=loss_masks, - seq_len_q=seq_len_q, - seq_len_kv=seq_len_kv, + latent_shape=torch.tensor(grid_size, dtype=torch.int32), + loss_mask=loss_mask, + seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), + seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), + seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), video_metadata=video_metadata, ) + + # NOTE: + # the method select_samples_to_pack() is inherited from the parent + # class DiffusionTaskEncoderWithSequencePacking + + @stateless + def pack_selected_samples(self, samples: List[WanSample]) -> WanSample: + """Construct a new Wan sample by concatenating the sequences.""" + + def stack(attr): + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + + def cat(attr): + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + + return WanSample( + __key__=",".join([s.__key__ for s in samples]), + __restore_key__=(), # Will be set by energon based on `samples` + __subflavor__=None, + __subflavors__=samples[0].__subflavors__, + video=cat("video"), + context_embeddings=cat("context_embeddings"), + loss_mask=cat("loss_mask"), + seq_len_q=cat("seq_len_q"), + seq_len_q_padded=cat("seq_len_q_padded"), + seq_len_kv=cat("seq_len_kv"), + latent_shape=stack("latent_shape"), + video_metadata=[sample.video_metadata for sample in samples], + ) + + @stateless + def batch(self, samples: List[WanSample]) -> dict: + """Return dictionary with data for batch.""" + if self.packing_buffer_size is None: + # no packing + return super().batch(samples).to_dict() + + # packing + sample = samples[0] + + # # CAVEAT: + # # when using pipeline parallelism, we need to set batch sequence length to DataModule's seq_length because + # # because pipeline parallelism requires pre-specified sequence length to create buffer + # if parallel_state.get_pipeline_model_parallel_world_size() > 1: + # if sample.video.shape[0] > self.seq_length: + # raise ValueError( + # f"video sequence length {sample.video.shape[0]} is greater than DataModule's seq_length {self.seq_length}" + # ) + # else: + # # set max_video_seq_len to DataModule's seq_length + # padded_seq_len = self.seq_length + + batch = dict( + video_latents=sample.video.unsqueeze(1), + context_embeddings=sample.context_embeddings.unsqueeze(1), + loss_mask=sample.loss_mask.unsqueeze(1) if sample.loss_mask is not None else None, + seq_len_q=sample.seq_len_q, + seq_len_q_padded=sample.seq_len_q_padded, + seq_len_kv=sample.seq_len_kv, + grid_sizes=sample.latent_shape, + video_metadata=sample.video_metadata, + ) + + ### Note: shape of batch's values + # video_latents: [seq_len, 1, latents_channels * pF * pH * pW] + # context_embeddings: [seq_len, 1, text_embedding_dim] + # loss_mask: [seq_len, 1] + # seq_len_q: [num_samples] + # seq_len_q_padded: [num_samples] + # seq_len_kv: [num_samples] + # grid_sizes: [num_samples, 3] + # video_metadata: [num_samples] + + return batch diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py index e246b48c..2bbb0eb3 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_inference_pipeline.py @@ -444,11 +444,14 @@ def noop_no_sync(): packed_seq_params = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_q, + cu_seqlens_q_padded=cu_q, cu_seqlens_kv=cu_kv_self, + cu_seqlens_kv_padded=cu_kv_self, qkv_format=self.model.config.qkv_format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_q, + cu_seqlens_q_padded=cu_q, cu_seqlens_kv=cu_kv_cross, qkv_format=self.model.config.qkv_format, ), diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 25c9b93d..d06d2993 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -54,7 +54,6 @@ def training_step( """ video_latents = data_batch["video_latents"] - max_video_seq_len = data_batch["max_video_seq_len"] context_embeddings = data_batch["context_embeddings"] loss_mask = data_batch["loss_mask"] grid_sizes = data_batch["grid_sizes"] @@ -113,7 +112,7 @@ def training_step( in_channels = model.config.in_channels patch_spatial = model.config.patch_spatial patch_temporal = model.config.patch_temporal - for grid_size in grid_sizes: + for i, grid_size in enumerate(grid_sizes): sample_noise = torch.randn( 1, in_channels, @@ -127,22 +126,27 @@ def training_step( 0 ] # shape [noise_seq, c * ( pF * pH * pW)] + # because video_latents might be padded, we need to make sure noise also be padded to have the same shape - noise_seq = sample_noise.shape[0] - video_seq = video_latents.shape[0] - if noise_seq < video_seq: - pad_len = video_seq - noise_seq + sample_noise_seq_len = sample_noise.shape[0] + cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded + seq_len_q_padded = cu_seqlens_q_padded[i+1] - cu_seqlens_q_padded[i] + if sample_noise_seq_len < seq_len_q_padded: + pad_len = seq_len_q_padded - sample_noise_seq_len pad = torch.zeros( (pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype ) - sample_noise = torch.cat([sample_noise, pad], dim=0) + sample_noise = torch.cat([sample_noise, pad], dim=0) # shape [padded_noise_seq, c * ( pF * pH * pW)] + noise.append(sample_noise) - noise = torch.stack(noise, dim=1) # shape [noise_seq, batch_size, c * ( pF * pH * pW)] + noise = torch.cat(noise, dim=0) # shape [concatenated_noise_seq, c * ( pF * pH * pW)] + noise = noise.unsqueeze(1) # shape [concatenated_noise_seq, 1, c * ( pF * pH * pW)] + # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) # x_t = (1 - σ) * x_0 + σ * ε - sigma_reshaped = sigma.view(1, batch_size, 1) - noisy_latents = (1.0 - sigma_reshaped) * video_latents.float() + sigma_reshaped * noise + # since we use sequence packing, the batch_size is 1) + noisy_latents = (1.0 - sigma) * video_latents.float() + sigma * noise # Timesteps for model [0, 1000] timesteps = sigma * num_train_timesteps @@ -201,7 +205,6 @@ def training_step( grid_sizes=grid_sizes, t=timesteps, context=context_embeddings, - max_seq_len=max_video_seq_len, packed_seq_params=packed_seq_params, ) @@ -219,10 +222,10 @@ def training_step( loss = torch.nn.functional.mse_loss(model_pred.float(), target.float(), reduction="none") # Flow weight: w = 1 + shift * σ - loss_weight = 1.0 + flow_shift * sigma # shape [batch_size] - loss_weight = loss_weight.view(1, batch_size, 1).to(device) # shape [1, batch_size, 1] + loss_weight = 1.0 + flow_shift * sigma unweighted_loss = loss - weighted_loss = loss * loss_weight # shape [seq_length / cp_size, batch_size, -1] + # since we use sequence packing, the batch_size is 1 + weighted_loss = loss * loss_weight # shape [seq_length / cp_size, 1, -1] # Safety check mean_weighted_loss = weighted_loss.mean() @@ -239,7 +242,6 @@ def training_step( grid_sizes=grid_sizes, t=timesteps, context=context_embeddings, - max_seq_len=max_video_seq_len, packed_seq_params=packed_seq_params, ) diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 76e076eb..8d1eca14 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -40,7 +40,7 @@ def rope_params(self, max_position_len, dim_head, theta=10000): ) return freqs - def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): + def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): self.freqs = self.freqs.to( device, ) @@ -65,21 +65,25 @@ def forward(self, n_head, dim_head, max_seq_len, grid_sizes, device): # Double dimension from c -> 2c with rotating angles as (x0, x0, x1, x1, ...), for interleaving RoPE freqs_real_i = freqs_real_i.unsqueeze(-1).expand(-1, -1, -1, -1, 2).reshape(seq_len, 1, 1, dim_head) - # Pad freqs_real_i to (max_seq_len, 1, 1, dim_head) with 0s - if freqs_real_i.shape[0] < max_seq_len: - pad_shape = (max_seq_len - freqs_real_i.shape[0], 1, 1, dim_head) + freqs_real.append(freqs_real_i) + + # Pad freqs_real_i to (padded_seq_len, 1, 1, dim_head) with 0s + for i, freqs_real_i in enumerate(freqs_real): + seq_len_q_padded = cu_seqlens_q_padded[i+1] - cu_seqlens_q_padded[i] + if freqs_real_i.shape[0] < seq_len_q_padded: + pad_shape = (seq_len_q_padded - freqs_real_i.shape[0], 1, 1, dim_head) freqs_real_i = torch.cat( - [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)] + [freqs_real_i, torch.zeros(pad_shape, dtype=freqs_real_i.dtype, device=freqs_real_i.device)], dim=0 ) - freqs_real.append(freqs_real_i) + freqs_real[i] = freqs_real_i - # Each freqs_real[i] is (max_seq_len, 1, 1, dim_head) - # We concatenate them along dim=1 to get (max_seq_len, batch_size, 1, dim_head) - freqs_real = torch.cat(freqs_real, dim=1) + # Each freqs_real[i] is (seq_len, 1, 1, dim_head) + # We concatenate them along dim=0 to get (concatenated_seq_len, 1, 1, dim_head) + freqs_real = torch.cat(freqs_real, dim=0) # Note: # when running context_parallel, which must use "thd" for qkv_format, # we don't need to scatter the freqs to the context parallel region, # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region - return freqs_real + return freqs_real \ No newline at end of file diff --git a/dfm/src/megatron/model/wan/wan_model.py b/dfm/src/megatron/model/wan/wan_model.py index 3154affd..b75a57d2 100644 --- a/dfm/src/megatron/model/wan/wan_model.py +++ b/dfm/src/megatron/model/wan/wan_model.py @@ -183,7 +183,6 @@ def forward( grid_sizes: list[Tuple[int, int, int]], t: Tensor, context: Tensor, - max_seq_len: int, packed_seq_params: PackedSeqParams = None, **kwargs, ) -> Tensor: @@ -194,7 +193,6 @@ def forward( grid_sizes List[Tuple[int, int, int]]: list of grid sizes (f, h, w) t Tensor: timesteps context List[Tensor]: list of context (text_len, hidden_size) - max_seq_len int: maximum sequence length packed_seq_params PackedSeqParams: packed sequence parameters Returns: @@ -238,8 +236,9 @@ def forward( # ============= decoder ============= # calculate rotary pos emb n_head, dim_head = self.num_heads, self.config.hidden_size // self.num_heads + cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded rotary_pos_emb = self.rope_embeddings( - n_head, dim_head, max_seq_len, grid_sizes, t.device + n_head, dim_head, cu_seqlens_q_padded, grid_sizes, t.device ) # output: rotary_pos_emb.shape [s, b, 1, dim_head] # run decoder diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index c7a155ca..f42bef8e 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -42,17 +42,24 @@ def wan_data_step(qkv_format, dataloader_iter): zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens = torch.cat((zero, cu_seqlens)) + cu_seqlens_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) + zero = torch.zeros(1, dtype=torch.int32, device="cuda") + cu_seqlens_padded = torch.cat((zero, cu_seqlens_padded)) + cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) batch["packed_seq_params"] = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, cu_seqlens_kv=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens_padded, qkv_format=qkv_format, ), "cross_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens, + cu_seqlens_q_padded=cu_seqlens_padded, cu_seqlens_kv=cu_seqlens_kv, qkv_format=qkv_format, ), diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index b092b96a..b1f1e4e5 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -179,6 +179,8 @@ def pretrain_config( micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, num_workers=10, + task_encoder_seq_length=None, + packing_buffer_size=131072, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens ) # Config Container diff --git a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index 98386c5a..3a9c91c0 100644 --- a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -21,6 +21,7 @@ import numpy as np import torch import webdataset as wds + from diffusers import AutoencoderKLWan from transformers import AutoTokenizer, UMT5EncoderModel @@ -204,6 +205,54 @@ def _load_frames_cv2( return video_tensor +def _extract_evenly_spaced_frames( + video_path: str, + start_frame: int, + end_frame: int, + num_frames: int, + target_size: Optional[Tuple[int, int]], + resize_mode: str, + maintain_aspect_ratio: bool, + center_crop: bool, +) -> List[np.ndarray]: + cap = cv2.VideoCapture(video_path) + total_frames = max(0, end_frame - start_frame + 1) + if total_frames <= 0: + cap.release() + raise ValueError(f"Invalid frame range [{start_frame}, {end_frame}] for {video_path}") + + if num_frames <= 1: + frame_indices = [start_frame] + else: + frame_indices = np.linspace(start_frame, end_frame, num_frames, dtype=int).tolist() + + extracted_frames: List[np.ndarray] = [] + for frame_idx in frame_indices: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + ret, frame = cap.read() + if not ret: + continue + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + frame = _resize_frame(frame, target_size, resize_mode, maintain_aspect_ratio, center_crop) + extracted_frames.append(frame) + + cap.release() + if not extracted_frames: + raise ValueError(f"Could not extract any frames from {video_path}") + return extracted_frames + + +def _frame_to_video_tensor(frame: np.ndarray, target_dtype: torch.dtype) -> torch.Tensor: + # frame: RGB numpy array (H, W, C), uint8 or float + if frame.dtype != np.float32: + frame = frame.astype(np.float32) + frame = frame / 255.0 if frame.max() > 1.0 else frame + tensor = torch.from_numpy(frame) # H, W, C + tensor = tensor.permute(2, 0, 1).unsqueeze(0).unsqueeze(2) # 1, C, 1, H, W + tensor = tensor.to(dtype=target_dtype) + return tensor + + @torch.no_grad() def _init_hf_models( model_id: str, @@ -309,6 +358,18 @@ def main(): ) parser.add_argument("--no-memory-optimization", action="store_true", help="Disable VAE slicing/tiling") parser.add_argument("--shard_maxcount", type=int, default=10000, help="Max samples per shard") + parser.add_argument( + "--mode", + default="video", + choices=["video", "frames"], + help="Processing mode: 'video' for full videos, 'frames' to extract frames and treat each as a 1-frame video", + ) + parser.add_argument( + "--num-frames", + type=int, + default=10, + help="Number of evenly-spaced frames to extract per video when using --mode frames", + ) # Resize arguments (match automodel) parser.add_argument("--height", type=int, default=None, help="Target height for video frames") @@ -351,61 +412,118 @@ def main(): for index, meta in enumerate(metadata_list): video_name = meta["file_name"] start_frame = int(meta["start_frame"]) # inclusive - end_frame = int(meta["end_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive caption_text = meta.get("vila_caption", "") video_path = str(video_folder / video_name) - # Load frames using the same OpenCV + resize path as automodel - video_tensor = _load_frames_cv2( - video_path=video_path, - start_frame=start_frame, - end_frame=end_frame, - target_size=target_size, - resize_mode=args.resize_mode, - maintain_aspect_ratio=not args.no_aspect_ratio, - center_crop=args.center_crop, - target_dtype=model_dtype, - ) - - # Encode text and video with HF models exactly like automodel - text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) - latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) - - # Move to CPU without changing dtype; keep exact values to match automodel outputs - text_embed_cpu = text_embed.detach().to(device="cpu") - latents_cpu = latents.detach().to(device="cpu") - - # Reshape to match Mcore's Wan input format - text_embed_cpu = text_embed_cpu[0] - latents_cpu = latents_cpu[0] - - # Build JSON side-info similar to prepare_energon script - C, T, H, W = video_tensor.shape[1:] # 1,C,T,H,W - json_data = { - "video_path": video_path, - "processed_frames": int(T), - "processed_height": int(H), - "processed_width": int(W), - "caption": caption_text, - "deterministic_latents": bool(not args.stochastic), - "memory_optimization": bool(not args.no_memory_optimization), - "model_version": "wan2.1", - "resize_settings": { - "target_size": target_size, - "resize_mode": args.resize_mode, - "maintain_aspect_ratio": bool(not args.no_aspect_ratio), - "center_crop": bool(args.center_crop), - }, - } + if args.mode == "video": + # Load frames using the same OpenCV + resize path as automodel + video_tensor = _load_frames_cv2( + video_path=video_path, + start_frame=start_frame, + end_frame=end_frame, + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + target_dtype=model_dtype, + ) - sample = { - "__key__": f"{index:06}", - "pth": latents_cpu, - "pickle": pickle.dumps(text_embed_cpu), - "json": json_data, - } - sink.write(sample) - written += 1 + # Encode text and video with HF models exactly like automodel + text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) + latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) + + # Move to CPU without changing dtype; keep exact values to match automodel outputs + text_embed_cpu = text_embed.detach().to(device="cpu") + latents_cpu = latents.detach().to(device="cpu") + + # Reshape to match Mcore's Wan input format + text_embed_cpu = text_embed_cpu[0] + latents_cpu = latents_cpu[0] + + # Build JSON side-info similar to prepare_energon script + C, T, H, W = video_tensor.shape[1:] # 1,C,T,H,W + json_data = { + "video_path": video_path, + "processed_frames": int(T), + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "processing_mode": "video", + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + sample = { + "__key__": f"{index:06}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu), + "json": json_data, + } + sink.write(sample) + written += 1 + else: + # Frames mode: extract evenly-spaced frames, treat each as a 1-frame video + frames = _extract_evenly_spaced_frames( + video_path=video_path, + start_frame=start_frame, + end_frame=end_frame, + num_frames=max(1, int(args.num_frames)), + target_size=target_size, + resize_mode=args.resize_mode, + maintain_aspect_ratio=not args.no_aspect_ratio, + center_crop=args.center_crop, + ) + + # Encode text once and reuse for all frames of this video + text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) + text_embed_cpu = text_embed.detach().to(device="cpu")[0] + + total_extracted = len(frames) + for frame_idx, frame in enumerate(frames, start=1): + video_tensor = _frame_to_video_tensor(frame, target_dtype=model_dtype) + latents = _encode_video_latents( + vae, args.device, video_tensor, deterministic_latents=not args.stochastic + ) + latents_cpu = latents.detach().to(device="cpu")[0] + + # Frame shape after resize + H, W = frame.shape[:2] + json_data = { + "video_path": video_path, + "processed_frames": 1, + "processed_height": int(H), + "processed_width": int(W), + "caption": caption_text, + "deterministic_latents": bool(not args.stochastic), + "memory_optimization": bool(not args.no_memory_optimization), + "model_version": "wan2.1", + "processing_mode": "frames", + "frame_index": int(frame_idx), + "total_frames_in_video": int(total_extracted), + "resize_settings": { + "target_size": target_size, + "resize_mode": args.resize_mode, + "maintain_aspect_ratio": bool(not args.no_aspect_ratio), + "center_crop": bool(args.center_crop), + }, + } + + sample = { + "__key__": f"{index:06}_{frame_idx:02}", + "pth": latents_cpu, + "pickle": pickle.dumps(text_embed_cpu), + "json": json_data, + } + sink.write(sample) + written += 1 print("Done writing shards using HF automodel encoders.") From e93690709c2784852a1a30dc28a155c5625e5bb2 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 10 Nov 2025 11:32:26 -0800 Subject: [PATCH 39/80] refactor with Sajad's PR - DiT data to common dir --- .../megatron/data/{dit => common}/base_energon_datamodule.py | 0 .../data/{dit => common}/diffusion_energon_datamodule.py | 2 +- dfm/src/megatron/data/{dit => common}/diffusion_sample.py | 0 .../data/{dit => common}/diffusion_task_encoder_with_sp.py | 4 ++-- .../megatron/data/{dit => common}/sequence_packing_utils.py | 0 dfm/src/megatron/data/dit/dit_taskencoder.py | 4 ++-- dfm/src/megatron/data/wan/wan_energon_datamodule.py | 2 +- dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py | 2 +- dfm/src/megatron/data/wan/wan_sample.py | 2 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- 10 files changed, 9 insertions(+), 9 deletions(-) rename dfm/src/megatron/data/{dit => common}/base_energon_datamodule.py (100%) rename dfm/src/megatron/data/{dit => common}/diffusion_energon_datamodule.py (98%) rename dfm/src/megatron/data/{dit => common}/diffusion_sample.py (100%) rename dfm/src/megatron/data/{dit => common}/diffusion_task_encoder_with_sp.py (95%) rename dfm/src/megatron/data/{dit => common}/sequence_packing_utils.py (100%) diff --git a/dfm/src/megatron/data/dit/base_energon_datamodule.py b/dfm/src/megatron/data/common/base_energon_datamodule.py similarity index 100% rename from dfm/src/megatron/data/dit/base_energon_datamodule.py rename to dfm/src/megatron/data/common/base_energon_datamodule.py diff --git a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py b/dfm/src/megatron/data/common/diffusion_energon_datamodule.py similarity index 98% rename from dfm/src/megatron/data/dit/diffusion_energon_datamodule.py rename to dfm/src/megatron/data/common/diffusion_energon_datamodule.py index 4fb5785e..eac2e604 100644 --- a/dfm/src/megatron/data/dit/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/common/diffusion_energon_datamodule.py @@ -23,7 +23,7 @@ from megatron.energon import DefaultTaskEncoder, get_train_dataset from torch import int_repr -from dfm.src.megatron.data.dit.base_energon_datamodule import EnergonMultiModalDataModule +from dfm.src.megatron.data.common.base_energon_datamodule import EnergonMultiModalDataModule from dfm.src.megatron.data.dit.dit_taskencoder import DiTTaskEncoder diff --git a/dfm/src/megatron/data/dit/diffusion_sample.py b/dfm/src/megatron/data/common/diffusion_sample.py similarity index 100% rename from dfm/src/megatron/data/dit/diffusion_sample.py rename to dfm/src/megatron/data/common/diffusion_sample.py diff --git a/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py similarity index 95% rename from dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py rename to dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index d97ab9aa..fe8b9a93 100644 --- a/dfm/src/megatron/data/dit/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -7,8 +7,8 @@ from megatron.energon.task_encoder.base import stateless from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample -from dfm.src.megatron.data.dit.sequence_packing_utils import first_fit_decreasing +from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.common.sequence_packing_utils import first_fit_decreasing def cook(sample: dict) -> dict: diff --git a/dfm/src/megatron/data/dit/sequence_packing_utils.py b/dfm/src/megatron/data/common/sequence_packing_utils.py similarity index 100% rename from dfm/src/megatron/data/dit/sequence_packing_utils.py rename to dfm/src/megatron/data/common/sequence_packing_utils.py diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index a3748672..d5ef366f 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -21,8 +21,8 @@ from megatron.core import parallel_state from megatron.energon import SkipSample, stateless -from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample -from dfm.src.megatron.data.dit.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking +from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking class DiTTaskEncoder(DiffusionTaskEncoderWithSequencePacking): diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 031d4395..9abcae1a 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -19,7 +19,7 @@ from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from torch import int_repr -from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModuleConfig, DiffusionDataModule +from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig, DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index 7ae7cac3..f92af5e9 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -19,7 +19,7 @@ import torch from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider -from dfm.src.megatron.data.dit.diffusion_energon_datamodule import DiffusionDataModule +from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder diff --git a/dfm/src/megatron/data/wan/wan_sample.py b/dfm/src/megatron/data/wan/wan_sample.py index a17cfb51..60e060e7 100644 --- a/dfm/src/megatron/data/wan/wan_sample.py +++ b/dfm/src/megatron/data/wan/wan_sample.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from dfm.src.megatron.data.dit.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample @dataclass diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 02c56b09..25b55a46 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -21,7 +21,7 @@ from megatron.core import parallel_state from typing import List from megatron.energon import SkipSample -from dfm.src.megatron.data.dit.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking +from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from dfm.src.megatron.data.wan.wan_sample import WanSample from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify From 66796b52451d9e1c47c5ab1bd17a08dc0fc20a7c Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 10 Nov 2025 11:56:16 -0800 Subject: [PATCH 40/80] fix Ruff, lint --- .../data/common/diffusion_task_encoder_with_sp.py | 14 ++++++++++++++ .../megatron/data/wan/wan_energon_datamodule.py | 4 ++-- dfm/src/megatron/data/wan/wan_sample.py | 1 + dfm/src/megatron/data/wan/wan_taskencoder.py | 14 +++++++------- dfm/src/megatron/model/wan/rope_utils.py | 2 +- dfm/src/megatron/recipes/wan/wan.py | 2 +- .../recipes/wan/prepare_energon_dataset_wan.py | 11 ++++++----- 7 files changed, 32 insertions(+), 16 deletions(-) diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index fe8b9a93..9dfa51e4 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -1,3 +1,17 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import random from abc import ABC, abstractmethod from typing import List diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index 9abcae1a..e53b066f 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -16,10 +16,10 @@ from dataclasses import dataclass -from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from megatron.bridge.data.utils import DatasetBuildContext from torch import int_repr -from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModuleConfig, DiffusionDataModule +from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule, DiffusionDataModuleConfig from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder diff --git a/dfm/src/megatron/data/wan/wan_sample.py b/dfm/src/megatron/data/wan/wan_sample.py index 60e060e7..5ac8538f 100644 --- a/dfm/src/megatron/data/wan/wan_sample.py +++ b/dfm/src/megatron/data/wan/wan_sample.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass + from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 25b55a46..2ddca34e 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -14,17 +14,17 @@ # pylint: disable=C0115,C0116,C0301 -from torch._tensor import Tensor -import torch -import torch.nn.functional as F -from megatron.energon.task_encoder.base import stateless -from megatron.core import parallel_state from typing import List -from megatron.energon import SkipSample + from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from dfm.src.megatron.data.wan.wan_sample import WanSample from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify +from megatron.core import parallel_state +from megatron.energon import SkipSample +from megatron.energon.task_encoder.base import stateless +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys +import torch +import torch.nn.functional as F def cook(sample: dict) -> dict: diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 8d1eca14..773e7ac3 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -69,7 +69,7 @@ def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): # Pad freqs_real_i to (padded_seq_len, 1, 1, dim_head) with 0s for i, freqs_real_i in enumerate(freqs_real): - seq_len_q_padded = cu_seqlens_q_padded[i+1] - cu_seqlens_q_padded[i] + seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] if freqs_real_i.shape[0] < seq_len_q_padded: pad_shape = (seq_len_q_padded - freqs_real_i.shape[0], 1, 1, dim_head) freqs_real_i = torch.cat( diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index b1f1e4e5..6cbf60ec 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -180,7 +180,7 @@ def pretrain_config( global_batch_size=global_batch_size, num_workers=10, task_encoder_seq_length=None, - packing_buffer_size=131072, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens + packing_buffer_size=131072, # 131,072 = 2^17 tokens, each 5 secs of 832*480 is about 45k tokens ) # Config Container diff --git a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index 3a9c91c0..f75d7144 100644 --- a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -18,12 +18,11 @@ from typing import Dict, List, Optional, Tuple import cv2 +from diffusers import AutoencoderKLWan import numpy as np import torch -import webdataset as wds - -from diffusers import AutoencoderKLWan from transformers import AutoTokenizer, UMT5EncoderModel +import webdataset as wds def _map_interpolation(resize_mode: str) -> int: @@ -412,7 +411,7 @@ def main(): for index, meta in enumerate(metadata_list): video_name = meta["file_name"] start_frame = int(meta["start_frame"]) # inclusive - end_frame = int(meta["end_frame"]) # inclusive + end_frame = int(meta["end_frame"]) # inclusive caption_text = meta.get("vila_caption", "") video_path = str(video_folder / video_name) @@ -431,7 +430,9 @@ def main(): # Encode text and video with HF models exactly like automodel text_embed = _encode_text(tokenizer, text_encoder, args.device, caption_text) - latents = _encode_video_latents(vae, args.device, video_tensor, deterministic_latents=not args.stochastic) + latents = _encode_video_latents( + vae, args.device, video_tensor, deterministic_latents=not args.stochastic + ) # Move to CPU without changing dtype; keep exact values to match automodel outputs text_embed_cpu = text_embed.detach().to(device="cpu") From 7d8e64f4b067af4bfa2babdff0b1bd06b088f4db Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 10 Nov 2025 12:00:11 -0800 Subject: [PATCH 41/80] fix Ruff, lint --- dfm/src/megatron/data/common/base_energon_datamodule.py | 2 +- dfm/src/megatron/data/common/diffusion_energon_datamodule.py | 2 +- dfm/src/megatron/data/common/diffusion_sample.py | 2 +- .../megatron/data/common/diffusion_task_encoder_with_sp.py | 2 +- dfm/src/megatron/data/common/sequence_packing_utils.py | 2 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- dfm/src/megatron/model/wan/rope_utils.py | 2 +- examples/megatron/recipes/wan/prepare_energon_dataset_wan.py | 4 ++-- 8 files changed, 9 insertions(+), 9 deletions(-) diff --git a/dfm/src/megatron/data/common/base_energon_datamodule.py b/dfm/src/megatron/data/common/base_energon_datamodule.py index 1429790a..f903ee7f 100644 --- a/dfm/src/megatron/data/common/base_energon_datamodule.py +++ b/dfm/src/megatron/data/common/base_energon_datamodule.py @@ -336,4 +336,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: update_num_microbatches( consumed_samples=consumed_samples, consistency_check=False, - ) \ No newline at end of file + ) diff --git a/dfm/src/megatron/data/common/diffusion_energon_datamodule.py b/dfm/src/megatron/data/common/diffusion_energon_datamodule.py index eac2e604..e25cd944 100644 --- a/dfm/src/megatron/data/common/diffusion_energon_datamodule.py +++ b/dfm/src/megatron/data/common/diffusion_energon_datamodule.py @@ -182,4 +182,4 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: try: super().load_state_dict(state_dict) except Exception as e: - logging.warning(f"datamodule.load_state_dict failed {e}") \ No newline at end of file + logging.warning(f"datamodule.load_state_dict failed {e}") diff --git a/dfm/src/megatron/data/common/diffusion_sample.py b/dfm/src/megatron/data/common/diffusion_sample.py index 20f76857..3d5f2384 100644 --- a/dfm/src/megatron/data/common/diffusion_sample.py +++ b/dfm/src/megatron/data/common/diffusion_sample.py @@ -93,4 +93,4 @@ def __lt__(self, other: Any) -> bool: return self.seq_len_q.item() < other.seq_len_q.item() elif isinstance(other, int): return self.seq_len_q.item() < other - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index 9dfa51e4..fe82794a 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -109,4 +109,4 @@ def cat(attr): @stateless def batch(self, samples: List[DiffusionSample]) -> dict: - raise NotImplementedError \ No newline at end of file + raise NotImplementedError diff --git a/dfm/src/megatron/data/common/sequence_packing_utils.py b/dfm/src/megatron/data/common/sequence_packing_utils.py index 4a643b5f..f551a292 100644 --- a/dfm/src/megatron/data/common/sequence_packing_utils.py +++ b/dfm/src/megatron/data/common/sequence_packing_utils.py @@ -102,4 +102,4 @@ def concat_pad(tensor_list, max_seq_length): result[current_index : current_index + length] = tensor current_index += length - return result \ No newline at end of file + return result diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 2ddca34e..d6e56699 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -22,7 +22,7 @@ from megatron.core import parallel_state from megatron.energon import SkipSample from megatron.energon.task_encoder.base import stateless -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys +from megatron.energon.task_encoder.cooking import basic_sample_keys, Cooker import torch import torch.nn.functional as F diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 773e7ac3..2b64fdaa 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -86,4 +86,4 @@ def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): # we don't need to scatter the freqs to the context parallel region, # because mcore rope_utils will automatically retrieve the correct freqs for each context parallel region - return freqs_real \ No newline at end of file + return freqs_real diff --git a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py index f75d7144..529853e1 100644 --- a/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py +++ b/examples/megatron/recipes/wan/prepare_energon_dataset_wan.py @@ -18,11 +18,11 @@ from typing import Dict, List, Optional, Tuple import cv2 -from diffusers import AutoencoderKLWan import numpy as np import torch -from transformers import AutoTokenizer, UMT5EncoderModel import webdataset as wds +from diffusers import AutoencoderKLWan +from transformers import AutoTokenizer, UMT5EncoderModel def _map_interpolation(resize_mode: str) -> int: From 6263299290d1aa0865639c81d195ae956417d089 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Mon, 10 Nov 2025 12:10:09 -0800 Subject: [PATCH 42/80] fix Ruff, lint --- dfm/src/megatron/data/dit/dit_taskencoder.py | 2 +- .../megatron/data/wan/wan_energon_datamodule.py | 6 ++---- dfm/src/megatron/data/wan/wan_sample.py | 2 +- dfm/src/megatron/data/wan/wan_taskencoder.py | 15 ++++++++------- .../model/wan/flow_matching/flow_pipeline.py | 4 +--- 5 files changed, 13 insertions(+), 16 deletions(-) diff --git a/dfm/src/megatron/data/dit/dit_taskencoder.py b/dfm/src/megatron/data/dit/dit_taskencoder.py index d5ef366f..fe3e6180 100644 --- a/dfm/src/megatron/data/dit/dit_taskencoder.py +++ b/dfm/src/megatron/data/dit/dit_taskencoder.py @@ -232,4 +232,4 @@ def get_pos_id_3d(self, *, t, h, w): # cookers = [ # Cooker(cook_raw_iamges), -# ] \ No newline at end of file +# ] diff --git a/dfm/src/megatron/data/wan/wan_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_energon_datamodule.py index e53b066f..fdd0871e 100644 --- a/dfm/src/megatron/data/wan/wan_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_energon_datamodule.py @@ -32,15 +32,13 @@ class WanDataModuleConfig(DiffusionDataModuleConfig): global_batch_size: int num_workers: int_repr dataloader_type: str = "external" - + def __post_init__(self): self.dataset = DiffusionDataModule( path=self.path, seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size, - task_encoder=WanTaskEncoder( - seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size - ), + task_encoder=WanTaskEncoder(seq_length=self.seq_length, packing_buffer_size=self.packing_buffer_size), micro_batch_size=self.micro_batch_size, global_batch_size=self.global_batch_size, num_workers=self.num_workers, diff --git a/dfm/src/megatron/data/wan/wan_sample.py b/dfm/src/megatron/data/wan/wan_sample.py index 5ac8538f..705682de 100644 --- a/dfm/src/megatron/data/wan/wan_sample.py +++ b/dfm/src/megatron/data/wan/wan_sample.py @@ -20,4 +20,4 @@ @dataclass class WanSample(DiffusionSample): video_metadata: dict = None - seq_len_q_padded: int = None \ No newline at end of file + seq_len_q_padded: int = None diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index d6e56699..5cb686d3 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -16,15 +16,16 @@ from typing import List -from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking -from dfm.src.megatron.data.wan.wan_sample import WanSample -from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify +import torch +import torch.nn.functional as F from megatron.core import parallel_state from megatron.energon import SkipSample from megatron.energon.task_encoder.base import stateless -from megatron.energon.task_encoder.cooking import basic_sample_keys, Cooker -import torch -import torch.nn.functional as F +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys + +from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking +from dfm.src.megatron.data.wan.wan_sample import WanSample +from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify def cook(sample: dict) -> dict: @@ -145,7 +146,7 @@ def encode_sample(self, sample: dict) -> dict: ) # NOTE: - # the method select_samples_to_pack() is inherited from the parent + # the method select_samples_to_pack() is inherited from the parent # class DiffusionTaskEncoderWithSequencePacking @stateless diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index d06d2993..f341ae3c 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -126,11 +126,10 @@ def training_step( 0 ] # shape [noise_seq, c * ( pF * pH * pW)] - # because video_latents might be padded, we need to make sure noise also be padded to have the same shape sample_noise_seq_len = sample_noise.shape[0] cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded - seq_len_q_padded = cu_seqlens_q_padded[i+1] - cu_seqlens_q_padded[i] + seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] if sample_noise_seq_len < seq_len_q_padded: pad_len = seq_len_q_padded - sample_noise_seq_len pad = torch.zeros( @@ -142,7 +141,6 @@ def training_step( noise = torch.cat(noise, dim=0) # shape [concatenated_noise_seq, c * ( pF * pH * pW)] noise = noise.unsqueeze(1) # shape [concatenated_noise_seq, 1, c * ( pF * pH * pW)] - # CRITICAL: Manual flow matching (NOT scheduler.add_noise!) # x_t = (1 - σ) * x_0 + σ * ε # since we use sequence packing, the batch_size is 1) From 377ff5bfc5f8269875a634542b6e23f4cae3937f Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 11 Nov 2025 09:12:53 -0800 Subject: [PATCH 43/80] workable mock datamodule (doesn't need setting path); updated training algo + hyper-parameters aligning with Linnan; tested training with anime dataset finetung --- .../data/wan/wan_mock_energon_datamodule.py | 159 ++++++++++++++---- .../model/wan/flow_matching/flow_pipeline.py | 19 ++- .../megatron/model/wan/inference/__init__.py | 3 +- dfm/src/megatron/model/wan/wan_step.py | 33 +++- dfm/src/megatron/recipes/wan/wan.py | 3 +- examples/megatron/recipes/wan/pretrain_wan.py | 28 ++- 6 files changed, 206 insertions(+), 39 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index f92af5e9..c8566de4 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -17,13 +17,51 @@ from dataclasses import dataclass import torch +from torch.utils.data import DataLoader, Dataset from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from typing import List +from dfm.src.megatron.data.common.base_energon_datamodule import EnergonMultiModalDataModule +from megatron.energon import DefaultTaskEncoder +from dfm.src.megatron.model.wan.utils import patchify +from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from dfm.src.megatron.data.common.diffusion_energon_datamodule import DiffusionDataModule -from dfm.src.megatron.data.wan.wan_taskencoder import WanTaskEncoder +def cook(sample: dict) -> dict: + """ + Produces a cooked sample without requiring any real filesystem-backed fields. + This mock version ignores missing keys and synthesizes placeholders so that + the pipeline can operate entirely in-memory. + + Args: + sample (dict): The input dictionary containing the raw sample data. + + Returns: + dict: A new dictionary containing the processed sample data with the following keys: + - A best-effort subset of basic keys when available + - 'json': mock metadata dict + - 'pth': empty tensor placeholder for video latent + - 'pickle': empty tensor placeholder for text embeddings + """ + base_keys = {} + try: + # Attempt to extract any common keys if present; tolerate absence. + base_keys = basic_sample_keys(sample) + except Exception: + base_keys = {} + return { + **base_keys + } +class _MockDataset(Dataset): + def __init__(self, length: int): + self.length = max(int(length), 1) -class WanMockTaskEncoder(WanTaskEncoder): + def __len__(self) -> int: + return self.length + + def __getitem__(self, idx: int) -> dict: + return {} + +class WanMockTaskEncoder(DefaultTaskEncoder): """ Mock task encoder for Wan dataset. Attributes: @@ -33,18 +71,30 @@ class WanMockTaskEncoder(WanTaskEncoder): seq_length (int): The sequence length. Defaults to 1024. """ + patch_spatial: int + patch_temporal: int F_latents: int H_latents: int W_latents: int + seq_length: int + number_packed_samples: int context_seq_len: int context_embeddings_dim: int + cookers = [ + Cooker(cook), + ] + def __init__( self, *args, F_latents: int, H_latents: int, W_latents: int, + seq_length: int, + patch_spatial: int, + patch_temporal: int, + number_packed_samples: int, context_seq_len: int, context_embeddings_dim: int, **kwargs, @@ -53,15 +103,20 @@ def __init__( self.F_latents = F_latents self.H_latents = H_latents self.W_latents = W_latents + self.seq_length = seq_length + self.patch_spatial = patch_spatial + self.patch_temporal = patch_temporal + self.number_packed_samples = number_packed_samples self.context_seq_len = context_seq_len self.context_embeddings_dim = context_embeddings_dim - # mock encode_sample() for debugging - def encode_sample(self, sample: dict) -> dict: - # mock encode sample - video_latent = torch.tensor( - torch.randn(16, self.F_latents, self.H_latents, self.W_latents), dtype=torch.float32 - ) + def encode_sample(self, _sample: dict) -> dict: + return {} + + def batch(self, samples: List[dict]) -> dict: + + # set mock values for one video sample + video_latent = torch.randn(16, self.F_latents, self.H_latents, self.W_latents, dtype=torch.float32) grid_size = torch.tensor( [ video_latent.shape[1] // self.patch_temporal, @@ -70,50 +125,88 @@ def encode_sample(self, sample: dict) -> dict: ], dtype=torch.int32, ) - context_embeddings = torch.tensor( - torch.randn(self.context_seq_len, self.context_embeddings_dim), dtype=torch.float32 - ) + video_latent = patchify([video_latent], (self.patch_temporal, self.patch_spatial, self.patch_spatial))[0] + video_latent = torch.as_tensor(video_latent, dtype=torch.float32) + seq_len_q = video_latent.shape[0] + seq_len_q_padded = seq_len_q + loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) + context_embeddings = torch.randn(self.context_seq_len, self.context_embeddings_dim, dtype=torch.float32) + seq_len_kv = context_embeddings.shape[0] video_metadata = {} - return dict( - video_latent=video_latent, - grid_size=grid_size, - context_embeddings=context_embeddings, + # set mock values for packed video samples + video_latents_packed = [video_latent for _ in range(self.number_packed_samples)] + video_latents_packed = torch.cat(video_latents_packed, dim=0) + loss_masks_packed = [loss_mask for _ in range(self.number_packed_samples)] + loss_masks_packed = torch.cat(loss_masks_packed, dim=0) + seq_len_q_packed = torch.tensor([seq_len_q for _ in range(self.number_packed_samples)], dtype=torch.int32) + seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(self.number_packed_samples)], dtype=torch.int32) + seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(self.number_packed_samples)], dtype=torch.int32) + grid_sizes_packed = torch.stack([grid_size for _ in range(self.number_packed_samples)], dim=0) + context_embeddings_packed = [context_embeddings for _ in range(self.number_packed_samples)] + context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) + + + ### Note: shape of sample's values + # video_latent: [num_patches, latents_channels * pF * pH * pW] + # grid_size: [F_patches, W_patches, H_patches] + # context_embeddings: [context_seq_len, text_embedding_dim] + + batch = dict( + video_latents=video_latents_packed.unsqueeze(1), + context_embeddings=context_embeddings_packed.unsqueeze(1), + loss_mask=loss_masks_packed.unsqueeze(1), + seq_len_q=seq_len_q_packed, + seq_len_q_padded=seq_len_q_padded_packed, + seq_len_kv=seq_len_kv_packed, + grid_sizes=grid_sizes_packed, video_metadata=video_metadata, ) + return batch @dataclass(kw_only=True) class WanMockDataModuleConfig(DatasetProvider): - path: str + path: str = "" seq_length: int + packing_buffer_size: int micro_batch_size: int global_batch_size: int num_workers: int dataloader_type: str = "external" - F_latents: int = 3 + F_latents: int = 24 H_latents: int = 104 W_latents: int = 60 + patch_spatial: int = 2 + patch_temporal: int = 1 + number_packed_samples: int = 3 context_seq_len: int = 512 context_embeddings_dim: int = 4096 def __post_init__(self): - self.dataset = DiffusionDataModule( - path=self.path, + self._task_encoder = WanMockTaskEncoder( + patch_spatial=self.patch_spatial, + patch_temporal=self.patch_temporal, + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, seq_length=self.seq_length, - task_encoder=WanMockTaskEncoder( - seq_length=self.seq_length, - F_latents=self.F_latents, - H_latents=self.H_latents, - W_latents=self.W_latents, - context_seq_len=self.context_seq_len, - context_embeddings_dim=self.context_embeddings_dim, - ), - micro_batch_size=self.micro_batch_size, - global_batch_size=self.global_batch_size, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ) + mock_ds = _MockDataset(length=1024) + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, num_workers=self.num_workers, + collate_fn=self._task_encoder.batch, + shuffle=False, + drop_last=False, ) - self.sequence_length = self.dataset.seq_length + self.sequence_length = self.seq_length - def build_datasets(self, context: DatasetBuildContext): - return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + def build_datasets(self, _context: DatasetBuildContext): + if hasattr(self, "dataset"): + return self.dataset.train_dataloader(), self.dataset.train_dataloader(), self.dataset.train_dataloader() + return self._train_dl, self._train_dl, self._train_dl diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index f341ae3c..11a9baa6 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -43,6 +43,8 @@ def training_step( logit_std: float = 1.0, flow_shift: float = 3.0, mix_uniform_ratio: float = 0.1, + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """ Performs a single training step using flow matching algorithm. @@ -97,10 +99,21 @@ def training_step( sigma = flow_shift / (flow_shift + (1.0 / u_clamped - 1.0)) sigma = torch.clamp(sigma, 0.0, 1.0) + # Clamp sigma (only if not full range [0,1]) + # Pretrain uses [0, 1], finetune uses [0.02, 0.55] + if sigma_min > 0.0 or sigma_max < 1.0: + sigma = torch.clamp(sigma, sigma_min, sigma_max) + else: + sigma = torch.clamp(sigma, 0.0, 1.0) + else: # Simple uniform without shift u = torch.rand(size=(batch_size,), device=device) - sigma = u + # Clamp sigma (only if not full range [0,1]) + if sigma_min > 0.0 or sigma_max < 1.0: + sigma = torch.clamp(u, sigma_min, sigma_max) + else: + sigma = u sampling_method = "uniform_no_shift" # ======================================================================== @@ -156,6 +169,10 @@ def training_step( video_latents = video_latents.to(torch.bfloat16) noisy_latents = noisy_latents.to(torch.bfloat16) context_embeddings = context_embeddings.to(torch.bfloat16) + + # NOTE: investigate the affect of bf16 timesteps on embedding precision + # CRITICAL: Keep timesteps in fp32 for embedding precision + # timesteps = timesteps.float() # NOT bf16! timesteps = timesteps.to(torch.bfloat16) # ======================================================================== diff --git a/dfm/src/megatron/model/wan/inference/__init__.py b/dfm/src/megatron/model/wan/inference/__init__.py index e477e446..1395bc93 100644 --- a/dfm/src/megatron/model/wan/inference/__init__.py +++ b/dfm/src/megatron/model/wan/inference/__init__.py @@ -5,6 +5,7 @@ SIZE_CONFIGS = { + "416*240": (416, 240), "720*1280": (720, 1280), "1280*720": (1280, 720), "480*832": (480, 832), @@ -21,5 +22,5 @@ SUPPORTED_SIZES = { "t2v-14B": ("720*1280", "1280*720", "480*832", "832*480"), - "t2v-1.3B": ("480*832", "832*480"), + "t2v-1.3B": ("416*240", "480*832", "832*480"), } diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index f42bef8e..c82d0f5e 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -69,8 +69,26 @@ def wan_data_step(qkv_format, dataloader_iter): class WanForwardStep: - def __init__(self): + def __init__( + self, + use_sigma_noise: bool = True, + timestep_sampling: str = "uniform", + logit_mean: float = 0.0, + logit_std: float = 1.0, + flow_shift: float = 3.0, + mix_uniform_ratio: float = 0.1, + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) + ): self.diffusion_pipeline = FlowPipeline() + self.use_sigma_noise = use_sigma_noise + self.timestep_sampling = timestep_sampling + self.logit_mean = logit_mean + self.logit_std = logit_std + self.flow_shift = flow_shift + self.mix_uniform_ratio = mix_uniform_ratio + self.sigma_min = sigma_min + self.sigma_max = sigma_max def __call__( self, state: GlobalState, data_iterator: Iterable, model: VisionModule @@ -96,7 +114,18 @@ def __call__( # run diffusion training step with straggler_timer: if parallel_state.is_pipeline_last_stage(): - output_batch, loss, split_loss_mask = self.diffusion_pipeline.training_step(model, batch) + output_batch, loss, split_loss_mask = self.diffusion_pipeline.training_step( + model, + batch, + use_sigma_noise=self.use_sigma_noise, + timestep_sampling=self.timestep_sampling, + logit_mean=self.logit_mean, + logit_std=self.logit_std, + flow_shift=self.flow_shift, + mix_uniform_ratio=self.mix_uniform_ratio, + sigma_min=self.sigma_min, + sigma_max=self.sigma_max, + ) output_tensor = torch.mean(loss, dim=-1) batch["loss_mask"] = split_loss_mask else: diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 6cbf60ec..21bd1b50 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -163,7 +163,7 @@ def pretrain_config( dataset = WanMockDataModuleConfig( path=None, seq_length=1024, # we don't need to use this value, just add because Bridge training requires for LLMs - F_latents=3, + F_latents=24, H_latents=104, W_latents=60, context_seq_len=512, @@ -171,6 +171,7 @@ def pretrain_config( micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, num_workers=10, + packing_buffer_size=None, ) else: dataset = WanDataModuleConfig( diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 950c4cb6..1509565f 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -87,6 +87,12 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: formatter_class=argparse.RawTextHelpFormatter, ) parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") + parser.add_argument( + "--training-mode", + choices=["pretrain", "finetune"], + default="finetune", + help="Set training mode, 'pretrain' or 'finetune'." + ) parser.add_argument( "--config-file", type=str, @@ -164,6 +170,26 @@ def main() -> None: # Apply overrides while preserving excluded fields apply_overrides(cfg, final_overrides_as_dict, excluded_fields) + # Config FlowPipeline based on training mode + if args.training_mode == "pretrain": + wan_forward_step = WanForwardStep( + timestep_sampling="logit_normal", + logit_std=1.5, + flow_shift=2.5, + mix_uniform_ratio=0.2, + sigma_min=0.0, + sigma_max=1.0, + ) + elif args.training_mode == "finetune": + wan_forward_step = WanForwardStep( + timestep_sampling="uniform", + logit_std=1.0, + flow_shift=3.0, + mix_uniform_ratio=0.1, + sigma_min=0.0, + sigma_max=1.0, + ) + # Display final configuration if get_rank_safe() == 0: logger.info("--- Final Merged Configuration ---") @@ -172,7 +198,7 @@ def main() -> None: # Start training logger.debug("Starting pretraining...") - pretrain(config=cfg, forward_step_func=WanForwardStep()) + pretrain(config=cfg, forward_step_func=wan_forward_step) if __name__ == "__main__": From d8550c4e49f0922d726d32add6bfa369f2760340 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 11 Nov 2025 11:20:47 -0800 Subject: [PATCH 44/80] bring wan_task encoders features to common, sharing with dit --- .../megatron/data/common/diffusion_sample.py | 8 +- .../common/diffusion_task_encoder_with_sp.py | 2 + .../data/wan/wan_mock_energon_datamodule.py | 82 ------------------- dfm/src/megatron/data/wan/wan_sample.py | 23 ------ dfm/src/megatron/data/wan/wan_taskencoder.py | 34 ++------ 5 files changed, 14 insertions(+), 135 deletions(-) delete mode 100644 dfm/src/megatron/data/wan/wan_sample.py diff --git a/dfm/src/megatron/data/common/diffusion_sample.py b/dfm/src/megatron/data/common/diffusion_sample.py index 3d5f2384..ccf515c8 100644 --- a/dfm/src/megatron/data/common/diffusion_sample.py +++ b/dfm/src/megatron/data/common/diffusion_sample.py @@ -34,9 +34,11 @@ class DiffusionSample(Sample): num_frames (Optional[torch.Tensor]): Number of frames in the video. padding_mask (Optional[torch.Tensor]): Mask indicating padding positions. seq_len_q (Optional[torch.Tensor]): Sequence length for query embeddings. + seq_len_q_padded (Optional[torch.Tensor]): Sequence length for query embeddings after padding. seq_len_kv (Optional[torch.Tensor]): Sequence length for key/value embeddings. pos_ids (Optional[torch.Tensor]): Positional IDs. latent_shape (Optional[torch.Tensor]): Shape of the latent tensor. + video_metadata (Optional[dict]): Metadata of the video. """ video: torch.Tensor # video latents (C T H W) @@ -48,10 +50,12 @@ class DiffusionSample(Sample): num_frames: Optional[torch.Tensor] = None padding_mask: Optional[torch.Tensor] = None seq_len_q: Optional[torch.Tensor] = None + seq_len_q_padded: Optional[torch.Tensor] = None seq_len_kv: Optional[torch.Tensor] = None pos_ids: Optional[torch.Tensor] = None latent_shape: Optional[torch.Tensor] = None - + video_metadata: Optional[dict] = None + def to_dict(self) -> dict: """Converts the sample to a dictionary.""" return dict( @@ -64,9 +68,11 @@ def to_dict(self) -> dict: num_frames=self.num_frames, padding_mask=self.padding_mask, seq_len_q=self.seq_len_q, + seq_len_q_padded=self.seq_len_q_padded, seq_len_kv=self.seq_len_kv, pos_ids=self.pos_ids, latent_shape=self.latent_shape, + video_metadata=self.video_metadata, ) def __add__(self, other: Any) -> int: diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index 537d4651..49e2b707 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -102,9 +102,11 @@ def cat(attr): context_embeddings=cat("context_embeddings"), loss_mask=cat("loss_mask"), seq_len_q=cat("seq_len_q"), + seq_len_q_padded=cat("seq_len_q_padded"), seq_len_kv=cat("seq_len_kv"), pos_ids=cat("pos_ids"), latent_shape=stack("latent_shape"), + video_metadata=[sample.video_metadata for sample in samples], ) @stateless diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index e78be417..8f2d452c 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -33,88 +33,6 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> dict: return {} -# class WanMockTaskEncoder(): -# """ -# Mock task encoder for Wan dataset. -# """ - -# def __init__( -# self, -# *args, -# F_latents: int, -# H_latents: int, -# W_latents: int, -# seq_length: int, -# patch_spatial: int, -# patch_temporal: int, -# number_packed_samples: int, -# context_seq_len: int, -# context_embeddings_dim: int, -# **kwargs, -# ): -# super().__init__(*args, **kwargs) -# self.F_latents = F_latents -# self.H_latents = H_latents -# self.W_latents = W_latents -# self.seq_length = seq_length -# self.patch_spatial = patch_spatial -# self.patch_temporal = patch_temporal -# self.number_packed_samples = number_packed_samples -# self.context_seq_len = context_seq_len -# self.context_embeddings_dim = context_embeddings_dim - -# def batch(self, samples: List[dict]) -> dict: - -# # set mock values for one video sample -# video_latent = torch.randn(16, self.F_latents, self.H_latents, self.W_latents, dtype=torch.float32) -# grid_size = torch.tensor( -# [ -# video_latent.shape[1] // self.patch_temporal, -# video_latent.shape[2] // self.patch_spatial, -# video_latent.shape[3] // self.patch_spatial, -# ], -# dtype=torch.int32, -# ) -# video_latent = patchify([video_latent], (self.patch_temporal, self.patch_spatial, self.patch_spatial))[0] -# video_latent = torch.as_tensor(video_latent, dtype=torch.float32) -# seq_len_q = video_latent.shape[0] -# seq_len_q_padded = seq_len_q -# loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) -# context_embeddings = torch.randn(self.context_seq_len, self.context_embeddings_dim, dtype=torch.float32) -# seq_len_kv = context_embeddings.shape[0] -# video_metadata = {} - -# # set mock values for packed video samples -# video_latents_packed = [video_latent for _ in range(self.number_packed_samples)] -# video_latents_packed = torch.cat(video_latents_packed, dim=0) -# loss_masks_packed = [loss_mask for _ in range(self.number_packed_samples)] -# loss_masks_packed = torch.cat(loss_masks_packed, dim=0) -# seq_len_q_packed = torch.tensor([seq_len_q for _ in range(self.number_packed_samples)], dtype=torch.int32) -# seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(self.number_packed_samples)], dtype=torch.int32) -# seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(self.number_packed_samples)], dtype=torch.int32) -# grid_sizes_packed = torch.stack([grid_size for _ in range(self.number_packed_samples)], dim=0) -# context_embeddings_packed = [context_embeddings for _ in range(self.number_packed_samples)] -# context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) - - -# ### Note: shape of sample's values -# # video_latent: [num_patches, latents_channels * pF * pH * pW] -# # grid_size: [F_patches, W_patches, H_patches] -# # context_embeddings: [context_seq_len, text_embedding_dim] - -# batch = dict( -# video_latents=video_latents_packed.unsqueeze(1), -# context_embeddings=context_embeddings_packed.unsqueeze(1), -# loss_mask=loss_masks_packed.unsqueeze(1), -# seq_len_q=seq_len_q_packed, -# seq_len_q_padded=seq_len_q_padded_packed, -# seq_len_kv=seq_len_kv_packed, -# grid_sizes=grid_sizes_packed, -# video_metadata=video_metadata, -# ) - -# return batch - def mock_batch( F_latents: int, H_latents: int, diff --git a/dfm/src/megatron/data/wan/wan_sample.py b/dfm/src/megatron/data/wan/wan_sample.py deleted file mode 100644 index 705682de..00000000 --- a/dfm/src/megatron/data/wan/wan_sample.py +++ /dev/null @@ -1,23 +0,0 @@ -# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - -from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample - - -@dataclass -class WanSample(DiffusionSample): - video_metadata: dict = None - seq_len_q_padded: int = None diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 5cb686d3..3dff057b 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -24,7 +24,7 @@ from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking -from dfm.src.megatron.data.wan.wan_sample import WanSample +from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify @@ -130,7 +130,7 @@ def encode_sample(self, sample: dict) -> dict: # grid_size: [F_patches, W_patches, H_patches] # context_embeddings: [context_seq_len, text_embedding_dim] - return WanSample( + return DiffusionSample( __key__=sample["__key__"], __restore_key__=sample["__restore_key__"], __subflavor__=None, @@ -142,40 +142,16 @@ def encode_sample(self, sample: dict) -> dict: seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), + pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids video_metadata=video_metadata, ) # NOTE: - # the method select_samples_to_pack() is inherited from the parent + # the method select_samples_to_pack() and pack_selected_samples() are inherited from the parent # class DiffusionTaskEncoderWithSequencePacking @stateless - def pack_selected_samples(self, samples: List[WanSample]) -> WanSample: - """Construct a new Wan sample by concatenating the sequences.""" - - def stack(attr): - return torch.stack([getattr(sample, attr) for sample in samples], dim=0) - - def cat(attr): - return torch.cat([getattr(sample, attr) for sample in samples], dim=0) - - return WanSample( - __key__=",".join([s.__key__ for s in samples]), - __restore_key__=(), # Will be set by energon based on `samples` - __subflavor__=None, - __subflavors__=samples[0].__subflavors__, - video=cat("video"), - context_embeddings=cat("context_embeddings"), - loss_mask=cat("loss_mask"), - seq_len_q=cat("seq_len_q"), - seq_len_q_padded=cat("seq_len_q_padded"), - seq_len_kv=cat("seq_len_kv"), - latent_shape=stack("latent_shape"), - video_metadata=[sample.video_metadata for sample in samples], - ) - - @stateless - def batch(self, samples: List[WanSample]) -> dict: + def batch(self, samples: List[DiffusionSample]) -> dict: """Return dictionary with data for batch.""" if self.packing_buffer_size is None: # no packing From a13d0c05f3325c1ddce960ac618e2f3361baffb5 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 11 Nov 2025 11:28:37 -0800 Subject: [PATCH 45/80] lint, ruff --- .../megatron/data/common/diffusion_sample.py | 2 +- .../data/wan/wan_mock_energon_datamodule.py | 22 +++++++++---------- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- dfm/src/megatron/model/wan/wan_step.py | 4 ++-- examples/megatron/recipes/wan/pretrain_wan.py | 4 ++-- 5 files changed, 16 insertions(+), 18 deletions(-) diff --git a/dfm/src/megatron/data/common/diffusion_sample.py b/dfm/src/megatron/data/common/diffusion_sample.py index ccf515c8..eb392d38 100644 --- a/dfm/src/megatron/data/common/diffusion_sample.py +++ b/dfm/src/megatron/data/common/diffusion_sample.py @@ -55,7 +55,7 @@ class DiffusionSample(Sample): pos_ids: Optional[torch.Tensor] = None latent_shape: Optional[torch.Tensor] = None video_metadata: Optional[dict] = None - + def to_dict(self) -> dict: """Converts the sample to a dictionary.""" return dict( diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index 8f2d452c..d15ebb56 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -17,11 +17,10 @@ from dataclasses import dataclass import torch -from torch.utils.data import DataLoader, Dataset from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider -from typing import List +from torch.utils.data import DataLoader, Dataset + from dfm.src.megatron.model.wan.utils import patchify -from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys class _MockDataset(Dataset): def __init__(self, length: int): @@ -34,14 +33,14 @@ def __getitem__(self, idx: int) -> dict: return {} def mock_batch( - F_latents: int, - H_latents: int, - W_latents: int, - patch_temporal: int, - patch_spatial: int, - number_packed_samples: int, - context_seq_len: int, - context_embeddings_dim: int, + F_latents: int, + H_latents: int, + W_latents: int, + patch_temporal: int, + patch_spatial: int, + number_packed_samples: int, + context_seq_len: int, + context_embeddings_dim: int, ) -> dict: # set mock values for one video sample @@ -75,7 +74,6 @@ def mock_batch( context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) - ### Note: shape of sample's values # video_latent: [num_patches, latents_channels * pF * pH * pW] # grid_size: [F_patches, W_patches, H_patches] diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 3dff057b..2ac3712e 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -23,8 +23,8 @@ from megatron.energon.task_encoder.base import stateless from megatron.energon.task_encoder.cooking import Cooker, basic_sample_keys -from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking from dfm.src.megatron.data.common.diffusion_sample import DiffusionSample +from dfm.src.megatron.data.common.diffusion_task_encoder_with_sp import DiffusionTaskEncoderWithSequencePacking from dfm.src.megatron.model.wan.utils import grid_sizes_calculation, patchify diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index c82d0f5e..32a5b5e6 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -77,8 +77,8 @@ def __init__( logit_std: float = 1.0, flow_shift: float = 3.0, mix_uniform_ratio: float = 0.1, - sigma_min: float = 0.0, # Default: no clamping (pretrain) - sigma_max: float = 1.0, # Default: no clamping (pretrain) + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) ): self.diffusion_pipeline = FlowPipeline() self.use_sigma_noise = use_sigma_noise diff --git a/examples/megatron/recipes/wan/pretrain_wan.py b/examples/megatron/recipes/wan/pretrain_wan.py index 1509565f..8bbd5e03 100644 --- a/examples/megatron/recipes/wan/pretrain_wan.py +++ b/examples/megatron/recipes/wan/pretrain_wan.py @@ -88,10 +88,10 @@ def parse_cli_args() -> Tuple[argparse.Namespace, list[str]]: ) parser.add_argument("--mock", action="store_true", help="Whether to use mock data.") parser.add_argument( - "--training-mode", + "--training-mode", choices=["pretrain", "finetune"], default="finetune", - help="Set training mode, 'pretrain' or 'finetune'." + help="Set training mode, 'pretrain' or 'finetune'.", ) parser.add_argument( "--config-file", From 39b0e7321c2acf6ceff9e4aae7aac4ae1f1c22d9 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 11 Nov 2025 11:42:58 -0800 Subject: [PATCH 46/80] lint, ruff --- dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py | 5 ++++- dfm/src/megatron/data/wan/wan_taskencoder.py | 2 +- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 4 ++-- dfm/src/megatron/model/wan/wan_provider.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index d15ebb56..7f3e1a19 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -17,11 +17,12 @@ from dataclasses import dataclass import torch -from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from torch.utils.data import DataLoader, Dataset +from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider from dfm.src.megatron.model.wan.utils import patchify + class _MockDataset(Dataset): def __init__(self, length: int): self.length = max(int(length), 1) @@ -32,6 +33,7 @@ def __len__(self) -> int: def __getitem__(self, idx: int) -> dict: return {} + def mock_batch( F_latents: int, H_latents: int, @@ -92,6 +94,7 @@ def mock_batch( return batch + @dataclass(kw_only=True) class WanMockDataModuleConfig(DatasetProvider): path: str = "" diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 2ac3712e..778a7051 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -142,7 +142,7 @@ def encode_sample(self, sample: dict) -> dict: seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), - pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids + pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids video_metadata=video_metadata, ) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 11a9baa6..c90cda91 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -43,8 +43,8 @@ def training_step( logit_std: float = 1.0, flow_shift: float = 3.0, mix_uniform_ratio: float = 0.1, - sigma_min: float = 0.0, # Default: no clamping (pretrain) - sigma_max: float = 1.0, # Default: no clamping (pretrain) + sigma_min: float = 0.0, # Default: no clamping (pretrain) + sigma_max: float = 1.0, # Default: no clamping (pretrain) ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: """ Performs a single training step using flow matching algorithm. diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index df591345..24e8c87d 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -30,7 +30,7 @@ @dataclass class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): - crossattn_emb_size: int = 1536 # cross attention emebedding size after linear projection + crossattn_emb_size: int = 1536 # cross attention emebedding size after linear projection add_bias_linear: bool = True gated_linear_unit: bool = False From 4647d8931bd58cfc6e596386d4e75a9bd0a4951a Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Tue, 11 Nov 2025 11:45:26 -0800 Subject: [PATCH 47/80] lint, ruff --- dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index 7f3e1a19..e8db2d67 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -17,8 +17,8 @@ from dataclasses import dataclass import torch -from torch.utils.data import DataLoader, Dataset from megatron.bridge.data.utils import DatasetBuildContext, DatasetProvider +from torch.utils.data import DataLoader, Dataset from dfm.src.megatron.model.wan.utils import patchify @@ -44,7 +44,6 @@ def mock_batch( context_seq_len: int, context_embeddings_dim: int, ) -> dict: - # set mock values for one video sample video_latent = torch.randn(16, F_latents, H_latents, W_latents, dtype=torch.float32) grid_size = torch.tensor( From 174bb7b34de002ebbbcae1ba8e2b12363c7dee01 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 08:21:10 -0800 Subject: [PATCH 48/80] fix CP error (input of thd_split_inputs_cp to be cu_seqlens_q_padded instead of cu_seqlens_q) --- .../megatron/data/common/diffusion_sample.py | 2 + .../common/diffusion_task_encoder_with_sp.py | 1 + ...n_datamodule.py => wan_mock_datamodule.py} | 3 ++ dfm/src/megatron/data/wan/wan_taskencoder.py | 6 +++ .../model/wan/flow_matching/flow_pipeline.py | 12 +++-- dfm/src/megatron/model/wan/utils.py | 50 ------------------- dfm/src/megatron/model/wan/wan_layer_spec.py | 5 +- dfm/src/megatron/model/wan/wan_step.py | 8 ++- dfm/src/megatron/recipes/wan/wan.py | 2 +- .../megatron/recipes/wan/README_perf_test.md | 11 ++-- 10 files changed, 35 insertions(+), 65 deletions(-) rename dfm/src/megatron/data/wan/{wan_mock_energon_datamodule.py => wan_mock_datamodule.py} (96%) diff --git a/dfm/src/megatron/data/common/diffusion_sample.py b/dfm/src/megatron/data/common/diffusion_sample.py index eb392d38..8efcc79c 100644 --- a/dfm/src/megatron/data/common/diffusion_sample.py +++ b/dfm/src/megatron/data/common/diffusion_sample.py @@ -52,6 +52,7 @@ class DiffusionSample(Sample): seq_len_q: Optional[torch.Tensor] = None seq_len_q_padded: Optional[torch.Tensor] = None seq_len_kv: Optional[torch.Tensor] = None + seq_len_kv_padded: Optional[torch.Tensor] = None pos_ids: Optional[torch.Tensor] = None latent_shape: Optional[torch.Tensor] = None video_metadata: Optional[dict] = None @@ -70,6 +71,7 @@ def to_dict(self) -> dict: seq_len_q=self.seq_len_q, seq_len_q_padded=self.seq_len_q_padded, seq_len_kv=self.seq_len_kv, + seq_len_kv_padded=self.seq_len_kv_padded, pos_ids=self.pos_ids, latent_shape=self.latent_shape, video_metadata=self.video_metadata, diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index 49e2b707..f3fae0b0 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -104,6 +104,7 @@ def cat(attr): seq_len_q=cat("seq_len_q"), seq_len_q_padded=cat("seq_len_q_padded"), seq_len_kv=cat("seq_len_kv"), + seq_len_kv_padded=cat("seq_len_kv_padded"), pos_ids=cat("pos_ids"), latent_shape=stack("latent_shape"), video_metadata=[sample.video_metadata for sample in samples], diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py similarity index 96% rename from dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py rename to dfm/src/megatron/data/wan/wan_mock_datamodule.py index e8db2d67..87bb8dc6 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -61,6 +61,7 @@ def mock_batch( loss_mask = torch.ones(seq_len_q, dtype=torch.bfloat16) context_embeddings = torch.randn(context_seq_len, context_embeddings_dim, dtype=torch.float32) seq_len_kv = context_embeddings.shape[0] + seq_len_kv_padded = seq_len_kv video_metadata = {} # set mock values for packed video samples @@ -71,6 +72,7 @@ def mock_batch( seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_padded_packed = torch.tensor([seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32) grid_sizes_packed = torch.stack([grid_size for _ in range(number_packed_samples)], dim=0) context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) @@ -87,6 +89,7 @@ def mock_batch( seq_len_q=seq_len_q_packed, seq_len_q_padded=seq_len_q_padded_packed, seq_len_kv=seq_len_kv_packed, + seq_len_kv_padded=seq_len_kv_padded_packed, grid_sizes=grid_sizes_packed, video_metadata=video_metadata, ) diff --git a/dfm/src/megatron/data/wan/wan_taskencoder.py b/dfm/src/megatron/data/wan/wan_taskencoder.py index 778a7051..1fb7a0ae 100644 --- a/dfm/src/megatron/data/wan/wan_taskencoder.py +++ b/dfm/src/megatron/data/wan/wan_taskencoder.py @@ -117,13 +117,16 @@ def encode_sample(self, sample: dict) -> dict: if parallel_state.get_context_parallel_world_size() > 1: sharding_factor = parallel_state.get_context_parallel_world_size() * 2 seq_len_q_padded = ((seq_len_q + sharding_factor - 1) // sharding_factor) * sharding_factor + seq_len_kv_padded = ((seq_len_kv + sharding_factor - 1) // sharding_factor) * sharding_factor else: seq_len_q_padded = seq_len_q + seq_len_kv_padded = seq_len_kv # padding if seq_len_q < seq_len_q_padded: video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len_q)) loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len_q)) + context_embeddings = F.pad(context_embeddings, (0, 0, 0, seq_len_kv_padded - seq_len_kv)) ### Note: shape of sample's values # video_latent: [num_patches, latents_channels * pF * pH * pW] @@ -142,6 +145,7 @@ def encode_sample(self, sample: dict) -> dict: seq_len_q=torch.tensor([seq_len_q], dtype=torch.int32), seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32), seq_len_kv=torch.tensor([seq_len_kv], dtype=torch.int32), + seq_len_kv_padded=torch.tensor([seq_len_kv_padded], dtype=torch.int32), pos_ids=torch.zeros(1, dtype=torch.bfloat16), # dummy pos_ids video_metadata=video_metadata, ) @@ -179,6 +183,7 @@ def batch(self, samples: List[DiffusionSample]) -> dict: seq_len_q=sample.seq_len_q, seq_len_q_padded=sample.seq_len_q_padded, seq_len_kv=sample.seq_len_kv, + seq_len_kv_padded=sample.seq_len_kv_padded, grid_sizes=sample.latent_shape, video_metadata=sample.video_metadata, ) @@ -190,6 +195,7 @@ def batch(self, samples: List[DiffusionSample]) -> dict: # seq_len_q: [num_samples] # seq_len_q_padded: [num_samples] # seq_len_kv: [num_samples] + # seq_len_kv_padded: [num_samples] # grid_sizes: [num_samples, 3] # video_metadata: [num_samples] diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index c90cda91..9fdc0695 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -182,25 +182,27 @@ def training_step( if parallel_state.get_context_parallel_world_size() > 1: video_latents = thd_split_inputs_cp( video_latents, - packed_seq_params["self_attention"].cu_seqlens_q, + packed_seq_params["self_attention"].cu_seqlens_q_padded, parallel_state.get_context_parallel_group(), ) noisy_latents = thd_split_inputs_cp( noisy_latents, - packed_seq_params["self_attention"].cu_seqlens_q, + packed_seq_params["self_attention"].cu_seqlens_q_padded, parallel_state.get_context_parallel_group(), ) noise = thd_split_inputs_cp( - noise, packed_seq_params["self_attention"].cu_seqlens_q, parallel_state.get_context_parallel_group() + noise, + packed_seq_params["self_attention"].cu_seqlens_q_padded, + parallel_state.get_context_parallel_group() ) context_embeddings = thd_split_inputs_cp( context_embeddings, - packed_seq_params["cross_attention"].cu_seqlens_kv, + packed_seq_params["cross_attention"].cu_seqlens_kv_padded, parallel_state.get_context_parallel_group(), ) split_loss_mask = thd_split_inputs_cp( loss_mask, - packed_seq_params["self_attention"].cu_seqlens_q, + packed_seq_params["self_attention"].cu_seqlens_q_padded, parallel_state.get_context_parallel_group(), ) else: diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index ac4de4e6..998ee7ed 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -99,56 +99,6 @@ def unpatchify( out.append(u) return out - -def split_inputs_cp(x: torch.Tensor, seq_dim: int = 0) -> torch.Tensor: - """ - Split input tensor along the sequence dimension for context parallelism. - - Args: - x: Input tensor to be split. (e.g. shape [seq_len, batch_size, ...]) - seq_dim: The dimension along which to split the input (sequence dimension). - - Returns: - A slice of the input tensor corresponding to the current rank. (e.g. shape [seq_len/cp_size, batch_size, ...]) - """ - - cp_size = parallel_state.get_context_parallel_world_size() - if cp_size > 1: - cp_rank = parallel_state.get_context_parallel_rank() - assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}" - x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim + 1) :]) - seq_idx = torch.tensor([cp_rank], device=x.device) - x = x.index_select(seq_dim, seq_idx) - # Note that the new sequence length is the original sequence length / cp_size - x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) - return x - - -def cat_outputs_cp(x: torch.Tensor, seq_dim: int) -> torch.Tensor: - """ - Concatenate tensors from multiple processes along a specified dimension. - - Args: - x: Input tensor to be concatenated. (e.g. shape [seq_len/cp_size, batch_size, ...]) - seq_dim: The dimension along which to concatenate the input tensors. - - Returns: - A tensor with the concatenated tensors. (e.g. shape [seq_len, batch_size, ...]) - """ - - cp_group = parallel_state.get_context_parallel_group() - cp_size = parallel_state.get_context_parallel_world_size() - if cp_size > 1: - gathered_tensors = [torch.zeros_like(x) for _ in range(cp_size)] - # Attempt to gather tensors from all ranks - # PyTorch’s all_gather orders outputs by rank within the group, which matches how chunks were selected by cp_rank - all_gather(gathered_tensors, x, group=cp_group) - gathered_tensors = torch.cat(gathered_tensors, dim=seq_dim) - return gathered_tensors - else: - return x - - def thd_split_inputs_cp( x: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, cp_group: dist.ProcessGroup ) -> torch.Tensor: diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 053e3e94..2b355930 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -20,6 +20,7 @@ import torch import torch.nn as nn from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.jit import jit_fuser from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.attention import SelfAttentionSubmodules from megatron.core.transformer.custom_layers.transformer_engine import ( @@ -68,11 +69,11 @@ def forward(self, timestep_emb): e = (self.modulation + timestep_emb).chunk(6, dim=1) return e - # @jit_fuser + @jit_fuser def modulate(self, x, shift, scale): return x * (1 + scale) + shift - # @jit_fuser + @jit_fuser def scale_add(self, residual, x, gate): return residual + gate * x diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index 32a5b5e6..d60a00b9 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -38,17 +38,20 @@ def wan_data_step(qkv_format, dataloader_iter): # Construct packed sequence parameters if ("seq_len_q" in batch) and ("seq_len_kv" in batch): - cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) zero = torch.zeros(1, dtype=torch.int32, device="cuda") + + cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) cu_seqlens = torch.cat((zero, cu_seqlens)) cu_seqlens_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32) - zero = torch.zeros(1, dtype=torch.int32, device="cuda") cu_seqlens_padded = torch.cat((zero, cu_seqlens_padded)) cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32) cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv)) + cu_seqlens_kv_padded = batch["seq_len_kv_padded"].cumsum(dim=0).to(torch.int32) + cu_seqlens_kv_padded = torch.cat((zero, cu_seqlens_kv_padded)) + batch["packed_seq_params"] = { "self_attention": PackedSeqParams( cu_seqlens_q=cu_seqlens, @@ -61,6 +64,7 @@ def wan_data_step(qkv_format, dataloader_iter): cu_seqlens_q=cu_seqlens, cu_seqlens_q_padded=cu_seqlens_padded, cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, qkv_format=qkv_format, ), } diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 21bd1b50..f150363d 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -31,7 +31,7 @@ from megatron.core.distributed import DistributedDataParallelConfig from dfm.src.megatron.data.wan.wan_energon_datamodule import WanDataModuleConfig -from dfm.src.megatron.data.wan.wan_mock_energon_datamodule import WanMockDataModuleConfig +from dfm.src.megatron.data.wan.wan_mock_datamodule import WanMockDataModuleConfig from dfm.src.megatron.model.wan.wan_provider import WanModelProvider diff --git a/examples/megatron/recipes/wan/README_perf_test.md b/examples/megatron/recipes/wan/README_perf_test.md index f62b88c2..7e03e0c2 100644 --- a/examples/megatron/recipes/wan/README_perf_test.md +++ b/examples/megatron/recipes/wan/README_perf_test.md @@ -28,7 +28,7 @@ cd /opt/ # DFM (pinned) git clone --no-checkout https://github.com/NVIDIA-NeMo/DFM.git -git -C DFM checkout aa2050466b8b9b8844d754cc61ea93c1f7a0e90e +git -C DFM checkout 4647d8931bd58cfc6e596386d4e75a9bd0a4951a export DFM_PATH=/opt/DFM # Megatron-Bridge (pinned) @@ -67,6 +67,7 @@ cd ${DFM_PATH} ```bash NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pretrain_wan.py \ + --training-mode pretrain \ model.tensor_model_parallel_size=1 \ model.pipeline_model_parallel_size=1 \ model.context_parallel_size=4 \ @@ -102,6 +103,7 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pret ```bash NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pretrain_wan.py \ + --training-mode pretrain \ model.tensor_model_parallel_size=2 \ model.pipeline_model_parallel_size=1 \ model.context_parallel_size=4 \ @@ -138,9 +140,8 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pret ### Using mock data (optional, for debugging) -- Edit `dfm/src/megatron/data/wan/wan_taskencoder.py`. -- Comment out the production `encode_sample()` and uncomment the mock version. -- Adjust `video_size` (F_latents, H_latents, W_latents). Total `seq_len = F * H * W`. +- Using `--mock` argument. +- Adjust `video_size` (F_latents, H_latents, W_latents) and `number_packed_samples` of `WanMockDataModuleConfig` in `wan.py`. Total `seq_len = F * H * W * number_packed_samples`. ## Inference @@ -152,7 +153,7 @@ T5_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/t5" VAE_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/wan_checkpoints/vae" CKPT_DIR="/lustre/fsw/coreai_dlalgo_genai/huvu/data/nemo_vfm/datasets/shared_checkpoints/megatron_checkpoint_1.3B" -NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/megatron/recipes/wan/inference_wan.py \ +NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=1 examples/megatron/recipes/wan/inference_wan.py \ --task t2v-1.3B \ --sizes 480*832 \ --checkpoint_dir "${CKPT_DIR}" \ From 462638af88d80353bd55ea1afc4737b17e98676c Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 08:22:35 -0800 Subject: [PATCH 49/80] udpate README_perf_test.md --- examples/megatron/recipes/wan/README_perf_test.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/megatron/recipes/wan/README_perf_test.md b/examples/megatron/recipes/wan/README_perf_test.md index 7e03e0c2..0344e6e9 100644 --- a/examples/megatron/recipes/wan/README_perf_test.md +++ b/examples/megatron/recipes/wan/README_perf_test.md @@ -28,7 +28,7 @@ cd /opt/ # DFM (pinned) git clone --no-checkout https://github.com/NVIDIA-NeMo/DFM.git -git -C DFM checkout 4647d8931bd58cfc6e596386d4e75a9bd0a4951a +git -C DFM checkout 174bb7b34de002ebbbcae1ba8e2b12363c7dee01 export DFM_PATH=/opt/DFM # Megatron-Bridge (pinned) From f5c10a1a5bcd069eb0db4dab97ef27facc64fa51 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 08:25:03 -0800 Subject: [PATCH 50/80] fix lint, ruff --- dfm/src/megatron/data/wan/wan_mock_datamodule.py | 4 +++- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 2 +- dfm/src/megatron/model/wan/utils.py | 3 +-- examples/megatron/recipes/wan/README_perf_test.md | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index 87bb8dc6..844ecd9b 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -72,7 +72,9 @@ def mock_batch( seq_len_q_packed = torch.tensor([seq_len_q for _ in range(number_packed_samples)], dtype=torch.int32) seq_len_q_padded_packed = torch.tensor([seq_len_q_padded for _ in range(number_packed_samples)], dtype=torch.int32) seq_len_kv_packed = torch.tensor([seq_len_kv for _ in range(number_packed_samples)], dtype=torch.int32) - seq_len_kv_padded_packed = torch.tensor([seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32) + seq_len_kv_padded_packed = torch.tensor( + [seq_len_kv_padded for _ in range(number_packed_samples)], dtype=torch.int32 + ) grid_sizes_packed = torch.stack([grid_size for _ in range(number_packed_samples)], dim=0) context_embeddings_packed = [context_embeddings for _ in range(number_packed_samples)] context_embeddings_packed = torch.cat(context_embeddings_packed, dim=0) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 9fdc0695..d5eeda47 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -193,7 +193,7 @@ def training_step( noise = thd_split_inputs_cp( noise, packed_seq_params["self_attention"].cu_seqlens_q_padded, - parallel_state.get_context_parallel_group() + parallel_state.get_context_parallel_group(), ) context_embeddings = thd_split_inputs_cp( context_embeddings, diff --git a/dfm/src/megatron/model/wan/utils.py b/dfm/src/megatron/model/wan/utils.py index 998ee7ed..1de1fe34 100644 --- a/dfm/src/megatron/model/wan/utils.py +++ b/dfm/src/megatron/model/wan/utils.py @@ -15,11 +15,9 @@ import math from typing import Tuple -import megatron.core.parallel_state as parallel_state import torch import torch.distributed as dist import transformer_engine_torch as tex -from torch.distributed import all_gather def grid_sizes_calculation( @@ -99,6 +97,7 @@ def unpatchify( out.append(u) return out + def thd_split_inputs_cp( x: torch.Tensor, cu_seqlens_q_padded: torch.Tensor, cp_group: dist.ProcessGroup ) -> torch.Tensor: diff --git a/examples/megatron/recipes/wan/README_perf_test.md b/examples/megatron/recipes/wan/README_perf_test.md index 0344e6e9..d95a4c14 100644 --- a/examples/megatron/recipes/wan/README_perf_test.md +++ b/examples/megatron/recipes/wan/README_perf_test.md @@ -140,7 +140,7 @@ NVTE_FUSED_ATTN=1 torchrun --nproc_per_node=8 examples/megatron/recipes/wan/pret ### Using mock data (optional, for debugging) -- Using `--mock` argument. +- Using `--mock` argument. - Adjust `video_size` (F_latents, H_latents, W_latents) and `number_packed_samples` of `WanMockDataModuleConfig` in `wan.py`. Total `seq_len = F * H * W * number_packed_samples`. ## Inference From 13968fc7cdcb9d8279b9894c40ed4389e886c367 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 10:15:12 -0800 Subject: [PATCH 51/80] update uv.lock, merge main --- 3rdparty/Automodel | 2 +- 3rdparty/Megatron-Bridge | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/3rdparty/Automodel b/3rdparty/Automodel index 8134b0c0..a5f06522 160000 --- a/3rdparty/Automodel +++ b/3rdparty/Automodel @@ -1 +1 @@ -Subproject commit 8134b0c039802fb3f6161571400ef7085dd1e9cb +Subproject commit a5f06522d4f8ef67bb9bbdd9502e50ae27d2fee5 diff --git a/3rdparty/Megatron-Bridge b/3rdparty/Megatron-Bridge index 97441068..8e21f81a 160000 --- a/3rdparty/Megatron-Bridge +++ b/3rdparty/Megatron-Bridge @@ -1 +1 @@ -Subproject commit 9744106801797ffbf2c79a1554bc6462aaf2b45e +Subproject commit 8e21f81ab961bdb0ad99a275074fe50aae15d2f9 From 46aa6d80dbe15a03f26025a0a7fe9cab58f0c6ed Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 11:13:59 -0800 Subject: [PATCH 52/80] uv.lock --- uv.lock | 68 ++++++++++++++++++++++++++++----------------------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/uv.lock b/uv.lock index d8357b84..5161ef7b 100644 --- a/uv.lock +++ b/uv.lock @@ -1217,40 +1217,40 @@ wheels = [ [[package]] name = "cython" -version = "3.2.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/82/01f0b63287cb922e5ba96c5147c30f1e51f541ce91bd178025bb3518b1ba/cython-3.2.0.tar.gz", hash = "sha256:41fdce8237baee2d961c292ed0386903dfe126f131e450a62de0fd7a5280d4b2", size = 3267264, upload-time = "2025-11-05T13:35:04.231Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/57/8d/b2e9578d960d38b1b04a278bf66e13008486aa73e73967186f2015d63d1c/cython-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ee408125b2d218ec7d7a061e09d24715fcab9bf7ea1a4ac01907c3f8ec8730b3", size = 2953775, upload-time = "2025-11-05T13:35:22.291Z" }, - { url = "https://files.pythonhosted.org/packages/19/dd/cfd684f98bac9e0f505af1cbb7998498c59d713275e920a72b40dab03bfa/cython-3.2.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c93ce307b05fcd86a5bb0e4a7d7fab238e2f0e9936636097a60bc0e21f2def30", size = 3361627, upload-time = "2025-11-05T13:35:24.519Z" }, - { url = "https://files.pythonhosted.org/packages/9c/c1/75acdbe9f6292514f0bb92ab1b78df5eedd7049235f4cbd194d2c6c46bfc/cython-3.2.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:191cfc2fa84642ad41a52d5abaacfb330d9a6653a465e4bf0a5681f66197a967", size = 3529751, upload-time = "2025-11-05T13:35:26.341Z" }, - { url = "https://files.pythonhosted.org/packages/f2/ce/d0468eb6d87b956902b02909f5007ad61e3839d4c07ab235b514911d869b/cython-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:a259053037ef82959b743b7fde238bd191ee43f88eb8e51101d5f3d8849f1e32", size = 2758839, upload-time = "2025-11-05T13:35:28.36Z" }, - { url = "https://files.pythonhosted.org/packages/ff/2b/904493fceda95747ba83971b40a66c8cc29ff009313429903f38ee620140/cython-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e9e4b2248dc3a98b86aeba65e9862d2cc881d072c163c0fb31b511d4d72e93c8", size = 2946248, upload-time = "2025-11-05T13:35:30.406Z" }, - { url = "https://files.pythonhosted.org/packages/89/fe/abe926699fe6c580967e30bc4035da54b5e31355ba9b1f4c0cf574228a84/cython-3.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:02fb4990a83d5d6f780dda18ed8baa8d587cb6523f57b4d72bc0b41ad3766c96", size = 3236384, upload-time = "2025-11-05T13:35:32.233Z" }, - { url = "https://files.pythonhosted.org/packages/1b/36/6b6266549802234286438298d494152deb19922a94928d9dcd256659ebd1/cython-3.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8a98925517819d62ea25d2cf40057df60a9bcf75fdd1d6ed3882e6ae0730d82f", size = 3372915, upload-time = "2025-11-05T13:35:34.082Z" }, - { url = "https://files.pythonhosted.org/packages/29/fa/5cf15466b428f9248e38a28515cf0fd98078ae869aa395cfb300315964c4/cython-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:4c959a5d4cd6331e8498822ba47200bd2ff4bf74517c0c91475d5bc21da3b4d5", size = 2762735, upload-time = "2025-11-05T13:35:35.806Z" }, - { url = "https://files.pythonhosted.org/packages/57/d3/2e6f5f2552c860bb9c00653d092103521846114f6a2ae0648ecf84c0816c/cython-3.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:511d823d9f8a1b850178ec355d6df0a1731b9c20b08ee6d1a780f68215e9013f", size = 2959932, upload-time = "2025-11-05T13:35:37.518Z" }, - { url = "https://files.pythonhosted.org/packages/dd/bf/7bdc7f231fff6780f78586f939c1740475adecaa03bf256fcb62b2353952/cython-3.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bbadeedcb2d135655bcce7380fb28c9e2a75b6810426c12b6e5a6fe6106fafb4", size = 3218588, upload-time = "2025-11-05T13:35:39.642Z" }, - { url = "https://files.pythonhosted.org/packages/be/81/7d7a81010897dc5abee59691f5fc85849dcc4c8a7687b22ed01bc8d86a7a/cython-3.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92d2394a3e3fe704210b5324eb8118333b514af72c98b1e02a6503945825b231", size = 3381940, upload-time = "2025-11-05T13:35:41.886Z" }, - { url = "https://files.pythonhosted.org/packages/4f/9d/35e7fb7b591bd9912685a772fcc773d7bb951a8feb6fb9be20addbc38928/cython-3.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:73435e56654a34ece57d4c3304a4556a8402cc4ae2d0e30f71c237a985dc5246", size = 2750886, upload-time = "2025-11-05T13:35:43.629Z" }, - { url = "https://files.pythonhosted.org/packages/5d/d0/dc4b260e8fde81b23ab4dca56948b3e69617ef470247ec6a3e09370a9849/cython-3.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d900e58e826f9a5a27b0e2b50e33473e9986a5bae375c39b0f2e19f2c545fa23", size = 2950437, upload-time = "2025-11-05T13:35:45.427Z" }, - { url = "https://files.pythonhosted.org/packages/c8/53/c322bf0486a938ad954a645866b67e978777d79183cf0a042bda6bea11de/cython-3.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a9d38cd3aab720d21fa6d6ee168228352f69aea0a95bd4fb84e8879c6ed38fbb", size = 3209331, upload-time = "2025-11-05T13:35:47.278Z" }, - { url = "https://files.pythonhosted.org/packages/cd/48/55d02dba0606768d3450afd088e2bbcd6f8a54977dce041c2c3c1894631c/cython-3.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92b31d0b7b0a49b3d2aa94faaf75d44a03174cff2616b341a8853c919e511d51", size = 3370974, upload-time = "2025-11-05T13:35:49.534Z" }, - { url = "https://files.pythonhosted.org/packages/ce/bd/6dab19652b68464572b7a137d07a91ebe86db2a81c35842ff5e49ef23403/cython-3.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:2847b74e76dbad612f6fc7182c12a5f78cffb0d05808fd2c4b638cf02d1aade6", size = 2746274, upload-time = "2025-11-05T13:35:51.522Z" }, - { url = "https://files.pythonhosted.org/packages/e2/db/de5331ca6489da1761078825709257e1f24e543b4040f86a2502a4b841f9/cython-3.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:a0a8274959d538d12f865193dcd67bb5630906e020190c890d2b7c13d31713c6", size = 2961164, upload-time = "2025-11-05T13:35:53.826Z" }, - { url = "https://files.pythonhosted.org/packages/54/3e/64e37e419331f7c4c540ad25c0b3e6d8f44d597f21ab8861afbc66aa7e02/cython-3.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a1c800833c25195833805c7c3626a2c30b3baaaa9ba361a1af3bbc379662a8d", size = 3249627, upload-time = "2025-11-05T13:35:55.524Z" }, - { url = "https://files.pythonhosted.org/packages/9b/fc/9faedfcc2de807f77115d97a4910c260dd4693f4fa9e0e3be0d9ae89e260/cython-3.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:df15af08c21c18a2e848df5954d6fd3310735089b60405132fa4111e2cf7482a", size = 3375458, upload-time = "2025-11-05T13:35:57.279Z" }, - { url = "https://files.pythonhosted.org/packages/31/e0/30d449cd97ee0d6395aba18f2646b61b52ab3dc5a3851a346e2d363a7d85/cython-3.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:9d6876af2132757fff1b42a2f4eaa72482f991863160e3f0dc8f2c812b300ebf", size = 2783210, upload-time = "2025-11-05T13:35:59.54Z" }, - { url = "https://files.pythonhosted.org/packages/dd/6b/9e1e171fe19274465d84dffa4610d46f434b1ae945e946802db396695d67/cython-3.2.0-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:04821ce06598a3aa5c9e0270d98960cfe6556dedbd1418c65e4479162b8ae74a", size = 2869249, upload-time = "2025-11-05T13:36:08.944Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f1/f461726f664668a96072b2a245bdfae566d68e2eb1393ec72780cc59c21e/cython-3.2.0-cp39-abi3-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:54b5b1c72a63da822b3f4739a0e31546c0a19f8e834b174906bf817ed5f9d65f", size = 3204332, upload-time = "2025-11-05T13:36:11.386Z" }, - { url = "https://files.pythonhosted.org/packages/78/d8/73c07ce64cae496e5f5a6dfe3e53574af1a8ef777e2a834d10dae8b67a4e/cython-3.2.0-cp39-abi3-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:6155a6c360e32af1aaa16fa10b0119b49deeadff42a1958973324150870af1b5", size = 2851317, upload-time = "2025-11-05T13:36:13.14Z" }, - { url = "https://files.pythonhosted.org/packages/bc/d9/d9f321637b8034b5028fa5fe7d1085ffa9351fea350af6510d5cb924c014/cython-3.2.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:861258ac3878b76c57b9b5a379787d772a0bc47fec9167b43986777de542c474", size = 2987155, upload-time = "2025-11-05T13:36:15.018Z" }, - { url = "https://files.pythonhosted.org/packages/f8/b5/9f9e7d261f083b4066d734b27a7872b0c584fd4c3578196652dbf72b3f62/cython-3.2.0-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:85dbf955e3193893d0288105afa0fa5f4e835ff587061681f240a4f0487c44fb", size = 2884219, upload-time = "2025-11-05T13:36:17.334Z" }, - { url = "https://files.pythonhosted.org/packages/88/64/5aeb6e43e0ded9efedc5a516f87a487fdca8e434491cc352e5a805380459/cython-3.2.0-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:3b3f13822526726bac43275c0e92916bbcc2c30e9f559edc4c1132670b70498d", size = 3218067, upload-time = "2025-11-05T13:36:19.493Z" }, - { url = "https://files.pythonhosted.org/packages/c4/a0/1958f54cd79d8251a330b9c9652b2a5ceba6a3fcec10782dd03e2a23c74f/cython-3.2.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ab18d09673d219008be5b6174bcbb6dbfd50904e66371f104a8a4698b791472d", size = 3108277, upload-time = "2025-11-05T13:36:21.203Z" }, - { url = "https://files.pythonhosted.org/packages/9c/84/9b8112160cab922b97edef00616ed18771567d88b5ba9d30d1736880c345/cython-3.2.0-cp39-abi3-win32.whl", hash = "sha256:c9fd986413fc52929b916187630a9abab9f876299951488c4b905ad5346afee6", size = 2430852, upload-time = "2025-11-05T13:36:23.049Z" }, - { url = "https://files.pythonhosted.org/packages/8f/57/65d3de140b51c45dd6892846bfabdfaaa032e2418f1cb1a2f46058c1fe42/cython-3.2.0-cp39-abi3-win_arm64.whl", hash = "sha256:ee2ea79ddeb721f912e7efea039b9db059c81767ff04fbf9a995f64e1187df99", size = 2435793, upload-time = "2025-11-05T13:36:25.139Z" }, - { url = "https://files.pythonhosted.org/packages/20/58/1f798ddb7fe6bfddf85f4f97d2d4ad63a491a7b643e85c1e274d0f09138e/cython-3.2.0-py3-none-any.whl", hash = "sha256:73f7f4c75acde5b5b4df05b11fdc2705ec637b99241d1bc2f4ebf345f7a2ea90", size = 1252818, upload-time = "2025-11-05T13:35:00.391Z" }, +version = "3.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/36/cce2972e13e83ffe58bc73bfd9d37340b5e5113e8243841a57511c7ae1c2/cython-3.2.1.tar.gz", hash = "sha256:2be1e4d0cbdf7f4cd4d9b8284a034e1989b59fd060f6bd4d24bf3729394d2ed8", size = 3270455, upload-time = "2025-11-12T19:02:59.847Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/74/f9fe9e7034f24aef407e7816880c012d8e863bedaa6b42b9ff33e79ea139/cython-3.2.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f1d10b3731171a33563ba81fdcba39c229e45087269dfbe07a1c00e7dcb2537f", size = 2957374, upload-time = "2025-11-12T19:03:10.132Z" }, + { url = "https://files.pythonhosted.org/packages/65/47/f9dd519117f520aaf4d723c88fd9e9139262a0379edc01e71a1e9825e082/cython-3.2.1-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92b814b6066d178a5057b557d372e2a03854e947e41cb9dec21db732fbd14c3c", size = 3366838, upload-time = "2025-11-12T19:03:11.742Z" }, + { url = "https://files.pythonhosted.org/packages/5d/3e/d967acfafef00056c3ba832692b9bb358ede2919f641e4a2d24828adacc6/cython-3.2.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c9fc6abd0532007827d8c6143b2bfedf80c7cb89a3c1c12f058336663489ed2e", size = 3535901, upload-time = "2025-11-12T19:03:13.545Z" }, + { url = "https://files.pythonhosted.org/packages/68/79/bc46e714ecb010f80a8aa7f7eaf412c53cbabbe7489590d6aba5f4478ba5/cython-3.2.1-cp310-cp310-win_amd64.whl", hash = "sha256:14f1ed135347587cfddcd3c3219667cac4f0ea0b66aa1c4c0187d50a1b92c222", size = 2764043, upload-time = "2025-11-12T19:03:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/48/d4/ba7b9f341ec168de78bd659600e04bb7de3b2d069bf98b2178a135e88ea4/cython-3.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3cb32c650e7f4476941d1f735cae75a2067d5e3279576273bb8802e8ea907222", size = 2949720, upload-time = "2025-11-12T19:03:17.492Z" }, + { url = "https://files.pythonhosted.org/packages/ad/47/c42417f424c0b928361f48d7dd0ae72716ee21f647b73ceb16f66b98663e/cython-3.2.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8a2b306813d7f28aa0a2c3e4e63ada1427a8109917532df942cd5429db228252", size = 3242127, upload-time = "2025-11-12T19:03:19.227Z" }, + { url = "https://files.pythonhosted.org/packages/e6/fc/1040460889129551649ec35be45e05169871fbcf71bd8e13c533e86f9468/cython-3.2.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0959d9a36d4f004ce63acc1474b3c606745af98b65e8ae709efd0c10988e9d6b", size = 3377094, upload-time = "2025-11-12T19:03:21.25Z" }, + { url = "https://files.pythonhosted.org/packages/f8/f2/8c754298eefa40e21af0ae3592837c6e71254900d5aea1c8859e96b11de5/cython-3.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:60c62e734421365135cc2842013d883136054a26c617c001be494235edfc447a", size = 2767824, upload-time = "2025-11-12T19:03:23.317Z" }, + { url = "https://files.pythonhosted.org/packages/ee/0e/19d5041b87f98ed19c94c388607cd27c1f7458078c3bad5de2dead55b2e1/cython-3.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ea5097d97afd2ab14e98637b7033eba5146de29a5dedf89f5e946076396ab891", size = 2966736, upload-time = "2025-11-12T19:03:25.064Z" }, + { url = "https://files.pythonhosted.org/packages/84/b8/bcc36d9d2464348106984956608a52a42a01ab44ea64031207dffdebc078/cython-3.2.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4bf12de0475bb6a21e2336a4a04dc4a2b4dd0507a2a3c703e045f3484266605", size = 3221633, upload-time = "2025-11-12T19:03:26.754Z" }, + { url = "https://files.pythonhosted.org/packages/79/20/7d4807fe4ebcef9f20f2e5f93312d0f5d02f9f76524fd4e37706d04e83f7/cython-3.2.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:18c64a0f69a1b8164de70ec7efc72250c589fec21519170de21582300f6aaed9", size = 3389542, upload-time = "2025-11-12T19:03:28.656Z" }, + { url = "https://files.pythonhosted.org/packages/2a/92/b06ba6721299293bc41e89732070132c453bdbaaeabb8f8cc76851b75345/cython-3.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:5ba14907d5826d8010e82306ce279a0d3650f5b50a4813c80836a17b2213c520", size = 2755307, upload-time = "2025-11-12T19:03:30.684Z" }, + { url = "https://files.pythonhosted.org/packages/40/28/c6e36c214baeb27ae45b518552e74457536c7c964b1a55b5900b047fa467/cython-3.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b4e850fc7a2f72d19679dd083fe4d20bf66860fceabb4f3207112f240249d708", size = 2957307, upload-time = "2025-11-12T19:03:32.471Z" }, + { url = "https://files.pythonhosted.org/packages/c8/c8/b0b9ba64f81f2875c42aab5c0979d6454cd1ac6b3c1e2373ad552701565d/cython-3.2.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3d20ca4afe993f7dccad3aeddbf4c3536cb0fd3ad6dc7a225935a666a5655af2", size = 3210919, upload-time = "2025-11-12T19:03:34.274Z" }, + { url = "https://files.pythonhosted.org/packages/f9/33/5d9ca6abba0e77e1851b843dd1b3c4095fbc6373166935e83c4414f80e88/cython-3.2.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f5a54a757d01ca6a260b02ce5baf17d9db1c2253566ab5844ee4966ff2a69c19", size = 3373350, upload-time = "2025-11-12T19:03:35.927Z" }, + { url = "https://files.pythonhosted.org/packages/e4/29/4408c3486ff380a2d6ae0d4b71da5195efcef3c4360017113ee7d1cb7335/cython-3.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:1b81e56584727a328e00d91c164f8f0f2c59b02bf6857c3f000cd830fa571453", size = 2753425, upload-time = "2025-11-12T19:03:38.157Z" }, + { url = "https://files.pythonhosted.org/packages/f0/32/c1aa03ccadda89487ff31b90d8651c3706ce2744bf4f2c2ae213147e89bd/cython-3.2.1-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d7af6ad01c0fe1965d1d3badaeb6df53c1f37383ebae1ccb405b73f628f87713", size = 2967833, upload-time = "2025-11-12T19:03:40.233Z" }, + { url = "https://files.pythonhosted.org/packages/ff/dc/3488d3ade0635408a2ebb05561a3009e2f54616bfefd1f107088dfeb2c4c/cython-3.2.1-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e3ea7cd085b62acb67c0fbde5cd17a7d9e47992c965e81ec977cf9ea7c59cd65", size = 3256237, upload-time = "2025-11-12T19:03:42.005Z" }, + { url = "https://files.pythonhosted.org/packages/7b/ba/f3d35d3803c9a424fa8812893847114deb9e2440c1bc67a31ab9ec4b9355/cython-3.2.1-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:986aea38fdf231e78d73745f83271c5654852c822dc5141a1d3fba64429a6aa6", size = 3383100, upload-time = "2025-11-12T19:03:43.675Z" }, + { url = "https://files.pythonhosted.org/packages/86/dc/d72dbb2f8e7ca95d2d18fd86f32b2e385996576230e7ecddd7d250786825/cython-3.2.1-cp314-cp314-win_amd64.whl", hash = "sha256:4960e26cd34c1385f21646339f2e0361fcdd2ed3c01cdb50fe734add577ec56a", size = 2790322, upload-time = "2025-11-12T19:03:45.373Z" }, + { url = "https://files.pythonhosted.org/packages/5a/7e/1194f4ba98b981bbdca945a292e4f49e87ea09d69516b24445409e7cf611/cython-3.2.1-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:4e9167316bf6ecfea33dcca62f074605648fb93cc053ef46b5deb3e5d12fc0d3", size = 2872858, upload-time = "2025-11-12T19:03:55.074Z" }, + { url = "https://files.pythonhosted.org/packages/6b/1a/393ca8ffec7ad3f02b8e4bffaba3dba4fb62c4a1c4c0b6dbf3b80e709fe3/cython-3.2.1-cp39-abi3-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3095df6cd470064742f428c937bed7200c5123b9e19ee04aa09ec61281e565a3", size = 3209664, upload-time = "2025-11-12T19:03:56.771Z" }, + { url = "https://files.pythonhosted.org/packages/37/57/f209f64c609d3d8fac60a572e56da2f621dc1789e399c58db61d5645a31f/cython-3.2.1-cp39-abi3-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:db3f53b2d9afb206075a2605f1150aa019f0733c7795a38eccc6119c2e9c3f7b", size = 2854607, upload-time = "2025-11-12T19:03:59.413Z" }, + { url = "https://files.pythonhosted.org/packages/fc/af/1e5c73fe52423f40776130b0be914fd9f9f8dc26c4f6ea4c2ed04772d558/cython-3.2.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:0fc5e7687ac8f8e2b2fb95648f43e9e074ebaa72fd5cb3d8e20e5f1e8b8e02d9", size = 2991567, upload-time = "2025-11-12T19:04:02.209Z" }, + { url = "https://files.pythonhosted.org/packages/39/2c/3ea175b6b1fdfb429f9e9c395240d894155b3c0615caced05fef43264cba/cython-3.2.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:bbb3bc152bc0de82b031c8d355418fa4890a92424209d59366c2c0bc9e6cf53c", size = 2889178, upload-time = "2025-11-12T19:04:05.272Z" }, + { url = "https://files.pythonhosted.org/packages/f1/88/b2ab22a3a3feac78c62354a823c5c0c33659909e9918f53aa05904532b4b/cython-3.2.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:a2022bc48ad0c2c0e0485bf0b54902913a3d81086b7d435f4437620c667799f6", size = 3223755, upload-time = "2025-11-12T19:04:07.262Z" }, + { url = "https://files.pythonhosted.org/packages/0b/56/9ba58629a03cbffb5965a3c65ccd91fa683d95d588c21a875da72fdc249b/cython-3.2.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:99fdd4ffc2dcb513f4be9ce71c6fedd895b96b1f814655b6bbab196df497b090", size = 3113456, upload-time = "2025-11-12T19:04:09.175Z" }, + { url = "https://files.pythonhosted.org/packages/56/5b/148c1a7ea5aebe460a70cad716a77e5fd0205be2de9fc5250491eb13ad8c/cython-3.2.1-cp39-abi3-win32.whl", hash = "sha256:06071f85bd5ce040464d43b2f9f287742a79f905e81b709fe904567230f1ed51", size = 2434223, upload-time = "2025-11-12T19:04:11.294Z" }, + { url = "https://files.pythonhosted.org/packages/7a/54/bb9b0c9db2a92a5e93747ca3027cfc645741411f8f1c6af2fb2a7b82df5d/cython-3.2.1-cp39-abi3-win_arm64.whl", hash = "sha256:e87c131d59480aee1ebac622b64f287c0e1d665ad1a1b7d498ac48accdb36c6b", size = 2439268, upload-time = "2025-11-12T19:04:12.931Z" }, + { url = "https://files.pythonhosted.org/packages/aa/30/373775b8d933d781d055c1dd0f110f275a101f320dab724c8c63a7c1b945/cython-3.2.1-py3-none-any.whl", hash = "sha256:cd72c46e7bffe8250c52d400e72c8d5d3086437b6aeec5b0eca99ccd337f5834", size = 1254219, upload-time = "2025-11-12T19:02:56.14Z" }, ] [[package]] From 6b553ec7eaf9773575fcda206ded1c9728b42b50 Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 11:15:15 -0800 Subject: [PATCH 53/80] uv.lock --- .gitmodules | 1 - 1 file changed, 1 deletion(-) diff --git a/.gitmodules b/.gitmodules index 514ac54f..8ad240e1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,4 +4,3 @@ [submodule "3rdparty/Megatron-Bridge"] path = 3rdparty/Megatron-Bridge url = https://github.com/NVIDIA-NeMo/Megatron-Bridge.git - branch = main From b1c41fc3be1f04365c5f4e83c3b75759b608171f Mon Sep 17 00:00:00 2001 From: Huy Vu2 Date: Wed, 12 Nov 2025 11:34:10 -0800 Subject: [PATCH 54/80] uv.lock --- pyproject.toml | 4 ---- uv.lock | 12 +----------- 2 files changed, 1 insertion(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fb34cc51..c2f5d7f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,10 +93,6 @@ dev = [ build = ["setuptools", "wheel", "torch", "pybind11", "Cython>=3.0.0", "numpy<2.0.0", "ninja", "packaging", "nvidia-mathdx"] automodel = [ "nemo-automodel", - "diffusers", - "ftfy", - "imageio-ffmpeg", - "opencv-python-headless==4.10.0.84", ] megatron-bridge = ["megatron-bridge"] diff --git a/uv.lock b/uv.lock index 5161ef7b..6998b6f2 100644 --- a/uv.lock +++ b/uv.lock @@ -3401,11 +3401,7 @@ dependencies = [ [package.dev-dependencies] automodel = [ - { name = "diffusers" }, - { name = "ftfy" }, - { name = "imageio-ffmpeg" }, { name = "nemo-automodel" }, - { name = "opencv-python-headless" }, ] build = [ { name = "cython" }, @@ -3457,13 +3453,7 @@ requires-dist = [ ] [package.metadata.requires-dev] -automodel = [ - { name = "diffusers" }, - { name = "ftfy" }, - { name = "imageio-ffmpeg" }, - { name = "nemo-automodel", directory = "3rdparty/Automodel" }, - { name = "opencv-python-headless", specifier = "==4.10.0.84" }, -] +automodel = [{ name = "nemo-automodel", directory = "3rdparty/Automodel" }] build = [ { name = "cython", specifier = ">=3.0.0" }, { name = "ninja" }, From 681145bc8f3ea63b0df507a50a7b7f94c7682d5b Mon Sep 17 00:00:00 2001 From: Pablo Garay Date: Wed, 12 Nov 2025 14:27:07 -0800 Subject: [PATCH 55/80] update uv.lock [using ci] --- uv.lock | 89 ++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 29 deletions(-) diff --git a/uv.lock b/uv.lock index 6998b6f2..5481eaff 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'darwin'", @@ -522,7 +522,7 @@ name = "bitsandbytes" version = "0.45.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "numpy" }, + { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, { name = "torch", marker = "sys_platform == 'never'" }, ] wheels = [ @@ -1400,6 +1400,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359, upload-time = "2025-02-09T03:17:01.998Z" }, ] +[[package]] +name = "emerging-optimizers" +version = "0.1.0" +source = { git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=cf9909b777ffac18e05b67a6708282cadc000942#cf9909b777ffac18e05b67a6708282cadc000942" } +dependencies = [ + { name = "absl-py" }, + { name = "torch", marker = "sys_platform == 'never'" }, + { name = "typing-extensions" }, +] + [[package]] name = "exceptiongroup" version = "1.3.0" @@ -1478,6 +1488,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/18/79/1b8fa1bb3568781e84c9200f951c735f3f157429f44be0495da55894d620/filetype-1.2.0-py2.py3-none-any.whl", hash = "sha256:7ce71b6880181241cf7ac8697a2f1eb6a8bd9b429f7ad6d27b8db9ba5f1c2d25", size = 19970, upload-time = "2022-11-02T17:34:01.425Z" }, ] +[[package]] +name = "fla-core" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "einops" }, + { name = "torch", marker = "sys_platform == 'never'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/67/c6/10a1149b07e6bab45b2cb2d07f6b827716c2baf5f3404161753f25c6389b/fla_core-0.3.2.tar.gz", hash = "sha256:d38db16bc4e1c6fa8c04df442f246da1e6926a209426bc6ef703d41bfbc37c92", size = 296725, upload-time = "2025-09-10T07:43:40.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/f5/74947b33c07682280e65adbdf17c4ee94b30232df2f728bafecf13d1d820/fla_core-0.3.2-py3-none-any.whl", hash = "sha256:e751d5a41e33eee721a6fb6588bd857f6f36e0d14719a23b1ebdbd617d307209", size = 413594, upload-time = "2025-09-10T07:43:37.786Z" }, +] + [[package]] name = "flake8" version = "7.3.0" @@ -1492,6 +1515,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/56/13ab06b4f93ca7cac71078fbe37fcea175d3216f31f85c3168a6bbd0bb9a/flake8-7.3.0-py2.py3-none-any.whl", hash = "sha256:b9696257b9ce8beb888cdbe31cf885c90d31928fe202be0889a7cdafad32f01e", size = 57922, upload-time = "2025-06-20T19:31:34.425Z" }, ] +[[package]] +name = "flash-linear-attention" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "datasets" }, + { name = "fla-core" }, + { name = "pytest" }, + { name = "transformers" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/84/f6/e62c1e562a288557eba7f06f168a7615813d1a227327b8beb8ba426da2c5/flash_linear_attention-0.3.2.tar.gz", hash = "sha256:9147747316c2951fed4ebeb4fa87977c05d807dc70c93b46250b68a6eb1183e2", size = 150880, upload-time = "2025-09-10T07:43:41.37Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/d0/35ce9eac5f52c72005095aaa12a393d2656ed7ffedf925b2381a6b76d10c/flash_linear_attention-0.3.2-py3-none-any.whl", hash = "sha256:604e73361437ba786420ab195e2caa3fd19280503761e703fa353c5ce5c65376", size = 274592, upload-time = "2025-09-10T07:43:39.107Z" }, +] + [[package]] name = "flashinfer-python" version = "0.5.2" @@ -2697,6 +2735,8 @@ dev = [ { name = "av" }, { name = "causal-conv1d" }, { name = "einops" }, + { name = "emerging-optimizers" }, + { name = "flash-linear-attention" }, { name = "flashinfer-python" }, { name = "mamba-ssm" }, { name = "megatron-energon", extra = ["av-decode"] }, @@ -2726,46 +2766,41 @@ mlm = [ [package.metadata] requires-dist = [ { name = "av", marker = "extra == 'dev'", specifier = "<16.0.0" }, - { name = "av", marker = "extra == 'lts'", specifier = "<16.0.0" }, { name = "causal-conv1d", marker = "extra == 'dev'", specifier = "~=1.5" }, - { name = "causal-conv1d", marker = "extra == 'lts'", specifier = "~=1.5" }, { name = "einops", marker = "extra == 'dev'", specifier = "~=0.8" }, - { name = "einops", marker = "extra == 'lts'", specifier = "~=0.8" }, + { name = "einops", marker = "extra == 'lts'" }, + { name = "emerging-optimizers", marker = "extra == 'dev'", git = "https://github.com/NVIDIA-NeMo/Emerging-Optimizers.git?rev=cf9909b777ffac18e05b67a6708282cadc000942" }, + { name = "flash-linear-attention", marker = "extra == 'dev'", specifier = "~=0.3.2" }, { name = "flashinfer-python", marker = "extra == 'dev'" }, - { name = "flashinfer-python", marker = "extra == 'lts'" }, { name = "flask-restful", marker = "extra == 'mlm'" }, { name = "mamba-ssm", marker = "extra == 'dev'", specifier = "~=2.2" }, - { name = "mamba-ssm", marker = "extra == 'lts'", specifier = "~=2.2" }, { name = "megatron-energon", extras = ["av-decode"], marker = "extra == 'dev'", specifier = "~=6.0" }, - { name = "megatron-energon", extras = ["av-decode"], marker = "extra == 'lts'", specifier = "~=6.0" }, { name = "multi-storage-client", marker = "extra == 'dev'", specifier = "~=0.27" }, - { name = "multi-storage-client", marker = "extra == 'lts'", specifier = "~=0.27" }, { name = "numpy", specifier = "<2.0.0" }, { name = "nv-grouped-gemm", marker = "extra == 'dev'", specifier = "~=1.1" }, - { name = "nv-grouped-gemm", marker = "extra == 'lts'", specifier = "~=1.1" }, - { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin' and extra == 'dev'" }, - { name = "nvidia-resiliency-ext", marker = "extra == 'dev'" }, + { name = "nvidia-modelopt", extras = ["torch"], marker = "sys_platform != 'darwin' and extra == 'dev'", specifier = ">=0.33.0a0,<0.34.0" }, + { name = "nvidia-resiliency-ext", marker = "extra == 'dev'", specifier = ">=0.4.0a0,<0.5.0" }, { name = "nvtx", marker = "extra == 'dev'", specifier = "~=0.2" }, - { name = "nvtx", marker = "extra == 'lts'", specifier = "~=0.2" }, + { name = "nvtx", marker = "extra == 'lts'" }, { name = "onnxscript", marker = "extra == 'dev'" }, - { name = "onnxscript", marker = "extra == 'lts'" }, { name = "opentelemetry-api", marker = "extra == 'dev'", specifier = "~=1.33.1" }, - { name = "opentelemetry-api", marker = "extra == 'lts'", specifier = "~=1.33.1" }, { name = "packaging", specifier = ">=24.2" }, { name = "sentencepiece", marker = "extra == 'mlm'" }, { name = "setuptools", marker = "extra == 'dev'", specifier = "<80.0.0" }, { name = "setuptools", marker = "extra == 'lts'", specifier = "<80.0.0" }, { name = "tensorstore", marker = "extra == 'dev'", specifier = "~=0.1,!=0.1.46,!=0.1.72" }, - { name = "tensorstore", marker = "extra == 'lts'", specifier = "~=0.1,!=0.1.46,!=0.1.72" }, + { name = "tensorstore", marker = "extra == 'lts'", specifier = "!=0.1.46,!=0.1.72" }, { name = "tiktoken", marker = "extra == 'mlm'" }, { name = "torch" }, { name = "tqdm", marker = "extra == 'dev'" }, { name = "tqdm", marker = "extra == 'lts'" }, { name = "transformer-engine", extras = ["pytorch"], marker = "extra == 'dev'", git = "https://github.com/NVIDIA/TransformerEngine.git?rev=release_v2.9" }, + { name = "transformers", marker = "extra == 'lts'" }, { name = "transformers", marker = "extra == 'mlm'" }, { name = "wandb", marker = "extra == 'mlm'" }, { name = "wget", marker = "extra == 'dev'" }, { name = "wget", marker = "extra == 'lts'" }, + { name = "zarr", marker = "extra == 'lts'" }, ] provides-extras = ["mlm", "dev", "lts"] @@ -5898,7 +5933,7 @@ name = "sympy" version = "1.14.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "mpmath" }, + { name = "mpmath", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/83/d3/803453b36afefb7c2bb238361cd4ae6125a569b4db67cd9e79846ba2d68c/sympy-1.14.0.tar.gz", hash = "sha256:d3d3fe8df1e5a0b42f0e7bdf50541697dbe7d23746e894990c030e2b05e72517", size = 7793921, upload-time = "2025-04-27T18:05:01.611Z" } wheels = [ @@ -6027,10 +6062,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/0e/c38f079f3933cc284aab53d52976f6cb4f1ad43bb6a704ac27e0b710f176/tensorstore-0.1.79-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:847982652273fb7b2d694b789205747aaf3e50ae64738c5cb7b5eb03d86a9947", size = 18949282, upload-time = "2025-11-11T22:05:07.562Z" }, { url = "https://files.pythonhosted.org/packages/6f/99/03479deea5bfd27a0d8a8c75d5f1d85417a7bbc9c6c7a90fb85b4a4e347a/tensorstore-0.1.79-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7af9422269c2bfcdecf9dd55309060665ab9c2d7f6c892377ed32c032400feea", size = 20931601, upload-time = "2025-11-11T22:05:10.098Z" }, { url = "https://files.pythonhosted.org/packages/26/36/2617edf6c6d6fc73b3ff96d9d0b97332adf0d0c56fa2014a226bf4f7dfa6/tensorstore-0.1.79-cp314-cp314-win_amd64.whl", hash = "sha256:bbd8c1ab7d2e3c03ded3d40bb373ee9a67668e33a564484927865ce43b210386", size = 13599766, upload-time = "2025-11-11T22:05:12.265Z" }, - { url = "https://files.pythonhosted.org/packages/2b/c3/8ca0dbced6b5578bf0acc9f9f52b7ba35b24be783aa4be60d374719e2e45/tensorstore-0.1.79-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9bdfb7f422eae1b976bee20e8c758c22910b0610e0c59eb9f6fe748330c1f8e5", size = 16565957, upload-time = "2025-11-11T22:05:14.206Z" }, - { url = "https://files.pythonhosted.org/packages/a6/a7/68b5e7b08ef4ddf11efcc1f90b626636ee7ff786634840e657c920c467e7/tensorstore-0.1.79-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:e100001a9e1b452650ab489ae6a75f4a7e85e127fbe3e33969c1e0157b66f8a6", size = 14593418, upload-time = "2025-11-11T22:05:16.278Z" }, - { url = "https://files.pythonhosted.org/packages/db/34/12fd39faf4aff50440949f91dc97bd90cd4a8dd0c95d9021682ea10e87f5/tensorstore-0.1.79-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8e0044d15b574506645061efdb204f00a85b2fda95c097049a67cfee135bcf74", size = 18959327, upload-time = "2025-11-11T22:05:18.7Z" }, - { url = "https://files.pythonhosted.org/packages/e4/d6/99f83285ae873d03db5f1189ed7173ba96c216d3dc2521e0b8f513aa0514/tensorstore-0.1.79-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fa4bfbd6718126669ba83e37e401fa84832772124862c3c7ec8a6bce90931492", size = 20938781, upload-time = "2025-11-11T22:05:21.309Z" }, ] [[package]] @@ -6191,15 +6222,15 @@ name = "torch" version = "2.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock" }, - { name = "fsspec" }, - { name = "jinja2" }, - { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, - { name = "networkx", version = "3.6rc0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, - { name = "setuptools", marker = "python_full_version >= '3.12'" }, - { name = "sympy" }, + { name = "filelock", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "fsspec", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "jinja2", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "networkx", version = "3.4.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "networkx", version = "3.6rc0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11' and sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "setuptools", marker = "python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux'" }, + { name = "sympy", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, { name = "triton", marker = "sys_platform == 'never'" }, - { name = "typing-extensions" }, + { name = "typing-extensions", marker = "sys_platform != 'darwin' and sys_platform != 'linux'" }, ] [[package]] From 7ad788e7f36d8ace3256e0d54e3bf8f590850d63 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 13 Nov 2025 00:25:14 -0800 Subject: [PATCH 56/80] Performance improvements to Wan --- .../data/wan/wan_mock_energon_datamodule.py | 8 ++-- .../model/wan/flow_matching/flow_pipeline.py | 13 ++++--- dfm/src/megatron/model/wan/wan_layer_spec.py | 39 +++++++++++++------ dfm/src/megatron/model/wan/wan_provider.py | 2 + dfm/src/megatron/model/wan/wan_step.py | 4 +- dfm/src/megatron/recipes/wan/wan.py | 10 ++++- 6 files changed, 51 insertions(+), 25 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py index e8db2d67..72c489b2 100644 --- a/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_energon_datamodule.py @@ -108,13 +108,13 @@ class WanMockDataModuleConfig(DatasetProvider): W_latents: int = 60 patch_spatial: int = 2 patch_temporal: int = 1 - number_packed_samples: int = 3 + number_packed_samples: int = 1 context_seq_len: int = 512 context_embeddings_dim: int = 4096 def __post_init__(self): mock_ds = _MockDataset(length=1024) - self._train_dl = DataLoader( + self._train_dl = iter(DataLoader( mock_ds, batch_size=self.micro_batch_size, num_workers=self.num_workers, @@ -130,7 +130,9 @@ def __post_init__(self): ), shuffle=False, drop_last=False, - ) + pin_memory=True, + prefetch_factor=8, + )) self.sequence_length = self.seq_length def build_datasets(self, _context: DatasetBuildContext): diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index c90cda91..bc964b68 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -193,11 +193,13 @@ def training_step( noise = thd_split_inputs_cp( noise, packed_seq_params["self_attention"].cu_seqlens_q, parallel_state.get_context_parallel_group() ) - context_embeddings = thd_split_inputs_cp( - context_embeddings, - packed_seq_params["cross_attention"].cu_seqlens_kv, - parallel_state.get_context_parallel_group(), - ) + # We don't need to split context embeddings across context parallelism + # if we disable context parallelism for cross-attention + # context_embeddings = thd_split_inputs_cp( + # context_embeddings, + # packed_seq_params["cross_attention"].cu_seqlens_kv, + # parallel_state.get_context_parallel_group(), + # ) split_loss_mask = thd_split_inputs_cp( loss_mask, packed_seq_params["self_attention"].cu_seqlens_q, @@ -259,5 +261,4 @@ def training_step( context=context_embeddings, packed_seq_params=packed_seq_params, ) - return hidden_states diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 053e3e94..4551f363 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -16,10 +16,12 @@ from dataclasses import dataclass from typing import Optional, Union +import copy import torch import torch.nn as nn from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.jit import jit_fuser from megatron.core.process_groups_config import ProcessGroupCollection from megatron.core.transformer.attention import SelfAttentionSubmodules from megatron.core.transformer.custom_layers.transformer_engine import ( @@ -68,11 +70,11 @@ def forward(self, timestep_emb): e = (self.modulation + timestep_emb).chunk(6, dim=1) return e - # @jit_fuser + @jit_fuser def modulate(self, x, shift, scale): return x * (1 + scale) + shift - # @jit_fuser + @jit_fuser def scale_add(self, residual, x, gate): return residual + gate * x @@ -95,19 +97,31 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + modified_submods = _replace_no_cp_submodules(submodules) super().__init__( - config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout + config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout ) - # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? - # # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. - # cp_override_config = copy.deepcopy(config) - # cp_override_config.tp_comm_overlap = False - # self.cross_attention = build_module( - # submodules.cross_attention, - # config=cp_override_config, - # layer_number=layer_number, - # ) + # Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as + # Q and lead to incorrect tensor shapes. + if submodules.cross_attention != IdentityOp: + cp_override_config = copy.deepcopy(config) + cp_override_config.context_parallel_size = 1 + cp_override_config.tp_comm_overlap = False + self.cross_attention = build_module( + submodules.cross_attention, + config=cp_override_config, + layer_number=layer_number, + ) + else: + self.cross_attention = None self.full_self_attention = build_module( submodules.full_self_attention, @@ -199,6 +213,7 @@ def forward( # ******************************************** cross attention ****************************************************** + packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor([0, hidden_states.shape[0]], device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, dtype=torch.int32) attention_output, bias = self.cross_attention( self.norm3(hidden_states), attention_mask=context_mask, diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index 24e8c87d..8a198bcf 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -51,6 +51,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): bf16: bool = False params_dtype: torch.dtype = torch.float32 qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + apply_rope_fusion: bool = True + bias_activation_fusion: bool = True # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index 32a5b5e6..8973056f 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -32,10 +32,8 @@ def wan_data_step(qkv_format, dataloader_iter): - batch = next(iter(dataloader_iter.iterable)) - + batch = next(dataloader_iter) batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} - # Construct packed sequence parameters if ("seq_len_q" in batch) and ("seq_len_kv" in batch): cu_seqlens = batch["seq_len_q"].cumsum(dim=0).to(torch.int32) diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 21bd1b50..061beaf7 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -26,6 +26,7 @@ RNGConfig, TokenizerConfig, TrainingConfig, + ProfilingConfig, ) from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config from megatron.core.distributed import DistributedDataParallelConfig @@ -170,7 +171,7 @@ def pretrain_config( context_embeddings_dim=4096, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, - num_workers=10, + num_workers=16, packing_buffer_size=None, ) else: @@ -225,6 +226,13 @@ def pretrain_config( rng=RNGConfig(seed=1234), comm_overlap=comm_overlap_config, mixed_precision=precision_config, + profiling=ProfilingConfig( + use_nsys_profiler=False, + profile_step_start=10, + profile_step_end=11, + record_shapes=True, + profile_ranks=[0], + ), ) return cfg From 0fd0e274c2cf4d541d62cd43b85a434a4d340d13 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 14 Nov 2025 12:23:41 -0800 Subject: [PATCH 57/80] Perf optimizations --- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 6 ++---- dfm/src/megatron/model/wan/rope_utils.py | 6 +----- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index bc964b68..d6575e27 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -141,10 +141,8 @@ def training_step( # because video_latents might be padded, we need to make sure noise also be padded to have the same shape sample_noise_seq_len = sample_noise.shape[0] - cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded - seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] - if sample_noise_seq_len < seq_len_q_padded: - pad_len = seq_len_q_padded - sample_noise_seq_len + if sample_noise_seq_len < video_latents.shape[0]: + pad_len = video_latents.shape[0] - sample_noise_seq_len pad = torch.zeros( (pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype ) diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 2b64fdaa..993507e5 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -31,7 +31,7 @@ def __init__(self, dim_head, max_position_len): self.rope_params(max_position_len, 2 * (dim_head // 6)), ], dim=1, - ) + ).cuda() def rope_params(self, max_position_len, dim_head, theta=10000): assert dim_head % 2 == 0 @@ -41,10 +41,6 @@ def rope_params(self, max_position_len, dim_head, theta=10000): return freqs def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): - self.freqs = self.freqs.to( - device, - ) - n, c = n_head, dim_head // 2 # split freqs From fd373f96d6487066ff85ca4b745a2fb64a6bd273 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 14 Nov 2025 14:34:32 -0800 Subject: [PATCH 58/80] Tiny fix --- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index ea886730..53d2b80f 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -261,8 +261,4 @@ def training_step( context=context_embeddings, packed_seq_params=packed_seq_params, ) -<<<<<<< HEAD -======= - ->>>>>>> 55c42e137929d45ec43e95bf241adc0ffc8df16b return hidden_states From 6a93703cb49e443474c6e9dda841d39015c35b31 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 14 Nov 2025 16:16:29 -0800 Subject: [PATCH 59/80] Remove CP disable as packed sequences not supported --- .../model/wan/flow_matching/flow_pipeline.py | 10 +++--- dfm/src/megatron/model/wan/wan_layer_spec.py | 33 +++++++++++-------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 53d2b80f..629d796b 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -195,11 +195,11 @@ def training_step( ) # We don't need to split context embeddings across context parallelism # if we disable context parallelism for cross-attention - # context_embeddings = thd_split_inputs_cp( - # context_embeddings, - # packed_seq_params["cross_attention"].cu_seqlens_kv, - # parallel_state.get_context_parallel_group(), - # ) + context_embeddings = thd_split_inputs_cp( + context_embeddings, + packed_seq_params["cross_attention"].cu_seqlens_kv_padded, + parallel_state.get_context_parallel_group(), + ) split_loss_mask = thd_split_inputs_cp( loss_mask, packed_seq_params["self_attention"].cu_seqlens_q_padded, diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 4551f363..21a96c7c 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -103,25 +103,25 @@ def _replace_no_cp_submodules(submodules): return modified_submods # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. - modified_submods = _replace_no_cp_submodules(submodules) + # modified_submods = _replace_no_cp_submodules(submodules) super().__init__( - config=config, submodules=modified_submods, layer_number=layer_number, hidden_dropout=hidden_dropout + config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout ) # Override Cross Attention to disable CP. # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as # Q and lead to incorrect tensor shapes. - if submodules.cross_attention != IdentityOp: - cp_override_config = copy.deepcopy(config) - cp_override_config.context_parallel_size = 1 - cp_override_config.tp_comm_overlap = False - self.cross_attention = build_module( - submodules.cross_attention, - config=cp_override_config, - layer_number=layer_number, - ) - else: - self.cross_attention = None + # if submodules.cross_attention != IdentityOp: + # cp_override_config = copy.deepcopy(config) + # cp_override_config.context_parallel_size = 1 + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + # else: + # self.cross_attention = None self.full_self_attention = build_module( submodules.full_self_attention, @@ -213,7 +213,12 @@ def forward( # ******************************************** cross attention ****************************************************** - packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor([0, hidden_states.shape[0]], device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, dtype=torch.int32) + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. + # But needs better support for packed sequences and padding to ensure correct calculations + # packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor( + # [0, hidden_states.shape[0]], + # device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, + # dtype=torch.int32) attention_output, bias = self.cross_attention( self.norm3(hidden_states), attention_mask=context_mask, From 345b53ee9e6357b2a76431e9158c440fef43de42 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Fri, 14 Nov 2025 16:18:59 -0800 Subject: [PATCH 60/80] Fix comment --- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 629d796b..37ae28d7 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -193,6 +193,7 @@ def training_step( packed_seq_params["self_attention"].cu_seqlens_q_padded, parallel_state.get_context_parallel_group(), ) + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. # We don't need to split context embeddings across context parallelism # if we disable context parallelism for cross-attention context_embeddings = thd_split_inputs_cp( From c9e55bbc1fc58d6ed53fb2fccacfe26390dbb669 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 17 Nov 2025 13:33:51 -0800 Subject: [PATCH 61/80] Minor fixes. Revert video_latent comparison --- .../megatron/data/wan/wan_mock_datamodule.py | 40 ++++++++++--------- .../model/wan/flow_matching/flow_pipeline.py | 4 +- dfm/src/megatron/model/wan/wan_layer_spec.py | 6 +-- dfm/src/megatron/recipes/wan/wan.py | 8 ---- 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index fb47ce60..c8c05624 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -119,25 +119,27 @@ class WanMockDataModuleConfig(DatasetProvider): def __post_init__(self): mock_ds = _MockDataset(length=1024) - self._train_dl = iter(DataLoader( - mock_ds, - batch_size=self.micro_batch_size, - num_workers=self.num_workers, - collate_fn=lambda samples: mock_batch( - F_latents=self.F_latents, - H_latents=self.H_latents, - W_latents=self.W_latents, - patch_temporal=self.patch_temporal, - patch_spatial=self.patch_spatial, - number_packed_samples=self.number_packed_samples, - context_seq_len=self.context_seq_len, - context_embeddings_dim=self.context_embeddings_dim, - ), - shuffle=False, - drop_last=False, - pin_memory=True, - prefetch_factor=8, - )) + self._train_dl = iter( + DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=lambda samples: mock_batch( + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + patch_temporal=self.patch_temporal, + patch_spatial=self.patch_spatial, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + shuffle=False, + drop_last=False, + pin_memory=True, + prefetch_factor=8, + ) + ) self.sequence_length = self.seq_length def build_datasets(self, _context: DatasetBuildContext): diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index 37ae28d7..bad25f04 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -141,8 +141,10 @@ def training_step( # because video_latents might be padded, we need to make sure noise also be padded to have the same shape sample_noise_seq_len = sample_noise.shape[0] + cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded + seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] if sample_noise_seq_len < video_latents.shape[0]: - pad_len = video_latents.shape[0] - sample_noise_seq_len + pad_len = seq_len_q_padded - sample_noise_seq_len pad = torch.zeros( (pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype ) diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 21a96c7c..5ad752ec 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -14,9 +14,9 @@ # pylint: disable=C0115,C0116,C0301 +import copy from dataclasses import dataclass from typing import Optional, Union -import copy import torch import torch.nn as nn @@ -216,8 +216,8 @@ def forward( # TODO (pmannan): Disable CP for CrossAttention as KV context is small. # But needs better support for packed sequences and padding to ensure correct calculations # packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor( - # [0, hidden_states.shape[0]], - # device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, + # [0, hidden_states.shape[0]], + # device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, # dtype=torch.int32) attention_output, bias = self.cross_attention( self.norm3(hidden_states), diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index 847789b5..5fa16526 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -26,7 +26,6 @@ RNGConfig, TokenizerConfig, TrainingConfig, - ProfilingConfig, ) from megatron.bridge.training.mixed_precision import MixedPrecisionConfig, get_mixed_precision_config from megatron.core.distributed import DistributedDataParallelConfig @@ -226,13 +225,6 @@ def pretrain_config( rng=RNGConfig(seed=1234), comm_overlap=comm_overlap_config, mixed_precision=precision_config, - profiling=ProfilingConfig( - use_nsys_profiler=False, - profile_step_start=10, - profile_step_end=11, - record_shapes=True, - profile_ranks=[0], - ), ) return cfg From 3504d3f02749d29e57e99d1e2067f0300af1942a Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 17 Nov 2025 13:34:48 -0800 Subject: [PATCH 62/80] Fix missed check --- dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index bad25f04..686401fd 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -143,7 +143,7 @@ def training_step( sample_noise_seq_len = sample_noise.shape[0] cu_seqlens_q_padded = packed_seq_params["self_attention"].cu_seqlens_q_padded seq_len_q_padded = cu_seqlens_q_padded[i + 1] - cu_seqlens_q_padded[i] - if sample_noise_seq_len < video_latents.shape[0]: + if sample_noise_seq_len < seq_len_q_padded: pad_len = seq_len_q_padded - sample_noise_seq_len pad = torch.zeros( (pad_len, sample_noise.shape[1]), device=sample_noise.device, dtype=sample_noise.dtype From b2fef2f151604f391c5a7586a759d087ed4e46dd Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Mon, 17 Nov 2025 13:39:45 -0800 Subject: [PATCH 63/80] Lint fix --- dfm/src/megatron/model/wan/wan_layer_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 5ad752ec..34a2c2af 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -101,7 +101,7 @@ def _replace_no_cp_submodules(submodules): modified_submods = copy.deepcopy(submodules) modified_submods.cross_attention = IdentityOp return modified_submods - + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. # modified_submods = _replace_no_cp_submodules(submodules) super().__init__( From b5ac64958f1dc897b00bb483d0d33b85c427e83b Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 18 Nov 2025 14:09:42 -0800 Subject: [PATCH 64/80] H100 mock pretraining perf config --- .../recipes/wan/conf/h100_pretrain_mock.yaml | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml diff --git a/examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml new file mode 100644 index 00000000..4dc88479 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml @@ -0,0 +1,37 @@ +model: + tensor_model_parallel_size: 2 + sequence_parallel: true + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + recompute_granularity: full + recompute_method: block + recompute_num_layers: 10 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + empty_unused_memory_level: 1 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 \ No newline at end of file From 52083acfd96136bb6323856131041edae51cc728 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 18 Nov 2025 14:10:39 -0800 Subject: [PATCH 65/80] Rename config file --- .../{h100_pretrain_mock.yaml => h100_perf_pretrain_mock.yaml} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename examples/megatron/recipes/wan/conf/{h100_pretrain_mock.yaml => h100_perf_pretrain_mock.yaml} (100%) diff --git a/examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml similarity index 100% rename from examples/megatron/recipes/wan/conf/h100_pretrain_mock.yaml rename to examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml From fef3196f97cb24a84bf4deb04ef5205734be1343 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 18 Nov 2025 14:18:31 -0800 Subject: [PATCH 66/80] Lint check Signed-off-by: Parth Mannan --- examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml index 4dc88479..c64df86d 100644 --- a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml +++ b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml @@ -34,4 +34,4 @@ dataset: micro_batch_size: 1 logger: - log_interval: 1 \ No newline at end of file + log_interval: 1 From 601a63ac1ea05a16de921dd5bea69f3af6d8aad6 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 18 Nov 2025 15:47:11 -0800 Subject: [PATCH 67/80] Adding GB200 perf config Signed-off-by: Parth Mannan --- .../wan/conf/gb200_perf_pretrain_mock.yaml | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml diff --git a/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml new file mode 100644 index 00000000..7b170d36 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 From bc191b250fb63ee6421d7bde0eadf36a37d623b2 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Tue, 18 Nov 2025 18:31:10 -0800 Subject: [PATCH 68/80] GB300 perf config Signed-off-by: Parth Mannan --- .../wan/conf/gb300_perf_pretrain_mock.yaml | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml diff --git a/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml new file mode 100644 index 00000000..a35e6238 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 2 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 From 992c8fb7a1c566074c29b55a8af32cc05cb2ad64 Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Wed, 19 Nov 2025 09:01:01 +0000 Subject: [PATCH 69/80] Refactor Energon data module to return wrapped dataloaders and add EnergonDataloader class for cyclic iteration. Introduce WAN pretrain mock data configuration for testing. --- .../data/common/base_energon_datamodule.py | 26 ++++++++++- .../wan_pretrain_mock_data.yaml | 44 +++++++++++++++++++ 2 files changed, 68 insertions(+), 2 deletions(-) create mode 100644 examples/megatron/override_configs/wan_pretrain_mock_data.yaml diff --git a/dfm/src/megatron/data/common/base_energon_datamodule.py b/dfm/src/megatron/data/common/base_energon_datamodule.py index 0d6cb99d..1f1da0f1 100644 --- a/dfm/src/megatron/data/common/base_energon_datamodule.py +++ b/dfm/src/megatron/data/common/base_energon_datamodule.py @@ -195,7 +195,7 @@ def train_dataloader(self) -> Any: train_dataset = self.datasets_provider(worker_config, split="train") energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) self.train_dataloader_object = energon_dataloader - return self.train_dataloader_object + return EnergonDataloader(self.train_dataloader_object) def val_dataloader(self): """ @@ -233,7 +233,7 @@ def val_dataloader(self): val_dataset = self.datasets_provider(worker_config, split="val") energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) self.val_dataloader_object = energon_loader - return self.val_dataloader_object + return EnergonDataloader(self.val_dataloader_object) def test_dataloader(self) -> None: """ @@ -337,3 +337,25 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: consumed_samples=consumed_samples, consistency_check=False, ) + + +class EnergonDataloader: + """A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop.""" + def __init__(self, dataloader): + self._dataloader = dataloader + self._iter = iter(cyclic_iter(dataloader)) + + def __next__(self): + return self._iter.__next__() + + def __iter__(self): + return self._iter.__iter__() + + def save_state(self): + return self._dataloader.save_state_rank() + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x \ No newline at end of file diff --git a/examples/megatron/override_configs/wan_pretrain_mock_data.yaml b/examples/megatron/override_configs/wan_pretrain_mock_data.yaml new file mode 100644 index 00000000..a36e9dc0 --- /dev/null +++ b/examples/megatron/override_configs/wan_pretrain_mock_data.yaml @@ -0,0 +1,44 @@ +# WAN Pretrain Mock Data Test Configuration +# Converted from L2_Function_Tests_GPU_Wan_Mock_Data.sh + +model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + crossattn_emb_size: 1536 + hidden_size: 1536 + ffn_hidden_size: 8960 + num_attention_heads: 12 + num_layers: 3 + qkv_format: thd + seq_length: 2048 + +train: + eval_iters: 0 + global_batch_size: 2 + micro_batch_size: 1 + +optimizer: + lr: 5.0e-6 + min_lr: 5.0e-6 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +checkpoint: + save: ${oc.env:CHECKPOINT_DIR,null} + load: ${oc.env:CHECKPOINT_DIR,null} + load_optim: false + save_interval: 200 + +dataset: + path: ${oc.env:DATASET_PATH,null} + seq_length: 2048 + global_batch_size: 2 + micro_batch_size: 1 + packing_buffer_size: 50 + +logger: + log_interval: 1 + From 3954dddee53150b6f8e7c3997473809d7cff35bb Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Wed, 19 Nov 2025 09:26:37 +0000 Subject: [PATCH 70/80] Enhance DiffusionTaskEncoder to handle None attributes in stacking and concatenation methods. Add WAN pretrain mock data configuration for testing purposes. --- .../data/common/diffusion_task_encoder_with_sp.py | 10 ++++++++-- ...in_mock_data.yaml => wan_pretrain_sample_data.yaml} | 0 2 files changed, 8 insertions(+), 2 deletions(-) rename examples/megatron/override_configs/{wan_pretrain_mock_data.yaml => wan_pretrain_sample_data.yaml} (100%) diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index a44f36dd..bb38b033 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -88,10 +88,16 @@ def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSamp """Construct a new Diffusion sample by concatenating the sequences.""" def stack(attr): - return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + else: + return None def cat(attr): - return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + else: + return None return DiffusionSample( __key__=",".join([s.__key__ for s in samples]), diff --git a/examples/megatron/override_configs/wan_pretrain_mock_data.yaml b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml similarity index 100% rename from examples/megatron/override_configs/wan_pretrain_mock_data.yaml rename to examples/megatron/override_configs/wan_pretrain_sample_data.yaml From 45aff1be6ebcec654ea7b9245785618cf753a48e Mon Sep 17 00:00:00 2001 From: Abhinav Garg Date: Wed, 19 Nov 2025 14:28:25 +0000 Subject: [PATCH 71/80] Refactor data processing in dit_data_step to simplify batch retrieval and update WAN pretrain configuration to include train_iters. --- dfm/src/megatron/model/dit/dit_data_process.py | 3 +-- .../megatron/override_configs/wan_pretrain_sample_data.yaml | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dfm/src/megatron/model/dit/dit_data_process.py b/dfm/src/megatron/model/dit/dit_data_process.py index e599581a..97b5d76e 100644 --- a/dfm/src/megatron/model/dit/dit_data_process.py +++ b/dfm/src/megatron/model/dit/dit_data_process.py @@ -18,8 +18,7 @@ def dit_data_step(qkv_format, dataloader_iter): - # import pdb;pdb.set_trace() - batch = next(iter(dataloader_iter.iterable)) + batch = next(dataloader_iter) batch["is_preprocessed"] = True # assume data is preprocessed batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} batch = encode_seq_length(batch, format=qkv_format) diff --git a/examples/megatron/override_configs/wan_pretrain_sample_data.yaml b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml index a36e9dc0..1e973cf4 100644 --- a/examples/megatron/override_configs/wan_pretrain_sample_data.yaml +++ b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml @@ -15,6 +15,7 @@ model: train: eval_iters: 0 + train_iters: 10 global_batch_size: 2 micro_batch_size: 1 From 6aff745ddb65ef12b21fb4e1aad850814ba94ec6 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 19 Nov 2025 19:04:02 -0800 Subject: [PATCH 72/80] Add op fusions Signed-off-by: Parth Mannan --- dfm/src/megatron/model/wan/wan_layer_spec.py | 33 +++++++++++-------- .../wan/conf/h100_perf_pretrain_mock.yaml | 4 +-- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index 8862f40f..f75888bf 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -66,10 +66,16 @@ def __init__(self, config: TransformerConfig): setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + @jit_fuser def forward(self, timestep_emb): - e = (self.modulation + timestep_emb).chunk(6, dim=1) + e = (self.modulation + timestep_emb).transpose(0, 1) + e = e.chunk(6, dim=0) return e + @jit_fuser + def normalize_modulate(self, norm, hidden_states, shift, scale): + return self.modulate(norm(hidden_states), shift, scale) + @jit_fuser def modulate(self, x, shift, scale): return x * (1 + scale) + shift @@ -108,7 +114,7 @@ def _replace_no_cp_submodules(submodules): config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout ) - # Override Cross Attention to disable CP. + # TODO (pmannan): Override Cross Attention to disable CP. # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as # Q and lead to incorrect tensor shapes. # if submodules.cross_attention != IdentityOp: @@ -161,6 +167,10 @@ def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) if isinstance(param, nn.Parameter) and param.requires_grad: setattr(param, "average_gradients_across_tp_domain", True) + @jit_fuser + def add_residual(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return x + residual + def forward( self, hidden_states, @@ -182,19 +192,13 @@ def forward( rope_emb = rotary_pos_emb shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) - # transpose to bring it to [1, b, ...] format - shift_full = shift_full.transpose(0, 1) - scale_full = scale_full.transpose(0, 1) - gate_full = gate_full.transpose(0, 1) - shift_mlp = shift_mlp.transpose(0, 1) - scale_mlp = scale_mlp.transpose(0, 1) - gate_mlp = gate_mlp.transpose(0, 1) # ******************************************** full self attention ******************************************* # adaLN with scale + shift + gate - pre_full_attn_layernorm_output_ada = self.adaLN.modulate( - self.norm1(hidden_states), + pre_full_attn_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm1, + hidden_states, shift=shift_full, scale=scale_full, ) @@ -229,12 +233,13 @@ def forward( if bias is not None: attention_output = attention_output + bias - hidden_states = hidden_states + attention_output + hidden_states = self.add_residual(hidden_states, attention_output) # ******************************************** mlp ****************************************************** - pre_mlp_layernorm_output_ada = self.adaLN.modulate( - self.norm2(hidden_states), + pre_mlp_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm2, + hidden_states, shift=shift_mlp, scale=scale_mlp, ) diff --git a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml index c64df86d..cfe45424 100644 --- a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml +++ b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml @@ -5,7 +5,7 @@ model: context_parallel_size: 4 recompute_granularity: full recompute_method: block - recompute_num_layers: 10 + recompute_num_layers: 8 crossattn_emb_size: 5120 hidden_size: 5120 ffn_hidden_size: 13824 @@ -18,7 +18,7 @@ train: global_batch_size: 64 micro_batch_size: 1 eval_iters: 0 - empty_unused_memory_level: 1 + empty_unused_memory_level: 0 scheduler: lr_decay_style: constant From 7a2ec65beecf3bc6f8dbd01113596759d7c324c1 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Wed, 19 Nov 2025 23:29:14 -0800 Subject: [PATCH 73/80] Update H100 config Signed-off-by: Parth Mannan --- .../megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml index cfe45424..0013fb32 100644 --- a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml +++ b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml @@ -15,7 +15,7 @@ model: seq_length: 2048 train: - global_batch_size: 64 + global_batch_size: 128 micro_batch_size: 1 eval_iters: 0 empty_unused_memory_level: 0 @@ -30,7 +30,7 @@ optimizer: dataset: seq_length: 2048 # This is not used - global_batch_size: 64 + global_batch_size: 128 micro_batch_size: 1 logger: From b90fc5ace2c31896e3bc9ba93fec0d3e8bf49f4b Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 11:10:03 -0800 Subject: [PATCH 74/80] Fix lint Signed-off-by: Parth Mannan --- dfm/src/megatron/data/common/base_energon_datamodule.py | 3 ++- .../megatron/override_configs/wan_pretrain_sample_data.yaml | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dfm/src/megatron/data/common/base_energon_datamodule.py b/dfm/src/megatron/data/common/base_energon_datamodule.py index 1f1da0f1..0bf711a9 100644 --- a/dfm/src/megatron/data/common/base_energon_datamodule.py +++ b/dfm/src/megatron/data/common/base_energon_datamodule.py @@ -341,6 +341,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class EnergonDataloader: """A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop.""" + def __init__(self, dataloader): self._dataloader = dataloader self._iter = iter(cyclic_iter(dataloader)) @@ -358,4 +359,4 @@ def save_state(self): def cyclic_iter(iter): while True: for x in iter: - yield x \ No newline at end of file + yield x diff --git a/examples/megatron/override_configs/wan_pretrain_sample_data.yaml b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml index 1e973cf4..9648874e 100644 --- a/examples/megatron/override_configs/wan_pretrain_sample_data.yaml +++ b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml @@ -42,4 +42,3 @@ dataset: logger: log_interval: 1 - From 78ef7b64ab89b601e1844348e39d917852b37263 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 11:14:19 -0800 Subject: [PATCH 75/80] Resolve conflict Signed-off-by: Parth Mannan --- .../megatron/data/wan/wan_mock_datamodule.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index a3d7b1db..0ae9b96f 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -119,26 +119,24 @@ class WanMockDataModuleConfig(DatasetProvider): def __post_init__(self): mock_ds = _MockDataset(length=1024) - self._train_dl = iter( - DataLoader( - mock_ds, - batch_size=self.micro_batch_size, - num_workers=self.num_workers, - collate_fn=lambda samples: mock_batch( - F_latents=self.F_latents, - H_latents=self.H_latents, - W_latents=self.W_latents, - patch_temporal=self.patch_temporal, - patch_spatial=self.patch_spatial, - number_packed_samples=self.number_packed_samples, - context_seq_len=self.context_seq_len, - context_embeddings_dim=self.context_embeddings_dim, - ), - shuffle=False, - drop_last=False, - pin_memory=True, - prefetch_factor=8, - ) + self._train_dl = DataLoader( + mock_ds, + batch_size=self.micro_batch_size, + num_workers=self.num_workers, + collate_fn=lambda samples: mock_batch( + F_latents=self.F_latents, + H_latents=self.H_latents, + W_latents=self.W_latents, + patch_temporal=self.patch_temporal, + patch_spatial=self.patch_spatial, + number_packed_samples=self.number_packed_samples, + context_seq_len=self.context_seq_len, + context_embeddings_dim=self.context_embeddings_dim, + ), + shuffle=False, + drop_last=False, + pin_memory=True, + prefetch_factor=8, ) self._train_dl = iter(self._train_dl) self.sequence_length = self.seq_length From 1a2d66210d51bdc744474940d456d7c885c24984 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 11:20:45 -0800 Subject: [PATCH 76/80] Fix for mock dataloader test Signed-off-by: Parth Mannan --- dfm/src/megatron/data/wan/wan_mock_datamodule.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index 0ae9b96f..1eb2b394 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -119,6 +119,9 @@ class WanMockDataModuleConfig(DatasetProvider): def __post_init__(self): mock_ds = _MockDataset(length=1024) + kwargs = {} + if self.num_workers > 0: + kwargs["prefetch_factor"] = 8 self._train_dl = DataLoader( mock_ds, batch_size=self.micro_batch_size, @@ -136,7 +139,7 @@ def __post_init__(self): shuffle=False, drop_last=False, pin_memory=True, - prefetch_factor=8, + **kwargs, ) self._train_dl = iter(self._train_dl) self.sequence_length = self.seq_length From 1b055c61508b3e20827ec33ba635a861fb71c179 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 12:28:10 -0800 Subject: [PATCH 77/80] Fix Dummyiter Signed-off-by: Parth Mannan --- tests/unit_tests/megatron/model/wan/test_wan_step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/megatron/model/wan/test_wan_step.py b/tests/unit_tests/megatron/model/wan/test_wan_step.py index 8ee0e9cb..f0c9e49a 100644 --- a/tests/unit_tests/megatron/model/wan/test_wan_step.py +++ b/tests/unit_tests/megatron/model/wan/test_wan_step.py @@ -35,7 +35,7 @@ def test_wan_data_step_builds_packed_seq_params_cuda_guarded(): # include a tensor field to exercise device transfer "video_latents": torch.randn(8, 1, 4, dtype=torch.float32), } - it = _DummyIter(batch) + it = iter(_DummyIter(batch)) qkv_format = "sbhd" out = wan_data_step(qkv_format, it) From 04f6c14eb9688dcdcfdbf308602d5ba78b327b52 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 14:01:14 -0800 Subject: [PATCH 78/80] Fix test Signed-off-by: Parth Mannan --- tests/unit_tests/megatron/model/wan/test_wan_step.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit_tests/megatron/model/wan/test_wan_step.py b/tests/unit_tests/megatron/model/wan/test_wan_step.py index f0c9e49a..c48366e0 100644 --- a/tests/unit_tests/megatron/model/wan/test_wan_step.py +++ b/tests/unit_tests/megatron/model/wan/test_wan_step.py @@ -35,7 +35,7 @@ def test_wan_data_step_builds_packed_seq_params_cuda_guarded(): # include a tensor field to exercise device transfer "video_latents": torch.randn(8, 1, 4, dtype=torch.float32), } - it = iter(_DummyIter(batch)) + it = iter(_DummyIter(batch).iterable) qkv_format = "sbhd" out = wan_data_step(qkv_format, it) From 7247bc7fe07823c2e9735dc72d42516c0430f6c8 Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 16:28:32 -0800 Subject: [PATCH 79/80] Make RoPE test only GPU Signed-off-by: Parth Mannan --- tests/unit_tests/megatron/model/wan/test_rope_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/unit_tests/megatron/model/wan/test_rope_utils.py b/tests/unit_tests/megatron/model/wan/test_rope_utils.py index 7e31d8d0..54090e6c 100644 --- a/tests/unit_tests/megatron/model/wan/test_rope_utils.py +++ b/tests/unit_tests/megatron/model/wan/test_rope_utils.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from dfm.src.megatron.model.wan.rope_utils import Wan3DRopeEmbeddings +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) def test_wan3d_rope_embeddings_shapes_and_padding(): # Small, CPU-friendly config n_head = 2 From b17a40d4a9c5592422eabe15e8cfb7f4a336063e Mon Sep 17 00:00:00 2001 From: Parth Mannan Date: Thu, 20 Nov 2025 16:58:54 -0800 Subject: [PATCH 80/80] Rope cuda fix Signed-off-by: Parth Mannan --- dfm/src/megatron/model/wan/rope_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 993507e5..e3449275 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -31,7 +31,9 @@ def __init__(self, dim_head, max_position_len): self.rope_params(max_position_len, 2 * (dim_head // 6)), ], dim=1, - ).cuda() + ) + if torch.cuda.is_available(): + self.freqs = self.freqs.cuda() def rope_params(self, max_position_len, dim_head, theta=10000): assert dim_head % 2 == 0