diff --git a/examples/automodel/README.md b/examples/automodel/README.md new file mode 100644 index 00000000..791bf66a --- /dev/null +++ b/examples/automodel/README.md @@ -0,0 +1,258 @@ +# Diffusion Model Fine-tuning with Automodel Backend + +Train diffusion models with distributed training support using NeMo Automodel and flow matching. + +**Currently Supported:** Wan 2.1 Text-to-Video (1.3B and 14B models) + +--- + +## Quick Start + +### 1. Docker Setup + +```bash +# Build image +docker build -f docker/Dockerfile.ci -t dfm-training . + +# Run container +docker run --gpus all -it \ + -v $(pwd):/workspace \ + -v /path/to/data:/data \ + --ipc=host \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + dfm-training bash + +# Inside container: Initialize submodules +export UV_PROJECT_ENVIRONMENT= +git submodule update --init --recursive 3rdparty/ +``` + +### 2. Prepare Data + +**Create video dataset:** +``` +/ +├── video1.mp4 +├── video2.mp4 +└── meta.json +``` + +**meta.json format:** +```json +[ + { + "file_name": "video1.mp4", + "width": 1280, + "height": 720, + "start_frame": 0, + "end_frame": 121, + "vila_caption": "A detailed description of the video content..." + } +] +``` + +**Preprocess videos to .meta files:** + +There are two preprocessing modes: + +**Mode 1: Full video (recommended for training)** +```bash +python dfm/src/automodel/utils/data/preprocess_resize.py \ + --mode video \ + --video_folder \ + --output_folder ./processed_meta \ + --model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --height 480 \ + --width 720 \ + --center-crop +``` + +**Mode 2: Extract frames (for frame-based training)** +```bash +python dfm/src/automodel/utils/data/preprocess_resize.py \ + --mode frames \ + --num-frames 40 \ + --video_folder \ + --output_folder ./processed_frames \ + --model Wan-AI/Wan2.1-T2V-1.3B-Diffusers \ + --height 240 \ + --width 416 \ + --center-crop +``` + +**Key arguments:** +- `--mode`: `video` (full video) or `frames` (extract evenly-spaced frames) +- `--num-frames`: Number of frames to extract (only for `frames` mode) +- `--height/--width`: Target resolution +- `--center-crop`: Crop to exact size after aspect-preserving resize + +**Preprocessing modes:** +- **`video` mode**: Processes entire video sequence, creates one `.meta` file per video +- **`frames` mode**: Extracts N evenly-spaced frames, creates one `.meta` file per frame (treated as 1-frame videos) + +**Output:** Creates `.meta` files containing: +- Encoded video latents (normalized) +- Text embeddings (from UMT5) +- First frame as JPEG (video mode only) +- Metadata + +### 3. Train + +**Single-node (8 GPUs):** +```bash +export UV_PROJECT_ENVIRONMENT= + +uv run --group automodel --with . \ + torchrun --nproc-per-node=8 \ + examples/automodel/finetune/finetune.py \ + -c examples/automodel/finetune/wan2_1_t2v_flow.yaml +``` + +**Multi-node with SLURM:** +```bash +#!/bin/bash +#SBATCH -N 2 +#SBATCH --ntasks-per-node 1 +#SBATCH --gpus-per-node=8 +#SBATCH --exclusive + +export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) +export MASTER_PORT=29500 +export NUM_GPUS=8 + +# Per-rank UV cache to avoid conflicts +unset UV_PROJECT_ENVIRONMENT +mkdir -p /opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} +export UV_CACHE_DIR=/opt/uv_cache/${SLURM_JOB_ID}_${SLURM_PROCID} + +uv run --group automodel --with . \ + torchrun \ + --nnodes=$SLURM_NNODES \ + --nproc-per-node=$NUM_GPUS \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT \ + examples/automodel/finetune/finetune.py \ + -c examples/automodel/finetune/wan2_1_t2v_flow_multinode.yaml +``` + +### 4. Validate + +```bash +uv run --group automodel --with . \ + python examples/automodel/generate/wan_validate.py \ + --meta_folder \ + --guidance_scale 5 \ + --checkpoint ./checkpoints/step_1000 \ + --num_samples 10 +``` + +**Note:** You can use `--checkpoint ./checkpoints/LATEST` to automatically use the most recent checkpoint. + +--- + +## Configuration + +### Fine-tuning Config (`wan2_1_t2v_flow.yaml`) + +```yaml +model: + pretrained_model_name_or_path: Wan-AI/Wan2.1-T2V-1.3B-Diffusers + +step_scheduler: + global_batch_size: 8 + local_batch_size: 1 + num_epochs: 10 + ckpt_every_steps: 100 + +data: + dataloader: + meta_folder: "" + num_workers: 2 + +optim: + learning_rate: 5e-6 + +flow_matching: + timestep_sampling: "uniform" + flow_shift: 3.0 + +fsdp: + dp_size: 8 # Single node: 8 GPUs + +checkpoint: + enabled: true + checkpoint_dir: "./checkpoints" +``` + +### Multi-node Config Differences + +```yaml +fsdp: + dp_size: 16 # 2 nodes × 8 GPUs + dp_replicate_size: 2 # Replicate across 2 nodes +``` + +### Pretraining vs Fine-tuning + +| Setting | Fine-tuning | Pretraining | +|---------|-------------|-------------| +| `learning_rate` | 5e-6 | 5e-5 | +| `weight_decay` | 0.01 | 0.1 | +| `flow_shift` | 3.0 | 2.5 | +| `logit_std` | 1.0 | 1.5 | +| Dataset size | 100s-1000s | 10K+ | + +--- + +## Hardware Requirements + +| Component | Minimum | Recommended | +|-----------|---------|-------------| +| GPU | A100 40GB | A100 80GB / H100 | +| GPUs | 4 | 8+ | +| RAM | 128 GB | 256 GB+ | +| Storage | 500 GB SSD | 2 TB NVMe | + +--- + +## Features + +- ✅ **Flow Matching**: Pure flow matching training +- ✅ **Distributed**: FSDP2 + Tensor Parallelism +- ✅ **Mixed Precision**: BF16 by default +- ✅ **WandB**: Automatic logging +- ✅ **Checkpointing**: consolidated, and sharded formats +- ✅ **Multi-node**: SLURM and torchrun support + +--- + +## Supported Models + +| Model | Parameters | Parallelization | Status | +|-------|------------|-----------------|--------| +| Wan 2.1 T2V 1.3B | 1.3B | FSDP2 via Automodel + DDP | ✅ | +| Wan 2.1 T2V 14B | 14B | FSDP2 via Automodel + DDP | ✅ | +| FLUX | TBD | TBD | 🔄 In Progress | + +--- + +## Advanced + +**Custom parallelization:** +```yaml +fsdp: + tp_size: 2 # Tensor parallel + dp_size: 4 # Data parallel +``` + +**Checkpoint cleanup:** +```python +from pathlib import Path +import shutil + +def cleanup_old_checkpoints(checkpoint_dir, keep_last_n=3): + checkpoints = sorted(Path(checkpoint_dir).glob("step_*")) + for old_ckpt in checkpoints[:-keep_last_n]: + shutil.rmtree(old_ckpt) +``` diff --git a/examples/automodel/generate/wan_validate.py b/examples/automodel/generate/wan_validate.py index 9c85a0e3..ced640e0 100644 --- a/examples/automodel/generate/wan_validate.py +++ b/examples/automodel/generate/wan_validate.py @@ -15,129 +15,58 @@ import argparse import os import pickle -import subprocess from pathlib import Path -import numpy as np import torch from diffusers import WanPipeline from diffusers.utils import export_to_video -from PIL import Image - - -try: - import wandb - - WANDB_AVAILABLE = True -except ImportError: - WANDB_AVAILABLE = False - print("[WARNING] wandb not installed. Install with: pip install wandb") - - -def convert_to_gif(video_path): - gif_path = Path(video_path).with_suffix(".gif") - cmd = [ - "ffmpeg", - "-y", - "-i", - str(video_path), - "-vf", - "fps=15,scale=512:-1:flags=lanczos", - "-loop", - "0", - str(gif_path), - ] - subprocess.run(cmd, check=True) - return str(gif_path) def parse_args(): - p = argparse.ArgumentParser("WAN 2.1 T2V Validation with Precomputed Embeddings") + p = argparse.ArgumentParser("WAN 2.1 T2V Validation") # Model configuration p.add_argument("--model_id", type=str, default="Wan-AI/Wan2.1-T2V-1.3B-Diffusers") p.add_argument("--checkpoint", type=str, default=None, help="Path to checkpoint (optional)") # Data - load from .meta files - p.add_argument("--meta_folder", type=str, required=True, help="Folder containing .meta files with embeddings") + p.add_argument("--meta_folder", type=str, required=True, help="Folder containing .meta files with prompts") # Generation settings - p.add_argument("--num_samples", type=int, default=10, help="Number of samples (default: 10)") + p.add_argument("--num_samples", type=int, default=None, help="Number of samples (default: all)") p.add_argument("--num_inference_steps", type=int, default=50) p.add_argument("--guidance_scale", type=float, default=5.0) + p.add_argument("--negative_prompt", type=str, default="") p.add_argument("--seed", type=int, default=42) + + # Video settings + p.add_argument("--height", type=int, default=480) + p.add_argument("--width", type=int, default=832) + p.add_argument("--num_frames", type=int, default=81) p.add_argument("--fps", type=int, default=16) # Output p.add_argument("--output_dir", type=str, default="./validation_outputs") - # Wandb settings - p.add_argument("--use_wandb", action="store_true", help="Upload results to Weights & Biases") - p.add_argument("--wandb_project", type=str, default="wan_t2v_valid", help="Wandb project name") - p.add_argument("--wandb_run_name", type=str, default=None, help="Wandb run name (default: auto-generated)") - return p.parse_args() -def infer_video_params_from_latents(latents): +def load_prompts_from_meta_files(meta_folder: str): """ - Infer video generation parameters from latent shape. + Load prompts from .meta files. + Each .meta file contains a 'metadata' dict with 'vila_caption'. - Args: - latents: torch.Tensor or np.ndarray with shape (16, T_latent, H_latent, W_latent) - or (1, 16, T_latent, H_latent, W_latent) - - Returns: - dict with num_frames, height, width - """ - # Convert to tensor if numpy - if not isinstance(latents, torch.Tensor): - latents = torch.from_numpy(latents) - - # Handle batch dimension - if latents.ndim == 5: - latents = latents[0] # Remove batch dim: (16, T_latent, H_latent, W_latent) - - C, T_latent, H_latent, W_latent = latents.shape - - # WAN 2.1 VAE compression ratios - temporal_compression = 4 - spatial_compression = 8 - - # Infer dimensions - num_frames = (T_latent - 1) * temporal_compression + 1 - height = H_latent * spatial_compression - width = W_latent * spatial_compression - - return { - "num_frames": num_frames, - "height": height, - "width": width, - } - - -def load_data_from_meta_files(meta_folder: str, num_samples: int = 10): - """ - Load text embeddings and metadata from .meta files. - - Returns list of dicts: [{ - "prompt": "...", - "name": "...", - "text_embeddings": tensor, - "num_frames": int, - "height": int, - "width": int - }, ...] + Returns list of dicts: [{"prompt": "...", "name": "...", "meta_file": "..."}, ...] """ meta_folder = Path(meta_folder) - meta_files = sorted(list(meta_folder.glob("*.meta")))[:num_samples] + meta_files = sorted(list(meta_folder.glob("*.meta"))) if not meta_files: raise FileNotFoundError(f"No .meta files found in {meta_folder}") - print(f"[INFO] Found {len(meta_files)} .meta files (limited to first {num_samples})") + print(f"[INFO] Found {len(meta_files)} .meta files") - data_list = [] + prompts = [] for meta_file in meta_files: try: @@ -152,122 +81,41 @@ def load_data_from_meta_files(meta_folder: str, num_samples: int = 10): print(f"[WARNING] No vila_caption in {meta_file.name}, skipping...") continue - # Get text embeddings - text_embeddings = data.get("text_embeddings") - if text_embeddings is None: - print(f"[WARNING] No text_embeddings in {meta_file.name}, skipping...") - continue - - # Convert to tensor and remove batch dimensions - if not isinstance(text_embeddings, torch.Tensor): - text_embeddings = torch.from_numpy(text_embeddings) - - # Squeeze out batch dimensions: (1, 1, seq_len, hidden_dim) -> (seq_len, hidden_dim) - while text_embeddings.ndim > 2 and text_embeddings.shape[0] == 1: - text_embeddings = text_embeddings.squeeze(0) - # Get filename without extension name = meta_file.stem - # Infer video dimensions from latents - video_params = None - if "video_latents" in data: - try: - video_params = infer_video_params_from_latents(data["video_latents"]) - except Exception as e: - print(f"[WARNING] Could not infer dimensions from {meta_file.name}: {e}") - - item = { - "prompt": prompt, - "name": name, - "text_embeddings": text_embeddings, - "meta_file": str(meta_file), - } - - # Add inferred dimensions if available - if video_params: - item.update(video_params) - - data_list.append(item) + prompts.append({"prompt": prompt, "name": name, "meta_file": str(meta_file)}) except Exception as e: print(f"[WARNING] Failed to load {meta_file.name}: {e}") continue - if not data_list: - raise ValueError(f"No valid data found in {meta_folder}") + if not prompts: + raise ValueError(f"No valid prompts found in {meta_folder}") - return data_list + return prompts def main(): args = parse_args() print("=" * 80) - print("WAN 2.1 Text-to-Video Validation (Using Precomputed Embeddings)") + print("WAN 2.1 Text-to-Video Validation") print("=" * 80) - # Initialize wandb if requested - wandb_run = None - if args.use_wandb: - if not WANDB_AVAILABLE: - print("[ERROR] wandb requested but not installed. Install with: pip install wandb") - print("[INFO] Continuing without wandb...") - else: - print("\n[WANDB] Initializing Weights & Biases...") - print(f"[WANDB] Project: {args.wandb_project}") - - # Generate run name if not provided - run_name = args.wandb_run_name - if run_name is None: - checkpoint_name = Path(args.checkpoint).name if args.checkpoint else "base_model" - run_name = f"validation_{checkpoint_name}" - - wandb_run = wandb.init( - project=args.wandb_project, - name=run_name, - config={ - "model_id": args.model_id, - "checkpoint": args.checkpoint, - "num_samples": args.num_samples, - "num_inference_steps": args.num_inference_steps, - "guidance_scale": args.guidance_scale, - "seed": args.seed, - "fps": args.fps, - }, - ) - print(f"[WANDB] Run name: {run_name}") - print(f"[WANDB] Run URL: {wandb_run.get_url()}") + # Load prompts from .meta files + print(f"\n[1] Loading prompts from .meta files in: {args.meta_folder}") + prompts = load_prompts_from_meta_files(args.meta_folder) - # Load data from .meta files - print(f"\n[1] Loading data from .meta files in: {args.meta_folder}") - data_list = load_data_from_meta_files(args.meta_folder, args.num_samples) + if args.num_samples: + prompts = prompts[: args.num_samples] - print(f"[INFO] Loaded {len(data_list)} samples") + print(f"[INFO] Loaded {len(prompts)} prompts") - # Show first few samples with dimensions + # Show first few prompts print("\n[INFO] Sample prompts:") - for i, item in enumerate(data_list[:3]): - dims_str = "" - if "num_frames" in item: - dims_str = f" [{item['num_frames']} frames, {item['width']}x{item['height']}]" - emb_shape = item["text_embeddings"].shape - print(f" {i + 1}. {item['name']}{dims_str}") - print(f" Prompt: {item['prompt'][:60]}...") - print(f" Text embeddings: {emb_shape}") - - # Check dimension consistency - items_with_dims = [p for p in data_list if "num_frames" in p] - if items_with_dims: - unique_dims = set((p["num_frames"], p["height"], p["width"]) for p in items_with_dims) - if len(unique_dims) == 1: - num_frames, height, width = list(unique_dims)[0] - print(f"\n[INFO] All samples have consistent dimensions: {num_frames} frames, {width}x{height}") - else: - print(f"\n[INFO] Found {len(unique_dims)} different dimension sets across samples") - for dims in unique_dims: - count = sum(1 for p in items_with_dims if (p["num_frames"], p["height"], p["width"]) == dims) - print(f" - {dims[0]} frames, {dims[2]}x{dims[1]}: {count} samples") + for i, item in enumerate(prompts[:3]): + print(f" {i + 1}. {item['name']}: {item['prompt'][:60]}...") # Load pipeline print(f"\n[2] Loading pipeline: {args.model_id}") @@ -283,137 +131,110 @@ def main(): if args.checkpoint: print(f"\n[3] Loading checkpoint: {args.checkpoint}") - # Try consolidated checkpoint or EMA checkpoint - consolidated_path = os.path.join(args.checkpoint, "consolidated_model.bin") + # Try EMA checkpoint first (best quality) ema_path = os.path.join(args.checkpoint, "ema_shadow.pt") + consolidated_path = os.path.join(args.checkpoint, "consolidated_model.bin") + sharded_dir = os.path.join(args.checkpoint, "model") - if os.path.exists(consolidated_path): - print("[INFO] Loading consolidated checkpoint...") - state_dict = torch.load(consolidated_path, map_location="cuda") - pipe.transformer.load_state_dict(state_dict, strict=True) - print("[INFO] Loaded from consolidated checkpoint") - elif os.path.exists(ema_path): + if os.path.exists(ema_path): print("[INFO] Loading EMA checkpoint (best quality)...") ema_state = torch.load(ema_path, map_location="cuda") pipe.transformer.load_state_dict(ema_state, strict=True) - print("[INFO] Loaded from EMA checkpoint") + print("[INFO] ✅ Loaded from EMA checkpoint") + elif os.path.exists(consolidated_path): + print("[INFO] ############Loading consolidated checkpoint...") + state_dict = torch.load(consolidated_path, map_location="cuda") + pipe.transformer.load_state_dict(state_dict, strict=True) + print("[INFO] ✅ ############Loaded from consolidated checkpoint") + elif os.path.isdir(sharded_dir) and any(name.endswith(".distcp") for name in os.listdir(sharded_dir)): + print(f"[INFO] Detected sharded FSDP checkpoint at: {sharded_dir}") + print("[INFO] Loading sharded checkpoint via PyTorch Distributed Checkpoint (single process)...") + + import torch.distributed as dist + from torch.distributed.checkpoint import FileSystemReader + from torch.distributed.checkpoint import load as dist_load + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import StateDictType + from torch.distributed.fsdp.api import ShardedStateDictConfig + + # Initialize a single-process group if not already initialized + init_dist = False + if not dist.is_initialized(): + dist.init_process_group(backend="gloo", rank=0, world_size=1) + init_dist = True + + # Wrap current transformer with FSDP to load sharded weights + base_transformer = pipe.transformer + + # Ensure uniform dtype before FSDP wraps/flattening + base_transformer.to(dtype=torch.bfloat16) + fsdp_transformer = FSDP(base_transformer, use_orig_params=True) + + # Configure to expect sharded state dict + FSDP.set_state_dict_type( + fsdp_transformer, + StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(offload_to_cpu=True), + ) + + # Load shards into the FSDP-wrapped model + model_state = fsdp_transformer.state_dict() + dist_load(state_dict=model_state, storage_reader=FileSystemReader(sharded_dir)) + fsdp_transformer.load_state_dict(model_state) + + # Unwrap back to the original module for inference + pipe.transformer = fsdp_transformer.module + + # Move to CUDA bf16 for inference + pipe.transformer.to("cuda", dtype=torch.bfloat16) + + if init_dist: + dist.destroy_process_group() + + print("[INFO] ✅ Loaded from sharded FSDP checkpoint") else: - print("[WARNING] No consolidated or EMA checkpoint found at specified path") - print("[INFO] Using base WAN 2.1 model weights from pipeline") - else: - print("\n[3] No checkpoint specified, using base WAN 2.1 model weights") + print("[WARNING] No consolidated or EMA checkpoint found") + print("[INFO] Using base model") # Create output directory os.makedirs(args.output_dir, exist_ok=True) # Generate videos - print("\n[4] Generating videos using precomputed text embeddings...") - print(f"[INFO] Settings: {args.num_inference_steps} steps, guidance scale: {args.guidance_scale}") + print("\n[4] Generating videos...") + print(f"[INFO] Settings: {args.width}x{args.height}, {args.num_frames} frames, {args.num_inference_steps} steps") + print(f"[INFO] Guidance scale: {args.guidance_scale}") torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) - # Track successful generations - num_generated = 0 - - for i, item in enumerate(data_list): + for i, item in enumerate(prompts): prompt = item["prompt"] name = item["name"] - text_embeddings = item["text_embeddings"] - - # Get dimensions for this sample - num_frames = item.get("num_frames") - height = item.get("height") - width = item.get("width") - - if num_frames is None or height is None or width is None: - print(f"\n[{i + 1}/{len(data_list)}] Skipping {name}: missing dimensions") - continue - print(f"\n[{i + 1}/{len(data_list)}] Generating: {name}") + print(f"\n[{i + 1}/{len(prompts)}] Generating: {name}") print(f" Prompt: {prompt[:80]}...") - print(f" Dimensions: {num_frames} frames, {width}x{height}") - print(f" Text embeddings: {text_embeddings.shape}") try: - # Move embeddings to GPU - text_embeddings = text_embeddings.to(device="cuda", dtype=torch.bfloat16) - - # Add batch dimension if needed: (seq_len, hidden_dim) -> (1, seq_len, hidden_dim) - if text_embeddings.ndim == 2: - text_embeddings = text_embeddings.unsqueeze(0) - - # Generate using precomputed embeddings + # Generate from scratch (no latents needed!) generator = torch.Generator(device="cuda").manual_seed(args.seed + i) - # Call pipeline with prompt_embeds instead of prompt output = pipe( - prompt_embeds=text_embeddings, - negative_prompt="", # Use empty string for negative prompt - height=height, - width=width, - num_frames=num_frames, + prompt=prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, guidance_scale=args.guidance_scale, num_inference_steps=args.num_inference_steps, generator=generator, ).frames[0] - # Save as image if single frame, otherwise as video - if num_frames == 1: - output_path = os.path.join(args.output_dir, f"{name}.png") - - # output is a numpy array, squeeze out extra dimensions - frame = np.squeeze(output) # Remove all dimensions of size 1 - - # Ensure we have the right shape (H, W, C) - if frame.ndim == 2: # Grayscale - pass - elif frame.ndim == 3 and frame.shape[-1] in [1, 3, 4]: # RGB/RGBA - pass - else: - raise ValueError(f"Unexpected frame shape: {frame.shape}") - - # Convert from float [0, 1] to uint8 [0, 255] - if frame.dtype in [np.float32, np.float64]: - frame = (frame * 255).clip(0, 255).astype(np.uint8) - - image = Image.fromarray(frame) - image.save(output_path) - print(f" ✅ Saved image to {output_path}") - - # Upload to wandb immediately - if wandb_run is not None: - print(" 📤 Uploading image to wandb...") - wandb_run.log( - { - f"image/{name}": wandb.Image(image, caption=prompt[:100]), - f"prompt/{name}": prompt, - f"dimensions/{name}": f"{width}x{height}", - "sample_index": i, - } - ) - print(" ✅ Uploaded to wandb!") - - else: - output_path = os.path.join(args.output_dir, f"{name}.mp4") - export_to_video(output, output_path, fps=args.fps) - print(f" ✅ Saved video to {output_path}") - gif_path = convert_to_gif(output_path) - # Upload to wandb immediately - if wandb_run is not None: - print(" 📤 Uploading video to wandb...") - wandb_run.log( - { - f"video/{name}": wandb.Image(gif_path), - f"prompt/{name}": prompt, - f"dimensions/{name}": f"{num_frames} frames, {width}x{height}", - "sample_index": i, - } - ) - print(" ✅ Uploaded to wandb!") - - num_generated += 1 + # Save video + output_path = os.path.join(args.output_dir, f"{name}.mp4") + export_to_video(output, output_path, fps=args.fps) + + print(f" ✅ Saved to {output_path}") except Exception as e: print(f" ❌ Failed: {e}") @@ -423,54 +244,10 @@ def main(): continue print("\n" + "=" * 80) - print("Validation complete!") - print(f"Generated: {num_generated}/{len(data_list)} samples") - print(f"Outputs saved to: {args.output_dir}") - if wandb_run is not None: - print(f"Wandb results: {wandb_run.get_url()}") + print("✅ Validation complete!") + print(f"📁 Videos saved to: {args.output_dir}") print("=" * 80) - # Finish wandb run - if wandb_run is not None: - wandb_run.finish() - if __name__ == "__main__": main() - - -# ============================================================================ -# USAGE EXAMPLES -# ============================================================================ - -# 1. Basic usage (uses precomputed text embeddings from .meta files): -# python validate_t2v.py \ -# --meta_folder /linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta - -# 2. With wandb logging: -# python validate_t2v.py \ -# --meta_folder /linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta \ -# --use_wandb \ -# --wandb_project wan_t2v_valid \ -# --wandb_run_name "validation_checkpoint_5000" - -# 3. With trained checkpoint and wandb: -# python validate_t2v.py \ -# --meta_folder /linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta \ -# --checkpoint ./wan_t2v_all_fixes/checkpoint-5000 \ -# --use_wandb - -# 4. Limited samples with custom settings: -# python validate_t2v.py \ -# --meta_folder /linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta \ -# --checkpoint ./checkpoint-5000 \ -# --num_samples 5 \ -# --num_inference_steps 50 \ -# --guidance_scale 5.0 \ -# --use_wandb - -# 5. If no checkpoint found, uses base WAN 2.1 weights: -# python validate_t2v.py \ -# --meta_folder /linnanw/hdvilla_sample/pika/wan21_codes/1.3B_meta \ -# --checkpoint ./nonexistent_checkpoint \ -# --use_wandb # Will fall back to base model and log to wandb