diff --git a/flux.1-dev-trt-b200/README.md b/flux.1-dev-trt-b200/README.md new file mode 100644 index 000000000..80ff0cfbe --- /dev/null +++ b/flux.1-dev-trt-b200/README.md @@ -0,0 +1,280 @@ +# Flux 1.0 Dev - TensorRT + +This model provides high-quality text-to-image generation using [Flux.1-dev model](https://huggingface.co/black-forest-labs/FLUX.1-dev) optimized with TensorRT for the B200 GPU. + +## Model Information + +- **Model**: Flux 1.0 Dev (black-forest-labs/FLUX.1-dev) +- **Optimization**: TensorRT 8.6.1 +- **Hardware**: NVIDIA B200 GPU +- **Framework**: PyTorch with TensorRT acceleration + +## Features + +- High-quality image generation from text prompts +- TensorRT optimization for fast inference +- Support for custom image dimensions (must be multiples of 8) +- Configurable denoising steps and guidance scale +- CUDA graph optimization support + +## Usage + +### Basic Usage + +```bash +curl -X POST https://app.baseten.co/models/{MODEL_ID}/predict \ + -H "Authorization: Api-Key API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "negative_prompt": "blurry, low quality, distorted", + "height": 1024, + "width": 1024, + "num_inference_steps": 30, + "guidance_scale": 3.5, + "seed": 42, + "batch_size": 1, + "batch_count": 1 + }' | python show.py +``` + +### Batch Processing + +The model supports efficient batch processing for generating multiple images in a single request. You can generate up to 4 images simultaneously with either the same prompt or different prompts for each image. + +#### Same Prompt for All Images + +```bash +curl -X POST https://app.baseten.co/models/{MODEL_ID}/predict \ + -H "Authorization: Api-Key API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "negative_prompt": "blurry, low quality, distorted", + "height": 1024, + "width": 1024, + "num_inference_steps": 30, + "guidance_scale": 3.5, + "seed": 42, + "batch_size": 4, + "batch_count": 1 + }' | python show_batch.py +``` + +#### Different Prompts for Each Image + +```bash +curl -X POST https://app.baseten.co/models/{MODEL_ID}/predict \ + -H "Authorization: Api-Key API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "prompt": [ + "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "A futuristic city skyline at sunset, neon lights, cyberpunk style, high quality", + "A cute cat sitting in a garden, soft lighting, detailed, high quality", + "Abstract geometric patterns in vibrant colors, modern art style, high quality" + ], + "negative_prompt": [ + "blurry, low quality, distorted", + "blurry, low quality, distorted", + "blurry, low quality, distorted", + "blurry, low quality, distorted" + ], + "height": 1024, + "width": 1024, + "num_inference_steps": 30, + "guidance_scale": 3.5, + "seed": 42, + "batch_size": 4, + "batch_count": 1 + }' | python show_batch.py +``` + +#### Batch Processing Benefits + +- **Parallel Processing**: All images in a batch are generated simultaneously, not sequentially +- **Better GPU Utilization**: More efficient use of GPU resources compared to separate requests +- **Faster Total Time**: Generating 4 images in a batch is significantly faster than 4 separate API calls +- **Consistent Parameters**: All images in a batch use the same dimensions, inference steps, and guidance scale + +#### Batch Processing Limitations + +- **Maximum Batch Size**: Limited to 4 images per batch (MAX_BATCH_SIZE = 4) +- **Prompt Array Length**: When using different prompts, the prompt array length must match the batch_size +- **Memory Requirements**: Larger batches require more GPU memory + +#### Displaying Batch Results + +Use the included `show_batch.py` script to handle multiple images in the response: + +```bash +# The script automatically detects single vs multiple images +curl ... | python show_batch.py +``` + +The script will: +- Save each image with a unique filename +- Automatically open all generated images +- Print status information about the batch + +### Load Test +Before running the load test, update `load_test.py` with your actual endpoint URL and API key. Replace the placeholder values for `api_url` and `api_key` with your deployment's information (lines 43 and 44). + +```bash +python load_test.py --save-all-images --use-varied-prompts --concurrent --num-requests 30 + +šŸš€ Starting Flux Truss API test... +================================================== +šŸŽØ Using 30 varied prompts for load testing +šŸ“ First 3 prompts to be tested: + 1. a beautiful photograph of Mt. Fuji during cherry blossom, photorealistic, high quality + 2. a majestic dragon soaring through a mystical forest, digital art, detailed + 3. a cozy coffee shop interior with warm lighting, people working on laptops, photorealistic + ... and 27 more +šŸš€ Starting concurrent load test with 30 requests, max 5 workers +============================================================ +šŸ“¤ Sending request 1/30: 'a beautiful photograph of Mt. Fuji during cherry b...' +Testing Truss API endpoint with prompt: 'a beautiful photograph of Mt. Fuji during cherry blossom, photorealistic, high quality' + +... + +============================================================ +šŸ“Š LOAD TEST SUMMARY +============================================================ +Total requests: 30 +Successful: 30 +Failed: 0 +Success rate: 100.0% +Total time: 74.72 seconds +Average request time: 11.63 seconds +Min request time: 2.91 seconds +Max request time: 12.86 seconds +Throughput: 0.40 requests/second + +================================================== + +šŸ’¾ Saving 30 successful images... +āœ… Successfully saved 30/30 images to './output' +šŸ“ Opened output directory + +================================================== +``` + +### Advanced Usage + +```python +{ + "prompt": "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "prompt2": "Additional prompt for T5 tokenizer", # Optional, uses prompt if not provided + "negative_prompt": "blurry, low quality, distorted", # Optional + "height": 1024, # Must be multiple of 8 + "width": 1024, # Must be multiple of 8 + "denoising_steps": 50, # Number of denoising steps + "guidance_scale": 3.5, # Classifier-free guidance scale (must be > 1) + "seed": 42, # Random seed for reproducible results + "batch_size": 1, # Number of images to generate + "batch_count": 1, # Number of batches + "num_warmup_runs": 0, # Number of warmup runs + "max_sequence_length": 512, # Max sequence length (up to 512 for flux.1-dev) + "t5_ws_percentage": None, # T5 weight streaming percentage + "transformer_ws_percentage": None # Transformer weight streaming percentage +} +``` + +## Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `prompt` | string/array | required | Text prompt(s) for image generation. Can be a single string or array of strings for different prompts per batch item | +| `prompt2` | string/array | same as prompt | Additional prompt(s) for T5 tokenizer. Can be a single string or array of strings | +| `negative_prompt` | string/array | "" | Negative prompt(s) to avoid certain elements. Can be a single string or array of strings | +| `height` | int | 1024 | Image height (must be multiple of 8) | +| `width` | int | 1024 | Image width (must be multiple of 8) | +| `denoising_steps` | int | 50 | Number of denoising steps | +| `guidance_scale` | float | 3.5 | Classifier-free guidance scale (> 1) | +| `seed` | int | None | Random seed for reproducibility | +| `batch_size` | int | 1 | Number of images per batch | +| `batch_count` | int | 1 | Number of batches | +| `num_warmup_runs` | int | 0 | Number of warmup runs | +| `max_sequence_length` | int | 512 | Maximum sequence length (≤ 512) | +| `t5_ws_percentage` | int | None | T5 weight streaming percentage | +| `transformer_ws_percentage` | int | None | Transformer weight streaming percentage | + +## Response Format + +The model returns a JSON response with the following structure: + +### Single Image Response + +```json +{ + "status": "success", + "data": "base64_encoded_image", + "time": 2.34, + "prompt": "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "negative_prompt": "blurry, low quality, distorted", + "height": 1024, + "width": 1024, + "num_inference_steps": 30, + "guidance_scale": 3.5, + "seed": 42 +} +``` + +### Batch Response (Multiple Images) + +```json +{ + "status": "success", + "data": [ + "base64_encoded_image_1", + "base64_encoded_image_2", + "base64_encoded_image_3", + "base64_encoded_image_4" + ], + "time": 7.23, + "prompt": [ + "A beautiful landscape with mountains and a lake, photorealistic, high quality", + "A futuristic city skyline at sunset, neon lights, cyberpunk style, high quality", + "A cute cat sitting in a garden, soft lighting, detailed, high quality", + "Abstract geometric patterns in vibrant colors, modern art style, high quality" + ], + "negative_prompt": [ + "blurry, low quality, distorted", + "blurry, low quality, distorted", + "blurry, low quality, distorted", + "blurry, low quality, distorted" + ], + "height": 1024, + "width": 1024, + "num_inference_steps": 30, + "guidance_scale": 3.5, + "seed": 42 +} +``` + +## Performance Notes + +The model is optimized with a pre-compiled TensorRT engine for the NVIDIA B200 GPU. +Performance characteristic is described below for the [basic usage](#basic-usage) + + +```text +|------------------|--------------| +| Module | Latency | +|------------------|--------------| +| CLIP | 2.02 ms | +| T5 | 6.43 ms | +| Transformer x 50 | 2361.44 ms | +| VAE-Dec | 11.67 ms | +|------------------|--------------| +| Pipeline | 2382.45 ms | +|------------------|--------------| +``` + +## Model Variants + +This implementation supports the `flux.1-dev` variant with: +- Maximum sequence length: 512 tokens +- Default image dimensions: 1024x1024 +- Optimized for high-quality image generation diff --git a/flux.1-dev-trt-b200/config.yaml b/flux.1-dev-trt-b200/config.yaml new file mode 100644 index 000000000..d2b6b2e2f --- /dev/null +++ b/flux.1-dev-trt-b200/config.yaml @@ -0,0 +1,17 @@ +base_image: + image: nvcr.io/nvidia/pytorch:25.06-py3 +description: Generate high-quality images from text prompts using Black Forest Labs's Flux model with TensorRT optimization. +external_package_dirs: [] +model_name: Flux 1.0 Dev - TensorRT +requirements_file: ./requirements.txt +resources: + accelerator: B200 + use_gpu: true +runtime: + predict_concurrency: 1 +secrets: + hf_access_token: null +system_packages: +- ffmpeg +- libsm6 +- libxext6 diff --git a/flux.1-dev-trt-b200/load_test.py b/flux.1-dev-trt-b200/load_test.py new file mode 100644 index 000000000..673289cd9 --- /dev/null +++ b/flux.1-dev-trt-b200/load_test.py @@ -0,0 +1,650 @@ +#!/usr/bin/env python3 +""" +Test script for the Flux model using Truss API endpoint. + +Usage: + python test_truss.py + python test_truss.py --save-image + python test_truss.py --prompt "your custom prompt here" +""" + +import argparse +import base64 +import os +import time +import requests +import concurrent.futures +from typing import List, Dict, Any + + +def test_truss_api_call( + prompt="a beautiful photograph of Mt. Fuji during cherry blossom, photorealistic, high quality", + height=1024, + width=1024, + steps=30, + guidance_scale=3.5, + seed=42, +): + """Test Truss API endpoint prediction.""" + print(f"Testing Truss API endpoint with prompt: '{prompt}'") + + # Prepare input - same structure as test_model.py + model_input = { + "prompt": prompt, + "negative_prompt": "blurry, low quality, distorted", + "height": height, + "width": width, + "num_inference_steps": steps, + "guidance_scale": guidance_scale, + "seed": seed, + "batch_size": 1, + "batch_count": 1, + } + + # API endpoint configuration + api_url = "https://app.baseten.co/models/{MODEL_ID}/predict" + api_key = os.getenv( + "BASETEN_API_KEY", "YOUR_API_KEY" + ) # Get from environment variable + + headers = { + "Authorization": f"Api-Key {api_key}", + "Content-Type": "application/json", + } + + try: + start_time = time.time() + response = requests.post(api_url, headers=headers, json=model_input) + end_time = time.time() + + if response.status_code == 200: + result = response.json() + + if result.get("status") == "success": + print("āœ… Truss API prediction successful!") + print(f" Time: {end_time - start_time:.2f} seconds") + print(f" Prompt: {result.get('prompt', prompt)}") + print( + f" Dimensions: {result.get('width', width)}x{result.get('height', height)}" + ) + print(f" Steps: {result.get('num_inference_steps', steps)}") + print( + f" Guidance Scale: {result.get('guidance_scale', guidance_scale)}" + ) + print(f" Seed: {result.get('seed', seed)}") + + return { + "result": result, + "time": end_time - start_time, + "success": True, + } + else: + print( + f"āŒ Truss API prediction failed: {result.get('error', 'Unknown error')}" + ) + return { + "result": None, + "time": end_time - start_time, + "success": False, + "error": result.get("error", "Unknown error"), + } + else: + print(f"āŒ API request failed with status code: {response.status_code}") + print(f" Response: {response.text}") + return { + "result": None, + "time": end_time - start_time, + "success": False, + "error": f"HTTP {response.status_code}", + } + + except Exception as e: + print(f"āŒ Truss API prediction failed with exception: {e}") + return {"result": None, "time": 0, "success": False, "error": str(e)} + + +def test_concurrent_requests( + prompts: List[str], + max_workers: int = 5, + height=1024, + width=1024, + steps=30, + guidance_scale=3.5, + seed=42, + sizes: List[tuple] = None, +): + """Test Truss API endpoint with concurrent requests.""" + print( + f"šŸš€ Starting concurrent load test with {len(prompts)} requests, max {max_workers} workers" + ) + if sizes: + print(f"šŸ“ Using variable image sizes: {len(sizes)} different dimensions") + print("=" * 60) + + results = [] + start_time = time.time() + + def make_request(prompt, request_id): + """Make a single API request.""" + # Use variable size if provided, otherwise use default + if sizes and request_id < len(sizes): + h, w = sizes[request_id] + print( + f"šŸ“¤ Sending request {request_id + 1}/{len(prompts)} ({h}x{w}): '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'" + ) + return test_truss_api_call(prompt, h, w, steps, guidance_scale, seed) + else: + print( + f"šŸ“¤ Sending request {request_id + 1}/{len(prompts)}: '{prompt[:50]}{'...' if len(prompt) > 50 else ''}'" + ) + return test_truss_api_call( + prompt, height, width, steps, guidance_scale, seed + ) + + # Use ThreadPoolExecutor for concurrent requests + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + # Submit all requests + future_to_prompt = { + executor.submit(make_request, prompt, i): (prompt, i) + for i, prompt in enumerate(prompts) + } + + # Collect results as they complete + for future in concurrent.futures.as_completed(future_to_prompt): + prompt, request_id = future_to_prompt[future] + try: + result = future.result() + results.append(result) + print(f"šŸ“„ Completed request {request_id + 1}/{len(prompts)}") + except Exception as e: + print(f"āŒ Request {request_id + 1} failed with exception: {e}") + results.append( + {"result": None, "time": 0, "success": False, "error": str(e)} + ) + + end_time = time.time() + total_time = end_time - start_time + + # Calculate statistics + successful_requests = [r for r in results if r["success"]] + failed_requests = [r for r in results if not r["success"]] + + if successful_requests: + avg_time = sum(r["time"] for r in successful_requests) / len( + successful_requests + ) + min_time = min(r["time"] for r in successful_requests) + max_time = max(r["time"] for r in successful_requests) + else: + avg_time = min_time = max_time = 0 + + # Print summary + print("\n" + "=" * 60) + print("šŸ“Š LOAD TEST SUMMARY") + print("=" * 60) + print(f"Total requests: {len(prompts)}") + print(f"Successful: {len(successful_requests)}") + print(f"Failed: {len(failed_requests)}") + print(f"Success rate: {len(successful_requests) / len(prompts) * 100:.1f}%") + print(f"Total time: {total_time:.2f} seconds") + print(f"Average request time: {avg_time:.2f} seconds") + print(f"Min request time: {min_time:.2f} seconds") + print(f"Max request time: {max_time:.2f} seconds") + print(f"Throughput: {len(successful_requests) / total_time:.2f} requests/second") + + if failed_requests: + print("\nāŒ Failed requests:") + for i, result in enumerate(failed_requests): + print(f" Request {i + 1}: {result.get('error', 'Unknown error')}") + + return results + + +def get_varied_prompts(num_prompts: int) -> List[str]: + """Generate a list of varied prompts for load testing.""" + base_prompts = [ + "a beautiful photograph of Mt. Fuji during cherry blossom, photorealistic, high quality", + "a majestic dragon soaring through a mystical forest, digital art, detailed", + "a cozy coffee shop interior with warm lighting, people working on laptops, photorealistic", + "a futuristic cityscape at sunset with flying cars and neon lights, cinematic", + "a serene mountain lake reflecting snow-capped peaks, nature photography, high resolution", + "a steampunk airship floating above Victorian-era buildings, detailed illustration", + "a magical library with floating books and glowing crystals, fantasy art", + "a peaceful Japanese garden with koi pond and cherry blossoms, traditional art style", + "a cyberpunk street scene with neon signs and rain, Blade Runner style", + "a whimsical treehouse village connected by rope bridges, children's book illustration", + "a dramatic storm over the ocean with lightning and waves, nature photography", + "a cozy cabin in the woods with smoke from chimney, winter scene, photorealistic", + "a space station orbiting Earth with stars and nebula in background, sci-fi art", + "a bustling medieval marketplace with merchants and colorful stalls, fantasy", + "a tranquil zen meditation room with candles and incense, minimalist design", + "a roaring waterfall in a tropical jungle with exotic birds, nature photography", + "a steampunk robot butler serving tea in a Victorian parlor, detailed illustration", + "a magical crystal cave with glowing formations and underground lake, fantasy", + "a peaceful farm at golden hour with rolling hills and grazing animals, pastoral", + "a futuristic robot city with advanced technology and clean architecture, sci-fi", + ] + + # If we need more prompts than we have, cycle through them + if num_prompts <= len(base_prompts): + return base_prompts[:num_prompts] + else: + # Cycle through prompts and add variations + prompts = [] + for i in range(num_prompts): + base_prompt = base_prompts[i % len(base_prompts)] + if i >= len(base_prompts): + # Add variation number for additional prompts + variation_num = (i // len(base_prompts)) + 1 + prompts.append(f"{base_prompt} (variation {variation_num})") + else: + prompts.append(base_prompt) + return prompts + + +def get_variable_sizes(num_sizes: int, size_range: str = "512-1024") -> List[tuple]: + """Generate a list of variable image sizes for load testing.""" + try: + min_size, max_size = map(int, size_range.split("-")) + # Ensure sizes are multiples of 8 (model requirement) + min_size = (min_size // 8) * 8 + max_size = (max_size // 8) * 8 + + # Common sizes that work well with the model + common_sizes = [ + (512, 512), + (512, 768), + (768, 512), + (768, 768), + (768, 1024), + (1024, 768), + (1024, 1024), + (1024, 1280), + (1280, 1024), + (1280, 1280), + (1280, 1536), + (1536, 1280), + (1536, 1536), + (1536, 1792), + (1792, 1536), + (1792, 1792), + (1792, 2048), + (2048, 1792), + (2048, 2048), + ] + + # Filter sizes within the specified range + valid_sizes = [ + (h, w) + for h, w in common_sizes + if min_size <= h <= max_size and min_size <= w <= max_size + ] + + if not valid_sizes: + # Fallback to simple multiples of 8 within range + valid_sizes = [] + for size in range(min_size, max_size + 1, 64): # Step by 64 for variety + if size % 8 == 0: + valid_sizes.append((size, size)) + + # If we need more sizes than available, cycle through them + if num_sizes <= len(valid_sizes): + return valid_sizes[:num_sizes] + else: + # Cycle through sizes and add variations + sizes = [] + for i in range(num_sizes): + base_size = valid_sizes[i % len(valid_sizes)] + if i >= len(valid_sizes): + # Add variation by slightly modifying dimensions + variation_num = (i // len(valid_sizes)) + 1 + h, w = base_size + # Add small variations while keeping multiples of 8 + h_var = h + (variation_num * 8) % 64 + w_var = w + (variation_num * 8) % 64 + sizes.append((h_var, w_var)) + else: + sizes.append(base_size) + return sizes + + except Exception as e: + print(f"āŒ Error parsing size range '{size_range}': {e}") + print(" Using default size 1024x1024") + return [(1024, 1024)] * num_sizes + + +def save_image(result, filename="test_truss_output.jpg", output_dir="./output"): + """Save the generated image to a file using the same decoding logic as test_model.py.""" + # Handle both old and new result formats + if isinstance(result, dict) and "result" in result: + # New format from concurrent testing + actual_result = result["result"] + if not actual_result or actual_result.get("status") != "success": + print("āŒ No valid result to save") + return + else: + # Old format from single request + actual_result = result + if not actual_result or actual_result.get("status") != "success": + print("āŒ No valid result to save") + return + + try: + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Use output directory for filename + output_path = os.path.join(output_dir, filename) + + # Decode base64 image - same logic as test_model.py + image_data = base64.b64decode(actual_result["data"]) + + # Save to file + with open(output_path, "wb") as f: + f.write(image_data) + + print(f"āœ… Image saved as '{output_path}'") + + # Try to open the image (macOS) + try: + os.system(f"open {output_path}") + print("āœ… Image opened in default viewer") + except: + print(f"šŸ“ Image saved at: {os.path.abspath(output_path)}") + + except Exception as e: + print(f"āŒ Failed to save image: {e}") + + +def save_image_bulk(result, filename="test_truss_output.jpg", output_dir="./output"): + """Save the generated image to a file without opening it (for bulk operations).""" + # Handle both old and new result formats + if isinstance(result, dict) and "result" in result: + # New format from concurrent testing + actual_result = result["result"] + if not actual_result or actual_result.get("status") != "success": + return False + else: + # Old format from single request + actual_result = result + if not actual_result or actual_result.get("status") != "success": + return False + + try: + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Use output directory for filename + output_path = os.path.join(output_dir, filename) + + # Decode base64 image - same logic as test_model.py + image_data = base64.b64decode(actual_result["data"]) + + # Save to file + with open(output_path, "wb") as f: + f.write(image_data) + + return True + + except Exception as e: + print(f"āŒ Failed to save image: {e}") + return False + + +def save_all_images( + results: List[Dict[str, Any]], + prompts: List[str], + output_dir="./output", + sizes: List[tuple] = None, +): + """Save all successful images from concurrent load test results.""" + if not results: + print("āŒ No results to save") + return + + successful_results = [r for r in results if r["success"]] + if not successful_results: + print("āŒ No successful results to save") + return + + print(f"\nšŸ’¾ Saving {len(successful_results)} successful images...") + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + saved_count = 0 + for i, result in enumerate(successful_results): + try: + # Get the corresponding prompt for this result + # Note: This assumes results are in the same order as prompts + # In a real scenario, you might want to track which prompt corresponds to which result + prompt = prompts[i] if i < len(prompts) else f"prompt_{i + 1}" + + # Get size information if available + size_info = "" + if sizes and i < len(sizes): + h, w = sizes[i] + size_info = f"_{h}x{w}" + elif ( + result.get("result") + and "width" in result["result"] + and "height" in result["result"] + ): + h = result["result"]["height"] + w = result["result"]["width"] + size_info = f"_{h}x{w}" + + # Create a safe filename from the prompt + safe_filename = "".join( + c for c in prompt[:50] if c.isalnum() or c in (" ", "-", "_") + ).rstrip() + safe_filename = safe_filename.replace(" ", "_") + filename = f"load_test_{i + 1:03d}_{safe_filename}{size_info}.jpg" + + # Save the image without opening it (for bulk saves) + save_image_bulk(result, filename=filename, output_dir=output_dir) + saved_count += 1 + + except Exception as e: + print(f"āŒ Failed to save image {i + 1}: {e}") + + print( + f"āœ… Successfully saved {saved_count}/{len(successful_results)} images to '{output_dir}'" + ) + + # Try to open the output directory + try: + os.system(f"open {output_dir}") + print("šŸ“ Opened output directory") + except: + print(f"šŸ“ Images saved in: {os.path.abspath(output_dir)}") + + +def main(): + parser = argparse.ArgumentParser( + description="Test the Flux model using Truss API endpoint" + ) + parser.add_argument( + "--save-image", action="store_true", help="Save the generated image to a file" + ) + parser.add_argument( + "--prompt", + type=str, + default="a beautiful photograph of Mt. Fuji during cherry blossom, photorealistic, high quality", + help="Custom prompt for testing", + ) + parser.add_argument( + "--steps", type=int, default=50, help="Number of inference steps" + ) + parser.add_argument("--height", type=int, default=1024, help="Image height") + parser.add_argument("--width", type=int, default=1024, help="Image width") + parser.add_argument( + "--api-key", + type=str, + help="Baseten API key (or set BASETEN_API_KEY environment variable)", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./output", + help="Output directory for saved images (default: ./output)", + ) + + # Load testing arguments + parser.add_argument( + "--concurrent", action="store_true", help="Run concurrent load test" + ) + parser.add_argument( + "--num-requests", + type=int, + default=5, + help="Number of concurrent requests (default: 5)", + ) + parser.add_argument( + "--max-workers", + type=int, + default=5, + help="Maximum number of concurrent workers (default: 5)", + ) + parser.add_argument( + "--prompt-file", + type=str, + help="File containing prompts (one per line) for load testing", + ) + parser.add_argument( + "--use-varied-prompts", + action="store_true", + help="Use different preset prompts for load testing", + ) + parser.add_argument( + "--save-all-images", + action="store_true", + help="Save all successful images from concurrent load test (automatically enables image saving)", + ) + parser.add_argument( + "--variable-sizes", + action="store_true", + help="Use variable image sizes during load testing", + ) + parser.add_argument( + "--size-range", + type=str, + default="512-1024", + help="Range of sizes for variable testing (format: min-max, e.g., 512-1024)", + ) + + args = parser.parse_args() + + print("šŸš€ Starting Flux Truss API test...") + print("=" * 50) + + # Set API key if provided + if args.api_key: + os.environ["BASETEN_API_KEY"] = args.api_key + + if args.concurrent: + # Load testing mode + prompts = [] + + if args.prompt_file: + # Load prompts from file + try: + with open(args.prompt_file, "r") as f: + prompts = [line.strip() for line in f if line.strip()] + print(f"šŸ“„ Loaded {len(prompts)} prompts from {args.prompt_file}") + except FileNotFoundError: + print(f"āŒ Prompt file not found: {args.prompt_file}") + return + else: + # Generate prompts based on num_requests + if args.use_varied_prompts: + # Use varied preset prompts + prompts = get_varied_prompts(args.num_requests) + print(f"šŸŽØ Using {len(prompts)} varied prompts for load testing") + else: + # Use base prompt with variations + base_prompt = args.prompt + for i in range(args.num_requests): + if i == 0: + prompts.append(base_prompt) + else: + # Create variations of the base prompt + prompts.append(f"{base_prompt} (variation {i + 1})") + + # Generate variable sizes if requested + sizes = None + if args.variable_sizes: + sizes = get_variable_sizes(len(prompts), args.size_range) + print(f"šŸ“ Generated {len(sizes)} variable sizes:") + for i, (h, w) in enumerate(sizes[:5]): # Show first 5 + print(f" {i + 1}. {h}x{w}") + if len(sizes) > 5: + print(f" ... and {len(sizes) - 5} more") + + # Show prompts being used (first few) + if len(prompts) <= 5: + print("šŸ“ Prompts to be tested:") + for i, prompt in enumerate(prompts): + size_info = ( + f" ({sizes[i][0]}x{sizes[i][1]})" + if sizes and i < len(sizes) + else "" + ) + print(f" {i + 1}. {prompt}{size_info}") + else: + print("šŸ“ First 3 prompts to be tested:") + for i, prompt in enumerate(prompts[:3]): + size_info = ( + f" ({sizes[i][0]}x{sizes[i][1]})" + if sizes and i < len(sizes) + else "" + ) + print(f" {i + 1}. {prompt}{size_info}") + print(f" ... and {len(prompts) - 3} more") + + # Run concurrent load test + results = test_concurrent_requests( + prompts=prompts, + max_workers=args.max_workers, + height=args.height, + width=args.width, + steps=args.steps, + guidance_scale=3.5, + seed=42, + sizes=sizes, + ) + + # Save images if requested + if args.save_image or args.save_all_images: + if args.save_all_images: + # Save all successful images + print("\n" + "=" * 50) + save_all_images(results, prompts, args.output_dir, sizes) + else: + # Save only first successful result + successful_results = [r for r in results if r["success"]] + if successful_results: + print("\n" + "=" * 50) + save_image(successful_results[0], output_dir=args.output_dir) + + else: + # Single request mode + result = test_truss_api_call( + prompt=args.prompt, + height=args.height, + width=args.width, + steps=args.steps, + guidance_scale=3.5, + seed=42, + ) + + if result["success"] and args.save_image: + print("\n" + "=" * 50) + save_image(result, output_dir=args.output_dir) + + print("\n" + "=" * 50) + print("šŸŽ‰ Truss API test completed!") + + +if __name__ == "__main__": + main() diff --git a/flux.1-dev-trt-b200/model/__init__.py b/flux.1-dev-trt-b200/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/dd_argparse.py b/flux.1-dev-trt-b200/model/demo_diffusion/dd_argparse.py new file mode 100644 index 000000000..b68e0e77b --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/dd_argparse.py @@ -0,0 +1,559 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import argparse +from typing import Any, Dict, Tuple + +import torch + +# Define valid optimization levels for TensorRT engine build +VALID_OPTIMIZATION_LEVELS = list(range(6)) + + +def parse_key_value_pairs(string: str) -> Dict[str, str]: + """Parse a string of comma-separated key-value pairs into a dictionary. + + Args: + string (str): A string of comma-separated key-value pairs. + + Returns: + Dict[str, str]: Parsed dictionary of key-value pairs. + + Example: + >>> parse_key_value_pairs("key1:value1,key2:value2") + {"key1": "value1", "key2": "value2"} + """ + parsed = {} + + for key_value_pair in string.split(","): + if not key_value_pair: + continue + + key_value_pair = key_value_pair.split(":") + if len(key_value_pair) != 2: + raise argparse.ArgumentTypeError( + f"Invalid key-value pair: {key_value_pair}. Must have length 2." + ) + key, value = key_value_pair + parsed[key] = value + + return parsed + + +def add_arguments(parser): + # Stable Diffusion configuration + parser.add_argument( + "--version", + type=str, + default="1.5", + choices=( + "1.4", + "1.5", + "dreamshaper-7", + "2.0-base", + "2.0", + "2.1-base", + "2.1", + "xl-1.0", + "xl-turbo", + "svd-xt-1.1", + "sd3", + "3.5-medium", + "3.5-large", + "cascade", + "flux.1-dev", + "flux.1-schnell", + "flux.1-dev-canny", + "flux.1-dev-depth", + ), + help="Version of Stable Diffusion", + ) + parser.add_argument( + "prompt", nargs="*", help="Text prompt(s) to guide image generation" + ) + parser.add_argument( + "--negative-prompt", + nargs="*", + default=[""], + help="The negative prompt(s) to guide the image generation.", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + choices=[1, 2, 4], + help="Batch size (repeat prompt)", + ) + parser.add_argument( + "--batch-count", + type=int, + default=1, + help="Number of images to generate in sequence, one at a time.", + ) + parser.add_argument( + "--height", + type=int, + default=512, + help="Height of image to generate (must be multiple of 8)", + ) + parser.add_argument( + "--width", + type=int, + default=512, + help="Height of image to generate (must be multiple of 8)", + ) + parser.add_argument( + "--denoising-steps", type=int, default=30, help="Number of denoising steps" + ) + parser.add_argument( + "--scheduler", + type=str, + default=None, + choices=( + "DDIM", + "DDPM", + "EulerA", + "Euler", + "LCM", + "LMSD", + "PNDM", + "UniPC", + "DDPMWuerstchen", + "FlowMatchEuler", + ), + help="Scheduler for diffusion process", + ) + parser.add_argument( + "--guidance-scale", + type=float, + default=7.5, + help="Value of classifier-free guidance scale (must be greater than 1)", + ) + parser.add_argument( + "--lora-scale", + type=float, + default=1.0, + help="Controls how much to influence the outputs with the LoRA parameters. (must between 0 and 1)", + ) + parser.add_argument( + "--lora-weight", + type=float, + nargs="+", + default=None, + help="The LoRA adapter(s) weights to use with the UNet. (must between 0 and 1)", + ) + parser.add_argument( + "--lora-path", + type=str, + nargs="+", + default=None, + help="Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'", + ) + parser.add_argument( + "--bf16", action="store_true", help="Run pipeline in BFloat16 precision" + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=19, + choices=range(7, 20), + help="Select ONNX opset version to target for exported models", + ) + parser.add_argument( + "--onnx-dir", default="onnx", help="Output directory for ONNX export" + ) + parser.add_argument( + "--custom-onnx-paths", + type=parse_key_value_pairs, + help=( + "[FLUX only] Custom override paths to pre-exported ONNX model files. These ONNX models are directly used to " + "build TRT engines without further optimization on the ONNX graphs. Paths should be a comma-separated list " + "of : pairs. For example: " + "--custom-onnx-paths=transformer:/path/to/transformer.onnx,vae:/path/to/vae.onnx. Call " + ".get_model_names(...) for the list of supported model names." + ), + ) + parser.add_argument( + "--onnx-export-only", + action="store_true", + help="If set, only performs the export of models to ONNX, skipping engine build and inference.", + ) + parser.add_argument( + "--download-onnx-models", + action="store_true", + help=("[FLUX only] Download pre-exported ONNX models"), + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models", + ) + + # TensorRT engine build + parser.add_argument( + "--engine-dir", default="engine", help="Output directory for TensorRT engines" + ) + parser.add_argument( + "--custom-engine-paths", + type=parse_key_value_pairs, + help=( + "[FLUX only] Custom override paths to pre-built engine files. Paths should be a comma-separated list of " + ": pairs. For example: " + "--custom-onnx-paths=transformer:/path/to/transformer.plan,vae:/path/to/vae.plan. Call " + ".get_model_names(...) for the list of supported model names." + ), + ) + + parser.add_argument( + "--optimization-level", + type=int, + default=None, + help=f"Set the builder optimization level to build the engine with. A higher level allows TensorRT to spend more building time for more optimization options. Must be one of {VALID_OPTIMIZATION_LEVELS}.", + ) + parser.add_argument( + "--build-static-batch", + action="store_true", + help="Build TensorRT engines with fixed batch size.", + ) + parser.add_argument( + "--build-dynamic-shape", + action="store_true", + help="Build TensorRT engines with dynamic image shapes.", + ) + parser.add_argument( + "--build-enable-refit", + action="store_true", + help="Enable Refit option in TensorRT engines during build.", + ) + parser.add_argument( + "--build-all-tactics", + action="store_true", + help="Build TensorRT engines using all tactic sources.", + ) + parser.add_argument( + "--timing-cache", + default=None, + type=str, + help="Path to the precached timing measurements to accelerate build.", + ) + parser.add_argument( + "--ws", + action="store_true", + help="Build TensorRT engines with weight streaming enabled.", + ) + + # Quantization configuration. + parser.add_argument("--int8", action="store_true", help="Apply int8 quantization.") + parser.add_argument("--fp8", action="store_true", help="Apply fp8 quantization.") + parser.add_argument("--fp4", action="store_true", help="Apply fp4 quantization.") + parser.add_argument( + "--quantization-level", + type=float, + default=0.0, + choices=[0.0, 1.0, 2.0, 2.5, 3.0, 4.0], + help="int8/fp8 quantization level, 1: CNN, 2: CNN + FFN, 2.5: CNN + FFN + QKV, 3: CNN + Almost all Linear (Including FFN, QKV, Proj and others), 4: CNN + Almost all Linear + fMHA, 0: Default to 2.5 for int8 and 4.0 for fp8.", + ) + parser.add_argument( + "--quantization-percentile", + type=float, + default=1.0, + help="Control quantization scaling factors (amax) collecting range, where the minimum amax in range(n_steps * percentile) will be collected. Recommendation: 1.0.", + ) + parser.add_argument( + "--quantization-alpha", + type=float, + default=0.8, + help="The alpha parameter for SmoothQuant quantization used for linear layers. Recommendation: 0.8 for SDXL.", + ) + parser.add_argument( + "--calibration-size", + type=int, + default=32, + help="The number of steps to use for calibrating the model for quantization. Recommendation: 32, 64, 128 for SDXL", + ) + + # Inference + parser.add_argument( + "--num-warmup-runs", + type=int, + default=5, + help="Number of warmup runs before benchmarking performance", + ) + parser.add_argument( + "--use-cuda-graph", action="store_true", help="Enable cuda graph" + ) + parser.add_argument( + "--nvtx-profile", + action="store_true", + help="Enable NVTX markers for performance profiling", + ) + parser.add_argument( + "--torch-inference", + default="", + help="Run inference with PyTorch (using specified compilation mode) instead of TensorRT.", + ) + parser.add_argument( + "--torch-fallback", + default=None, + type=str, + help="[FLUX only] Comma separated list of models to be inferenced using torch instead of TRT. For example --torch-fallback t5,transformer. If --torch-inference set, this parameter will be ignored.", + ) + parser.add_argument( + "--low-vram", + action="store_true", + help="[FLUX only] Optimize for low VRAM usage, possibly at the expense of inference performance. Disabled by default.", + ) + parser.add_argument( + "--seed", + type=int, + default=None, + help="Seed for random generator to get consistent results", + ) + parser.add_argument( + "--output-dir", + default="output", + help="Output directory for logs and image artifacts", + ) + parser.add_argument( + "--hf-token", + type=str, + help="HuggingFace API access token for downloading model checkpoints", + ) + parser.add_argument( + "-v", "--verbose", action="store_true", help="Show verbose output" + ) + return parser + + +def process_pipeline_args( + args: argparse.Namespace, +) -> Tuple[Dict[str, Any], Dict[str, Any], Tuple]: + """Validate parsed arguments and process argument values. + + Some argument values are resolved or overwritten during processing. + + Args: + args (argparse.Namespace): Parsed argument. This is modified in-place. + + Returns: + Dict[str, Any]: Keyword arguments for initializing a pipeline. This is only used in legacy pipelines that do not + have factory methods `FromArgs` that construct the pipeline directly from the parsed argument. + Dict[str, Any]: Keyword arguments for calling the `.load_engine` method of the pipeline. + Tuple: Arguments for calling the `.run` method of the pipeline. + """ + + # GPU device info + device_info = torch.cuda.get_device_properties(0) + sm_version = device_info.major * 10 + device_info.minor + + is_flux = args.version.startswith("flux") + is_sd35 = args.version.startswith("3.5") + + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.image_height} and {args.width}." + ) + + # Handle batch size + max_batch_size = 4 + if args.batch_size > max_batch_size: + raise ValueError( + f"Batch size {args.batch_size} is larger than allowed {max_batch_size}." + ) + + if args.use_cuda_graph and ( + not args.build_static_batch or args.build_dynamic_shape + ): + raise ValueError( + "Using CUDA graph requires static dimensions. Enable `--build-static-batch` and do not specify `--build-dynamic-shape`" + ) + + # TensorRT builder optimization level + if args.optimization_level is None: + # optimization level set to 3 for all Flux pipelines to reduce GPU memory usage + if args.int8 or args.fp8 and not is_flux: + args.optimization_level = 4 + else: + args.optimization_level = 3 + + if args.optimization_level not in VALID_OPTIMIZATION_LEVELS: + raise ValueError( + f"Optimization level {args.optimization_level} not valid. Valid values are: {VALID_OPTIMIZATION_LEVELS}" + ) + + # Quantized pipeline + # int8 support + if args.int8 and not any( + args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1") + ): + raise ValueError( + "int8 quantization is only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipelines." + ) + + # fp8 support + if args.fp8 and not ( + any(args.version.startswith(prefix) for prefix in ("xl", "1.4", "1.5", "2.1")) + or is_flux + ): + raise ValueError( + "fp8 quantization is only supported for SDXL, SD1.4, SD1.5, SD2.1 and FLUX pipelines." + ) + + if args.fp8 and args.int8: + raise ValueError( + "Cannot apply both int8 and fp8 quantization, please choose only one." + ) + + if args.fp8 and sm_version < 89: + raise ValueError( + f"Cannot apply FP8 quantization for GPU with compute capability {sm_version / 10.0}. Only Ada and Hopper are supported." + ) + + # TensorRT ModelOpt quantization level + if args.quantization_level == 0.0: + + def override_quant_level(level: float, dtype_str: str): + args.quantization_level = level + print( + f"[W] The default quantization level has been set to {level} for {dtype_str}." + ) + + if args.fp8: + # L4 fp8 fMHA on Hopper not yet enabled. + if sm_version == 90 and is_flux: + override_quant_level(3.0, "FP8") + else: + override_quant_level( + 3.0 if args.version in ("1.4", "1.5") else 4.0, "FP8" + ) + + elif args.int8: + override_quant_level(3.0, "INT8") + + if args.quantization_level == 3.0 and args.download_onnx_models: + raise ValueError( + "Transformer ONNX model for Quantization level 3 is not available for download. Please export the quantized Transformer model natively with the removal of --download-onnx-models." + ) + if args.fp4: + # FP4 precision is only supported for Flux Pipelines + assert is_flux, "FP4 precision is only supported for Flux pipelines" + + # Handle LoRA + # FLUX canny and depth official LoRAs are not supported because they modify the transformer architecture, conflicting with refit + if args.lora_path and not any( + args.version.startswith(prefix) + for prefix in ("1.5", "2.1", "xl", "flux.1-dev", "flux.1-schnell") + ): + raise ValueError( + "LoRA adapter support is only supported for SD1.5, SD2.1, SDXL, FLUX.1-dev and FLUX.1-schnell pipelines" + ) + + if args.lora_weight: + for weight in (weight for weight in args.lora_weight if not 0 <= weight <= 1): + raise ValueError( + f"LoRA adapter weights must be between 0 and 1, provided {weight}" + ) + + if not 0 <= args.lora_scale <= 1: + raise ValueError( + f"LoRA scale value must be between 0 and 1, provided {args.lora_scale}" + ) + + # Force lora merge when fp8 or int8 is used with LoRA + if args.build_enable_refit and args.lora_path and (args.int8 or args.fp8): + raise ValueError( + "Engine refit should not be enabled for quantized models with LoRA. ModelOpt recommends fusing the LoRA to the model before quantization. \ + See https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/diffusers/quantization#lora" + ) + + # Torch-fallback and Torch-inference + if args.torch_fallback and not args.torch_inference: + assert is_flux or is_sd35, ( + "PyTorch Fallback is only supported for Flux and Stable Diffusion 3.5 pipelines." + ) + args.torch_fallback = args.torch_fallback.split(",") + + if args.torch_fallback and args.torch_inference: + print( + "[W] All models will run in PyTorch when --torch-inference is set. Parameter --torch-fallback will be ignored." + ) + args.torch_fallback = None + + # low-vram + if args.low_vram: + assert is_flux or is_sd35, ( + "low-vram mode is only supported for Flux and Stable Diffusion 3.5 pipelines." + ) + + # Pack arguments + kwargs_init_pipeline = { + "version": args.version, + "max_batch_size": max_batch_size, + "denoising_steps": args.denoising_steps, + "scheduler": args.scheduler, + "guidance_scale": args.guidance_scale, + "output_dir": args.output_dir, + "hf_token": args.hf_token, + "verbose": args.verbose, + "nvtx_profile": args.nvtx_profile, + "use_cuda_graph": args.use_cuda_graph, + "lora_scale": args.lora_scale, + "lora_weight": args.lora_weight, + "lora_path": args.lora_path, + "framework_model_dir": args.framework_model_dir, + "torch_inference": args.torch_inference, + } + + kwargs_load_engine = { + "onnx_opset": args.onnx_opset, + "opt_batch_size": args.batch_size, + "opt_image_height": args.height, + "opt_image_width": args.width, + "optimization_level": args.optimization_level, + "static_batch": args.build_static_batch, + "static_shape": not args.build_dynamic_shape, + "enable_all_tactics": args.build_all_tactics, + "enable_refit": args.build_enable_refit, + "timing_cache": args.timing_cache, + "int8": args.int8, + "fp8": args.fp8, + "fp4": args.fp4, + "quantization_level": args.quantization_level, + "quantization_percentile": args.quantization_percentile, + "quantization_alpha": args.quantization_alpha, + "calibration_size": args.calibration_size, + "onnx_export_only": args.onnx_export_only, + "download_onnx_models": args.download_onnx_models, + } + + args_run_demo = ( + args.prompt, + args.negative_prompt, + args.height, + args.width, + args.batch_size, + args.batch_count, + args.num_warmup_runs, + args.use_cuda_graph, + ) + + return kwargs_init_pipeline, kwargs_load_engine, args_run_demo diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/dynamic_import.py b/flux.1-dev-trt-b200/model/demo_diffusion/dynamic_import.py new file mode 100755 index 000000000..a8bdf5485 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/dynamic_import.py @@ -0,0 +1,36 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import warnings +from importlib import import_module + + +def import_from_diffusers(model_name, module_name): + try: + module = import_module(module_name) + return getattr(module, model_name) + except ImportError: + warnings.warn( + f"Failed to import {module_name}. The {model_name} model will not be available.", + ImportWarning, + ) + except AttributeError: + warnings.warn( + f"The {model_name} model is not available in the installed version of diffusers.", + ImportWarning, + ) + return None diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/engine.py b/flux.1-dev-trt-b200/model/demo_diffusion/engine.py new file mode 100644 index 000000000..60076c233 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/engine.py @@ -0,0 +1,367 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import gc +import os +import subprocess +import warnings +from collections import OrderedDict, defaultdict + +import numpy as np +import tensorrt as trt +import torch +from cuda import cudart +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + engine_from_bytes, +) + +import onnx +from onnx import numpy_helper + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +# Map of TensorRT dtype -> torch dtype +trt_to_torch_dtype_dict = { + trt.DataType.BOOL: torch.bool, + trt.DataType.UINT8: torch.uint8, + trt.DataType.INT8: torch.int8, + trt.DataType.INT32: torch.int32, + trt.DataType.INT64: torch.int64, + trt.DataType.HALF: torch.float16, + trt.DataType.FLOAT: torch.float32, + trt.DataType.BF16: torch.bfloat16, +} + + +def _CUASSERT(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +def get_refit_weights( + state_dict, onnx_opt_path, weight_name_mapping, weight_shape_mapping +): + onnx_opt_dir = os.path.dirname(onnx_opt_path) + onnx_opt_model = onnx.load(onnx_opt_path) + # Create initializer data hashes + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = initializer_hash + + refit_weights = OrderedDict() + updated_weight_names = ( + set() + ) # save names of updated weights to refit only the required weights + for wt_name, wt in state_dict.items(): + # query initializer to compare + initializer_name = weight_name_mapping[wt_name] + initializer_hash = initializer_hash_mapping[initializer_name] + + # get shape transform info + initializer_shape, is_transpose = weight_shape_mapping[wt_name] + if is_transpose: + wt = torch.transpose(wt, 0, 1) + else: + wt = torch.reshape(wt, initializer_shape) + + # include weight if hashes differ + wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes()) + if initializer_hash != wt_hash: + updated_weight_names.add(initializer_name) + # Store all weights as the refitter may require unchanged weights too + # docs: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#refitting-engine-c + refit_weights[initializer_name] = wt.contiguous() + return refit_weights, updated_weight_names + + +class Engine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, refit_weights, updated_weight_names): + # Initialize refitter + refitter = trt.Refitter(self.engine, TRT_LOGGER) + refitted_weights = set() + + def refit_single_weight(trt_weight_name): + # get weight from state dict + trt_datatype = refitter.get_weights_prototype(trt_weight_name).dtype + refit_weights[trt_weight_name] = refit_weights[trt_weight_name].to( + trt_to_torch_dtype_dict[trt_datatype] + ) + + # trt.Weight and trt.TensorLocation + trt_wt_tensor = trt.Weights( + trt_datatype, + refit_weights[trt_weight_name].data_ptr(), + torch.numel(refit_weights[trt_weight_name]), + ) + trt_wt_location = ( + trt.TensorLocation.DEVICE + if refit_weights[trt_weight_name].is_cuda + else trt.TensorLocation.HOST + ) + + # apply refit + refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location) + refitted_weights.add(trt_weight_name) + + # iterate through all tensorrt refittable weights + for trt_weight_name in refitter.get_all_weights(): + if trt_weight_name not in updated_weight_names: + continue + + refit_single_weight(trt_weight_name) + + # iterate through missing weights required by tensorrt - addresses the case where lora_scale=0 + for trt_weight_name in refitter.get_missing_weights(): + refit_single_weight(trt_weight_name) + + if not refitter.refit_cuda_engine(): + print("Error: failed to refit new weights.") + exit(0) + + print(f"[I] Total refitted weights {len(refitted_weights)}.") + + def build( + self, + onnx_path, + strongly_typed=False, + fp16=True, + bf16=False, + tf32=False, + int8=False, + fp8=False, + input_profile=None, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + update_output_names=None, + native_instancenorm=True, + verbose=False, + weight_streaming=False, + builder_optimization_level=3, + precision_constraints="none", + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + + # Handle weight streaming case: https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#streaming-weights. + if weight_streaming: + strongly_typed, fp16, bf16, int8, fp8 = True, False, False, False, False + + # Base command + build_command = [ + f"polygraphy convert {onnx_path} --convert-to trt --output {self.engine_path}" + ] + + # Precision flags + build_args = [ + "--fp16" if fp16 else "", + "--bf16" if bf16 else "", + "--tf32" if tf32 else "", + "--fp8" if fp8 else "", + "--int8" if int8 else "", + "--strongly-typed" if strongly_typed else "", + ] + + # Additional arguments + build_args.extend( + [ + "--weight-streaming" if weight_streaming else "", + "--refittable" if enable_refit else "", + "--tactic-sources" if not enable_all_tactics else "", + "--onnx-flags native_instancenorm" if native_instancenorm else "", + f"--builder-optimization-level {builder_optimization_level}", + f"--precision-constraints {precision_constraints}", + ] + ) + + # Timing cache + if timing_cache: + build_args.extend( + [ + f"--load-timing-cache {timing_cache}", + f"--save-timing-cache {timing_cache}", + ] + ) + + # Verbosity setting + verbosity = "extra_verbose" if verbose else "error" + build_args.append(f"--verbosity {verbosity}") + + # Output names + if update_output_names: + print(f"Updating network outputs to {update_output_names}") + build_args.append(f"--trt-outputs {' '.join(update_output_names)}") + + # Input profiles + if input_profile: + profile_args = defaultdict(str) + for name, dims in input_profile.items(): + assert len(dims) == 3 + profile_args["--trt-min-shapes"] += ( + f"{name}:{str(list(dims[0])).replace(' ', '')} " + ) + profile_args["--trt-opt-shapes"] += ( + f"{name}:{str(list(dims[1])).replace(' ', '')} " + ) + profile_args["--trt-max-shapes"] += ( + f"{name}:{str(list(dims[2])).replace(' ', '')} " + ) + + build_args.extend(f"{k} {v}" for k, v in profile_args.items()) + + # Filter out empty strings and join command + build_args = [arg for arg in build_args if arg] + final_command = " ".join(build_command + build_args) + + # Execute command with improved error handling + try: + print(f"Engine build command: {final_command}") + subprocess.run(final_command, check=True, shell=True) + except subprocess.CalledProcessError as exc: + error_msg = ( + f"Failed to build TensorRT engine. Error details:\nCommand: {exc.cmd}\n" + ) + raise RuntimeError(error_msg) from exc + + def load(self, weight_streaming=False, weight_streaming_budget_percentage=None): + if self.engine is not None: + print(f"[W]: Engine {self.engine_path} already loaded, skip reloading") + return + if not hasattr(self, "engine_bytes_cpu") or self.engine_bytes_cpu is None: + # keep a cpu copy of the engine to reduce reloading time. + print(f"Loading TensorRT engine to cpu bytes: {self.engine_path}") + self.engine_bytes_cpu = bytes_from_path(self.engine_path) + print(f"Loading TensorRT engine from bytes: {self.engine_path}") + self.engine = engine_from_bytes(self.engine_bytes_cpu) + if weight_streaming: + if weight_streaming_budget_percentage is None: + warnings.warn( + f"Weight streaming budget is not set for {self.engine_path}. Weights will not be streamed." + ) + else: + self.engine.weight_streaming_budget_v2 = int( + weight_streaming_budget_percentage + / 100 + * self.engine.streamable_weights_size + ) + + def unload(self): + if self.engine is not None: + print(f"Unloading TensorRT engine: {self.engine_path}") + del self.engine + self.engine = None + gc.collect() + else: + print(f"[W]: Unload an unloaded engine {self.engine_path}, skip unloading") + + def activate(self, device_memory=None): + if device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = device_memory + else: + self.context = self.engine.create_execution_context() + + def reactivate(self, device_memory): + assert self.context + self.context.device_memory = device_memory + + def deactivate(self): + del self.context + self.context = None + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for binding in range(self.engine.num_io_tensors): + name = self.engine.get_tensor_name(binding) + if shape_dict and name in shape_dict: + shape = shape_dict[name] + else: + shape = self.engine.get_tensor_shape(name) + print( + f"[W]: {self.engine_path}: Could not find '{name}' in shape dict {shape_dict}. Using shape {shape} inferred from the engine." + ) + if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: + self.context.set_input_shape(name, shape) + dtype = trt_to_torch_dtype_dict[self.engine.get_tensor_dtype(name)] + tensor = torch.empty(tuple(shape), dtype=dtype).to(device=device) + self.tensors[name] = tensor + + def deallocate_buffers(self): + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + del self.tensors[binding] + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + _CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + _CUASSERT(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError(f"ERROR: inference of {self.engine_path} failed.") + # capture cuda graph + _CUASSERT( + cudart.cudaStreamBeginCapture( + stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal + ) + ) + self.context.execute_async_v3(stream) + self.graph = _CUASSERT(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = _CUASSERT( + cudart.cudaGraphInstantiate(self.graph, 0) + ) + else: + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError(f"ERROR: inference of {self.engine_path} failed.") + + return self.tensors diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/image/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/image/__init__.py new file mode 100644 index 000000000..f0c165e09 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/image/__init__.py @@ -0,0 +1,34 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from demo_diffusion.image.load import ( + download_image, + prepare_mask_and_masked_image, + preprocess_image, + save_image, +) +from demo_diffusion.image.resize import resize_with_antialiasing +from demo_diffusion.image.video import tensor2vid + +__all__ = [ + "preprocess_image", + "prepare_mask_and_masked_image", + "download_image", + "save_image", + "resize_with_antialiasing", + "tensor2vid", +] diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/image/load.py b/flux.1-dev-trt-b200/model/demo_diffusion/image/load.py new file mode 100644 index 000000000..f501ec89e --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/image/load.py @@ -0,0 +1,78 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import random +from io import BytesIO + +import numpy as np +import requests +import torch +from PIL import Image + + +def preprocess_image(image): + """ + image: torch.Tensor + """ + w, h = image.size + w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 + image = image.resize((w, h)) + image = np.array(image).astype(np.float32) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).contiguous() + return 2.0 * image - 1.0 + + +def prepare_mask_and_masked_image(image, mask): + """ + image: PIL.Image.Image + mask: PIL.Image.Image + """ + if isinstance(image, Image.Image): + image = np.array(image.convert("RGB")) + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0 + if isinstance(mask, Image.Image): + mask = np.array(mask.convert("L")) + mask = mask.astype(np.float32) / 255.0 + mask = mask[None, None] + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous() + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +def download_image(url): + response = requests.get(url) + return Image.open(BytesIO(response.content)).convert("RGB") + + +def save_image(images, image_path_dir, image_name_prefix, image_name_suffix): + """ + Save the generated images to png files. + """ + for i in range(images.shape[0]): + image_path = os.path.join( + image_path_dir, + f"{image_name_prefix}{i + 1}-{random.randint(1000, 9999)}-{image_name_suffix}.png", + ) + print(f"Saving image {i + 1} / {images.shape[0]} to: {image_path}") + Image.fromarray(images[i]).save(image_path) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/image/resize.py b/flux.1-dev-trt-b200/model/demo_diffusion/image/resize.py new file mode 100644 index 000000000..3ec529c70 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/image/resize.py @@ -0,0 +1,130 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +# Taken from https://github.com/huggingface/diffusers/blob/be62c85cd973f2001ab8c5d8919a9a6811fc7e43/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py#L633 +def resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True): + h, w = input.shape[-2:] + factors = (h / size[0], w / size[1]) + + # First, we have to determine sigma + # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171 + sigmas = ( + max((factors[0] - 1.0) / 2.0, 0.001), + max((factors[1] - 1.0) / 2.0, 0.001), + ) + + # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma + # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206 + # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now + ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3)) + + # Make sure it is odd + if (ks[0] % 2) == 0: + ks = ks[0] + 1, ks[1] + + if (ks[1] % 2) == 0: + ks = ks[0], ks[1] + 1 + + input = _gaussian_blur2d(input, ks, sigmas) + + output = torch.nn.functional.interpolate( + input, size=size, mode=interpolation, align_corners=align_corners + ) + return output + + +def _compute_padding(kernel_size): + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def _filter2d(input, kernel): + # prepare kernel + b, c, h, w = input.shape + tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + padding_shape: list[int] = _compute_padding([height, width]) + input = torch.nn.functional.pad(input, padding_shape, mode="reflect") + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = torch.nn.functional.conv2d( + input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1 + ) + + out = output.view(b, c, h, w) + return out + + +def _gaussian(window_size: int, sigma): + if isinstance(sigma, float): + sigma = torch.tensor([[sigma]]) + + batch_size = sigma.shape[0] + + x = ( + torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) + - window_size // 2 + ).expand(batch_size, -1) + + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0))) + + return gauss / gauss.sum(-1, keepdim=True) + + +def _gaussian_blur2d(input, kernel_size, sigma): + if isinstance(sigma, tuple): + sigma = torch.tensor([sigma], dtype=input.dtype) + else: + sigma = sigma.to(dtype=input.dtype) + + ky, kx = int(kernel_size[0]), int(kernel_size[1]) + bs = sigma.shape[0] + kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1)) + kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1)) + out_x = _filter2d(input, kernel_x[..., None, :]) + out = _filter2d(out_x, kernel_y[..., None]) + + return out diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/image/video.py b/flux.1-dev-trt-b200/model/demo_diffusion/image/video.py new file mode 100644 index 000000000..825aec840 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/image/video.py @@ -0,0 +1,37 @@ +# +# Copyright (c) Alibaba, Inc. and its affiliates. +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch + + +# Not a contribution +# Changes made by NVIDIA CORPORATION & AFFILIATES enabling tensor2vid or otherwise documented as +# NVIDIA-proprietary are not a contribution and subject to the terms and conditions at the top of the file +def tensor2vid(video: torch.Tensor, processor, output_type="np"): + # Based on: + # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78 + + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + + outputs.append(batch_output) + + return outputs diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/__init__.py new file mode 100644 index 000000000..077f163fa --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/__init__.py @@ -0,0 +1,99 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from demo_diffusion.model.base_model import BaseModel +from demo_diffusion.model.clip import ( + CLIPImageProcessorModel, + CLIPModel, + CLIPVisionWithProjModel, + CLIPWithProjModel, + SD3_CLIPGModel, + SD3_CLIPLModel, + SD3_T5XXLModel, + get_clip_embedding_dim, +) +from demo_diffusion.model.diffusion_transformer import ( + FluxTransformerModel, + SD3_MMDiTModel, + SD3TransformerModel, +) +from demo_diffusion.model.gan import VQGANModel +from demo_diffusion.model.load import unload_torch_model +from demo_diffusion.model.lora import FLUXLoraLoader, SDLoraLoader, merge_loras +from demo_diffusion.model.scheduler import make_scheduler +from demo_diffusion.model.t5 import T5Model +from demo_diffusion.model.tokenizer import make_tokenizer +from demo_diffusion.model.unet import ( + UNetCascadeModel, + UNetModel, + UNetTemporalModel, + UNetXLModel, + UNetXLModelControlNet, +) +from demo_diffusion.model.vae import ( + SD3_VAEDecoderModel, + SD3_VAEEncoderModel, + TorchVAEEncoder, + VAEDecTemporalModel, + VAEEncoderModel, + VAEModel, +) + +__all__ = [ + # base_model + "BaseModel", + # clip + "get_clip_embedding_dim", + "CLIPModel", + "CLIPWithProjModel", + "SD3_CLIPGModel", + "SD3_CLIPLModel", + "SD3_T5XXLModel", + "CLIPVisionWithProjModel", + "CLIPImageProcessorModel", + # diffusion_transformer + "SD3_MMDiTModel", + "FluxTransformerModel", + "SD3TransformerModel", + # gan + "VQGANModel", + # lora + "SDLoraLoader", + "FLUXLoraLoader", + "merge_loras", + # scheduler + "make_scheduler", + # t5 + "T5Model", + # tokenizer + "make_tokenizer", + # unet + "UNetModel", + "UNetXLModel", + "UNetXLModelControlNet", + "UNetTemporalModel", + "UNetCascadeModel", + # vae + "VAEModel", + "SD3_VAEDecoderModel", + "VAEDecTemporalModel", + "TorchVAEEncoder", + "VAEEncoderModel", + "SD3_VAEEncoderModel", + # load + "unload_torch_model", +] diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/base_model.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/base_model.py new file mode 100644 index 000000000..53c207062 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/base_model.py @@ -0,0 +1,310 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import os + +import numpy as np +import torch +from diffusers import DiffusionPipeline + +import onnx +from demo_diffusion.model import load, optimizer +from demo_diffusion.model.lora import merge_loras +from onnx import numpy_helper + + +class BaseModel: + def __init__( + self, + version="1.5", + pipeline=None, + device="cuda", + hf_token="", + verbose=True, + framework_model_dir="pytorch_model", + fp16=False, + tf32=False, + bf16=False, + int8=False, + fp8=False, + max_batch_size=16, + text_maxlen=77, + embedding_dim=768, + compression_factor=8, + ): + self.name = self.__class__.__name__ + self.pipeline = pipeline.name + self.version = version + self.path = load.get_path(version, pipeline) + self.device = device + self.hf_token = hf_token + self.hf_safetensor = not (pipeline.is_inpaint() and version in ("1.4", "1.5")) + self.verbose = verbose + self.framework_model_dir = framework_model_dir + + self.fp16 = fp16 + self.tf32 = tf32 + self.bf16 = bf16 + self.int8 = int8 + self.fp8 = fp8 + + self.compression_factor = compression_factor + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1344 # max image resolution: 1344x1344 + self.min_latent_shape = self.min_image_shape // self.compression_factor + self.max_latent_shape = self.max_image_shape // self.compression_factor + + self.text_maxlen = text_maxlen + self.embedding_dim = embedding_dim + self.extra_output_names = [] + + self.do_constant_folding = True + + def get_pipeline(self): + model_opts = ( + {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + ) + model_opts = {"torch_dtype": torch.bfloat16} if self.bf16 else model_opts + return DiffusionPipeline.from_pretrained( + self.path, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + + def get_model(self, torch_inference=""): + pass + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + pass + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + # Helper utility for ONNX export + def export_onnx( + self, + onnx_path, + onnx_opt_path, + onnx_opset, + opt_image_height, + opt_image_width, + custom_model=None, + enable_lora_merge=False, + static_shape=False, + lora_loader=None, + ): + onnx_opt_graph = None + # Export optimized ONNX model (if missing) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + print(f"[I] Exporting ONNX model: {onnx_path}") + + def export_onnx(model): + if enable_lora_merge: + assert lora_loader is not None + model = merge_loras(model, lora_loader) + inputs = self.get_sample_input( + 1, opt_image_height, opt_image_width, static_shape + ) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=self.do_constant_folding, + input_names=self.get_input_names(), + output_names=self.get_output_names(), + dynamic_axes=self.get_dynamic_axes(), + verbose=False, + ) + + if custom_model: + with torch.inference_mode(): + export_onnx(custom_model) + else: + # WAR: Enable autocast for BF16 Stable Cascade pipeline + do_autocast = ( + True if self.version == "cascade" and self.bf16 else False + ) + with ( + torch.inference_mode(), + torch.autocast("cuda", enabled=do_autocast), + ): + export_onnx(self.get_model()) + else: + print(f"[I] Found cached ONNX model: {onnx_path}") + + print(f"[I] Optimizing ONNX model: {onnx_opt_path}") + onnx_opt_graph = self.optimize(onnx.load(onnx_path)) + if load.onnx_graph_needs_external_data(onnx_opt_graph): + onnx.save_model( + onnx_opt_graph, + onnx_opt_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + else: + onnx.save(onnx_opt_graph, onnx_opt_path) + else: + print(f"[I] Found cached optimized ONNX model: {onnx_opt_path} ") + + # Helper utility for weights map + def export_weights_map(self, onnx_opt_path, weights_map_path): + if not os.path.exists(weights_map_path): + onnx_opt_dir = os.path.dirname(onnx_opt_path) + onnx_opt_model = onnx.load(onnx_opt_path) + state_dict = self.get_model().state_dict() + # Create initializer data hashes + initializer_hash_mapping = {} + for initializer in onnx_opt_model.graph.initializer: + initializer_data = numpy_helper.to_array( + initializer, base_dir=onnx_opt_dir + ).astype(np.float16) + initializer_hash = hash(initializer_data.data.tobytes()) + initializer_hash_mapping[initializer.name] = ( + initializer_hash, + initializer_data.shape, + ) + + weights_name_mapping = {} + weights_shape_mapping = {} + # set to keep track of initializers already added to the name_mapping dict + initializers_mapped = set() + for wt_name, wt in state_dict.items(): + # get weight hash + wt = wt.cpu().detach().numpy().astype(np.float16) + wt_hash = hash(wt.data.tobytes()) + wt_t_hash = hash(np.transpose(wt).data.tobytes()) + + for initializer_name, ( + initializer_hash, + initializer_shape, + ) in initializer_hash_mapping.items(): + # Due to constant folding, some weights are transposed during export + # To account for the transpose op, we compare the initializer hash to the + # hash for the weight and its transpose + if wt_hash == initializer_hash or wt_t_hash == initializer_hash: + # The assert below ensures there is a 1:1 mapping between + # PyTorch and ONNX weight names. It can be removed in cases where 1:many + # mapping is found and name_mapping[wt_name] = list() + assert initializer_name not in initializers_mapped + weights_name_mapping[wt_name] = initializer_name + initializers_mapped.add(initializer_name) + is_transpose = False if wt_hash == initializer_hash else True + weights_shape_mapping[wt_name] = ( + initializer_shape, + is_transpose, + ) + + # Sanity check: Were any weights not matched + if wt_name not in weights_name_mapping: + print( + f"[I] PyTorch weight {wt_name} not matched with any ONNX initializer" + ) + print( + f"[I] {len(weights_name_mapping.keys())} PyTorch weights were matched with ONNX initializers" + ) + assert weights_name_mapping.keys() == weights_shape_mapping.keys() + with open(weights_map_path, "w") as fp: + json.dump([weights_name_mapping, weights_shape_mapping], fp) + else: + print(f"[I] Found cached weights map: {weights_map_path} ") + + def optimize(self, onnx_graph, return_onnx=True, **kwargs): + opt = optimizer.Optimizer(onnx_graph, verbose=self.verbose) + opt.info(self.name + ": original") + opt.cleanup() + opt.info(self.name + ": cleanup") + if kwargs.get("modify_fp8_graph", False): + is_fp16_io = kwargs.get("is_fp16_io", True) + opt.modify_fp8_graph(is_fp16_io=is_fp16_io) + opt.info(self.name + ": modify fp8 graph") + if self.version.startswith("flux.1") and self.fp8: + opt.flux_convert_rope_weight_type() + opt.info(self.name + ": convert rope weight type for fp8 flux") + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes() + opt.info(self.name + ": shape inference") + if kwargs.get("fuse_mha_qkv_int8", False): + opt.fuse_mha_qkv_int8_sq() + opt.info(self.name + ": fuse QKV nodes") + onnx_opt_graph = opt.cleanup(return_onnx=return_onnx) + opt.info(self.name + ": finished") + return onnx_opt_graph + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + latent_height = image_height // self.compression_factor + latent_width = image_width // self.compression_factor + assert ( + latent_height >= self.min_latent_shape + and latent_height <= self.max_latent_shape + ) + assert ( + latent_width >= self.min_latent_shape + and latent_width <= self.max_latent_shape + ) + return (latent_height, latent_width) + + def get_minmax_dims( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // self.compression_factor + latent_width = image_width // self.compression_factor + min_image_height = image_height if static_shape else self.min_image_shape + max_image_height = image_height if static_shape else self.max_image_shape + min_image_width = image_width if static_shape else self.min_image_shape + max_image_width = image_width if static_shape else self.max_image_shape + min_latent_height = latent_height if static_shape else self.min_latent_shape + max_latent_height = latent_height if static_shape else self.max_latent_shape + min_latent_width = latent_width if static_shape else self.min_latent_shape + max_latent_width = latent_width if static_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/clip.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/clip.py new file mode 100644 index 000000000..104979f59 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/clip.py @@ -0,0 +1,628 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open +from transformers import ( + CLIPImageProcessor, + CLIPTextModel, + CLIPTextModelWithProjection, + CLIPVisionModelWithProjection, +) + +from demo_diffusion.model import base_model, load, optimizer +from demo_diffusion.utils_sd3.other_impls import ( + SDClipModel, + SDXLClipG, + T5XXLModel, + load_into, +) + + +def get_clipwithproj_embedding_dim(version: str, subfolder: str) -> int: + """Return the embedding dimension of a CLIP with projection model.""" + if version in ("xl-1.0", "xl-turbo", "cascade"): + return 1280 + elif version in {"3.5-medium", "3.5-large"} and subfolder == "text_encoder": + return 768 + elif version in {"3.5-medium", "3.5-large"} and subfolder == "text_encoder_2": + return 1280 + else: + raise ValueError(f"Invalid version {version} + subfolder {subfolder}") + + +def get_clip_embedding_dim(version, pipeline): + if version in ( + "1.4", + "1.5", + "dreamshaper-7", + "flux.1-dev", + "flux.1-schnell", + "flux.1-dev-canny", + "flux.1-dev-depth", + ): + return 768 + elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_base(): + return 768 + elif version in ("sd3"): + return 4096 + else: + raise ValueError(f"Invalid version {version} + pipeline {pipeline}") + + +class CLIPModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + embedding_dim, + fp16=False, + tf32=False, + bf16=False, + output_hidden_states=False, + keep_pooled_output=False, + subfolder="text_encoder", + ): + super(CLIPModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + self.subfolder = subfolder + self.hidden_layer_offset = 0 if pipeline.is_cascade() else -1 + self.keep_pooled_output = keep_pooled_output + + # Output the final hidden state + if output_hidden_states: + self.extra_output_names = ["hidden_states"] + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + if self.bf16 + else {} + ) + clip_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached( + clip_model_dir, model_opts, self.hf_safetensor, model_name="model" + ): + model = CLIPTextModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(clip_model_dir, **model_opts) + else: + print(f"[I] Load CLIPTextModel model from: {clip_model_dir}") + model = CLIPTextModel.from_pretrained(clip_model_dir, **model_opts).to( + self.device + ) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + output_names = ["text_embeddings"] + if self.keep_pooled_output: + output_names += ["pooled_embeddings"] + return output_names + + def get_dynamic_axes(self): + dynamic_axes = { + "input_ids": {0: "B"}, + "text_embeddings": {0: "B"}, + } + if self.keep_pooled_output: + dynamic_axes["pooled_embeddings"] = {0: "B"} + return dynamic_axes + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [ + (min_batch, self.text_maxlen), + (batch_size, self.text_maxlen), + (max_batch, self.text_maxlen), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + if self.keep_pooled_output: + output["pooled_embeddings"] = (batch_size, self.embedding_dim) + if "hidden_states" in self.extra_output_names: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + return output + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros( + batch_size, self.text_maxlen, dtype=torch.int32, device=self.device + ) + + def optimize(self, onnx_graph): + opt = optimizer.Optimizer(onnx_graph, verbose=self.verbose) + opt.info(self.name + ": original") + keep_outputs = [0, 1] if self.keep_pooled_output else [0] + opt.select_outputs(keep_outputs) + opt.cleanup() + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes() + opt.info(self.name + ": shape inference") + opt.select_outputs( + keep_outputs, names=self.get_output_names() + ) # rename network outputs + opt.info(self.name + ": rename network output(s)") + opt_onnx_graph = opt.cleanup(return_onnx=True) + if "hidden_states" in self.extra_output_names: + opt_onnx_graph = opt.clip_add_hidden_states( + self.hidden_layer_offset, return_onnx=True + ) + opt.info(self.name + ": added hidden_states") + opt.info(self.name + ": finished") + return opt_onnx_graph + + +class CLIPWithProjModel(CLIPModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + bf16=False, + max_batch_size=16, + output_hidden_states=False, + subfolder="text_encoder_2", + ): + super(CLIPWithProjModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + bf16=bf16, + max_batch_size=max_batch_size, + embedding_dim=get_clipwithproj_embedding_dim(version, subfolder), + output_hidden_states=output_hidden_states, + ) + self.subfolder = subfolder + + def get_model(self, torch_inference=""): + model_opts = ( + {"variant": "fp16", "torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + ) + clip_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached( + clip_model_dir, model_opts, self.hf_safetensor, model_name="model" + ): + model = CLIPTextModelWithProjection.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(clip_model_dir, **model_opts) + else: + print(f"[I] Load CLIPTextModelWithProjection model from: {clip_model_dir}") + model = CLIPTextModelWithProjection.from_pretrained( + clip_model_dir, **model_opts + ).to(self.device) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["input_ids", "attention_mask"] + + def get_output_names(self): + return ["text_embeddings"] + + def get_dynamic_axes(self): + return { + "input_ids": {0: "B"}, + "attention_mask": {0: "B"}, + "text_embeddings": {0: "B"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [ + (min_batch, self.text_maxlen), + (batch_size, self.text_maxlen), + (max_batch, self.text_maxlen), + ], + "attention_mask": [ + (min_batch, self.text_maxlen), + (batch_size, self.text_maxlen), + (max_batch, self.text_maxlen), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "attention_mask": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.embedding_dim), + } + if "hidden_states" in self.extra_output_names: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + return output + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + self.check_dims(batch_size, image_height, image_width) + return ( + torch.zeros( + batch_size, self.text_maxlen, dtype=torch.int32, device=self.device + ), + torch.zeros( + batch_size, self.text_maxlen, dtype=torch.int32, device=self.device + ), + ) + + +class SD3_CLIPGModel(CLIPModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + embedding_dim=None, + fp16=False, + pooled_output=False, + ): + self.CLIPG_CONFIG = { + "hidden_act": "gelu", + "hidden_size": 1280, + "intermediate_size": 5120, + "num_attention_heads": 20, + "num_hidden_layers": 32, + } + super(SD3_CLIPGModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=self.CLIPG_CONFIG["hidden_size"] + if embedding_dim is None + else embedding_dim, + ) + self.subfolder = "text_encoders" + if pooled_output: + self.extra_output_names = ["pooled_output"] + + def get_model(self, torch_inference=""): + clip_g_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + clip_g_filename = "clip_g.safetensors" + clip_g_model_path = f"{clip_g_model_dir}/{clip_g_filename}" + if not os.path.exists(clip_g_model_path): + hf_hub_download( + repo_id=self.path, + filename=clip_g_filename, + local_dir=load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, "" + ), + subfolder=self.subfolder, + ) + with safe_open(clip_g_model_path, framework="pt", device=self.device) as f: + dtype = torch.float16 if self.fp16 else torch.float32 + model = SDXLClipG(self.CLIPG_CONFIG, device=self.device, dtype=dtype) + load_into(f, model.transformer, "", self.device, dtype) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + if "pooled_output" in self.extra_output_names: + output["pooled_output"] = (batch_size, self.embedding_dim) + + return output + + def optimize(self, onnx_graph): + opt = optimizer.Optimizer(onnx_graph, verbose=self.verbose) + opt.info(self.name + ": original") + opt.select_outputs([0, 1]) + opt.cleanup() + opt.fold_constants() + opt.info(self.name + ": fold constants") + opt.infer_shapes() + opt.info(self.name + ": shape inference") + opt.select_outputs( + [0, 1], names=["text_embeddings", "pooled_output"] + ) # rename network output + opt.info(self.name + ": rename output[0] and output[1]") + opt_onnx_graph = opt.cleanup(return_onnx=True) + opt.info(self.name + ": finished") + return opt_onnx_graph + + +class SD3_CLIPLModel(SD3_CLIPGModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + fp16=False, + pooled_output=False, + ): + self.CLIPL_CONFIG = { + "hidden_act": "quick_gelu", + "hidden_size": 768, + "intermediate_size": 3072, + "num_attention_heads": 12, + "num_hidden_layers": 12, + } + super(SD3_CLIPLModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=self.CLIPL_CONFIG["hidden_size"], + ) + self.subfolder = "text_encoders" + if pooled_output: + self.extra_output_names = ["pooled_output"] + + def get_model(self, torch_inference=""): + clip_l_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + clip_l_filename = "clip_l.safetensors" + clip_l_model_path = f"{clip_l_model_dir}/{clip_l_filename}" + if not os.path.exists(clip_l_model_path): + hf_hub_download( + repo_id=self.path, + filename=clip_l_filename, + local_dir=load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, "" + ), + subfolder=self.subfolder, + ) + with safe_open(clip_l_model_path, framework="pt", device=self.device) as f: + dtype = torch.float16 if self.fp16 else torch.float32 + model = SDClipModel( + layer="hidden", + layer_idx=-2, + device=self.device, + dtype=dtype, + layer_norm_hidden_state=False, + return_projected_pooled=False, + textmodel_json_config=self.CLIPL_CONFIG, + ) + load_into(f, model.transformer, "", self.device, dtype) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + +# NOTE: For legacy reasons, even though this is a T5 model, it inherits from CLIPModel. +class SD3_T5XXLModel(CLIPModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + embedding_dim, + fp16=False, + ): + super(SD3_T5XXLModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim, + ) + self.T5_CONFIG = { + "d_ff": 10240, + "d_model": 4096, + "num_heads": 64, + "num_layers": 24, + "vocab_size": 32128, + } + self.subfolder = "text_encoders" + + def get_model(self, torch_inference=""): + t5xxl_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + t5xxl_filename = "t5xxl_fp16.safetensors" + t5xxl_model_path = f"{t5xxl_model_dir}/{t5xxl_filename}" + if not os.path.exists(t5xxl_model_path): + hf_hub_download( + repo_id=self.path, + filename=t5xxl_filename, + local_dir=load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, "" + ), + subfolder=self.subfolder, + ) + with safe_open(t5xxl_model_path, framework="pt", device=self.device) as f: + dtype = torch.float16 if self.fp16 else torch.float32 + model = T5XXLModel(self.T5_CONFIG, device=self.device, dtype=dtype) + load_into(f, model.transformer, "", self.device, dtype) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + +class CLIPVisionWithProjModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size=1, + subfolder="image_encoder", + ): + super(CLIPVisionWithProjModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + max_batch_size=max_batch_size, + ) + self.subfolder = subfolder + + def get_model(self, torch_inference=""): + clip_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(clip_model_dir): + model = CLIPVisionModelWithProjection.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + ).to(self.device) + model.save_pretrained(clip_model_dir) + else: + print( + f"[I] Load CLIPVisionModelWithProjection model from: {clip_model_dir}" + ) + model = CLIPVisionModelWithProjection.from_pretrained(clip_model_dir).to( + self.device + ) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + +class CLIPImageProcessorModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size=1, + subfolder="feature_extractor", + ): + super(CLIPImageProcessorModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + max_batch_size=max_batch_size, + ) + self.subfolder = subfolder + + def get_model(self, torch_inference=""): + clip_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + # NOTE to(device) not supported + if not os.path.exists(clip_model_dir): + model = CLIPImageProcessor.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + ) + model.save_pretrained(clip_model_dir) + else: + print(f"[I] Load CLIPImageProcessor model from: {clip_model_dir}") + model = CLIPImageProcessor.from_pretrained(clip_model_dir) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/diffusion_transformer.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/diffusion_transformer.py new file mode 100644 index 000000000..e4fa46150 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/diffusion_transformer.py @@ -0,0 +1,753 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open + +from demo_diffusion.dynamic_import import import_from_diffusers +from demo_diffusion.model import base_model, load, optimizer +from demo_diffusion.utils_sd3.other_impls import load_into +from demo_diffusion.utils_sd3.sd3_impls import BaseModel as BaseModelSD3 + +# List of models to import from diffusers.models +models_to_import = ["FluxTransformer2DModel", "SD3Transformer2DModel"] +for model in models_to_import: + globals()[model] = import_from_diffusers(model, "diffusers.models") + + +class SD3_MMDiTModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + shift=1.0, + fp16=False, + max_batch_size=16, + text_maxlen=77, + ): + super(SD3_MMDiTModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + ) + self.subfolder = "sd3" + self.mmdit_dim = 16 + self.shift = shift + self.xB = 2 + + def get_model(self, torch_inference=""): + sd3_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + sd3_filename = "sd3_medium.safetensors" + sd3_model_path = f"{sd3_model_dir}/{sd3_filename}" + if not os.path.exists(sd3_model_path): + hf_hub_download( + repo_id=self.path, filename=sd3_filename, local_dir=sd3_model_dir + ) + with safe_open(sd3_model_path, framework="pt", device=self.device) as f: + model = BaseModelSD3( + shift=self.shift, + file=f, + prefix="model.diffusion_model.", + device=self.device, + dtype=torch.float16, + ).eval() + load_into(f, model, "model.", self.device, torch.float16) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["sample", "sigma", "c_crossattn", "y"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "sigma": {0: xB}, + "c_crossattn": {0: xB}, + "y": {0: xB}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "sample": [ + ( + self.xB * min_batch, + self.mmdit_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.mmdit_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.mmdit_dim, + max_latent_height, + max_latent_width, + ), + ], + "sigma": [ + (self.xB * min_batch,), + (self.xB * batch_size,), + (self.xB * max_batch,), + ], + "c_crossattn": [ + (self.xB * min_batch, 154, 4096), + (self.xB * batch_size, 154, 4096), + (self.xB * max_batch, 154, 4096), + ], + "y": [ + (self.xB * min_batch, 2048), + (self.xB * batch_size, 2048), + (self.xB * max_batch, 2048), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "sample": ( + self.xB * batch_size, + self.mmdit_dim, + latent_height, + latent_width, + ), + "sigma": (self.xB * batch_size,), + "c_crossattn": (self.xB * batch_size, 154, 4096), + "y": (self.xB * batch_size, 2048), + "latent": ( + self.xB * batch_size, + self.mmdit_dim, + latent_height, + latent_width, + ), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + batch_size, + self.mmdit_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.randn(batch_size, dtype=dtype, device=self.device), + { + "c_crossattn": torch.randn( + batch_size, 154, 4096, dtype=dtype, device=self.device + ), + "y": torch.randn(batch_size, 2048, dtype=dtype, device=self.device), + }, + ) + + +class FluxTransformerModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + tf32=False, + int8=False, + fp8=False, + bf16=False, + max_batch_size=16, + text_maxlen=77, + build_strongly_typed=False, + weight_streaming=False, + weight_streaming_budget_percentage=None, + ): + super(FluxTransformerModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + int8=int8, + fp8=fp8, + bf16=bf16, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + ) + self.subfolder = "transformer" + self.transformer_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(self.transformer_model_dir): + self.config = FluxTransformer2DModel.load_config( + self.path, subfolder=self.subfolder, token=self.hf_token + ) + else: + print( + f"[I] Load FluxTransformer2DModel config from: {self.transformer_model_dir}" + ) + self.config = FluxTransformer2DModel.load_config(self.transformer_model_dir) + self.build_strongly_typed = build_strongly_typed + self.weight_streaming = weight_streaming + self.weight_streaming_budget_percentage = weight_streaming_budget_percentage + self.out_channels = ( + self.config.get("out_channels") or self.config["in_channels"] + ) + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + if self.bf16 + else {} + ) + if not load.is_model_cached( + self.transformer_model_dir, model_opts, self.hf_safetensor + ): + model = FluxTransformer2DModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(self.transformer_model_dir, **model_opts) + else: + print( + f"[I] Load FluxTransformer2DModel model from: {self.transformer_model_dir}" + ) + model = FluxTransformer2DModel.from_pretrained( + self.transformer_model_dir, **model_opts + ).to(self.device) + if torch_inference: + model.to(memory_format=torch.channels_last) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return [ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + "img_ids", + "txt_ids", + "guidance", + ] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + dynamic_axes = { + "hidden_states": {0: "B", 1: "latent_dim"}, + "encoder_hidden_states": {0: "B"}, + "pooled_projections": {0: "B"}, + "timestep": {0: "B"}, + "img_ids": {0: "latent_dim"}, + } + if self.config["guidance_embeds"]: + dynamic_axes["guidance"] = {0: "B"} + return dynamic_axes + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + input_profile = { + "hidden_states": [ + ( + min_batch, + (min_latent_height // 2) * (min_latent_width // 2), + self.config["in_channels"], + ), + ( + batch_size, + (latent_height // 2) * (latent_width // 2), + self.config["in_channels"], + ), + ( + max_batch, + (max_latent_height // 2) * (max_latent_width // 2), + self.config["in_channels"], + ), + ], + "encoder_hidden_states": [ + (min_batch, self.text_maxlen, self.config["joint_attention_dim"]), + (batch_size, self.text_maxlen, self.config["joint_attention_dim"]), + (max_batch, self.text_maxlen, self.config["joint_attention_dim"]), + ], + "pooled_projections": [ + (min_batch, self.config["pooled_projection_dim"]), + (batch_size, self.config["pooled_projection_dim"]), + (max_batch, self.config["pooled_projection_dim"]), + ], + "timestep": [(min_batch,), (batch_size,), (max_batch,)], + "img_ids": [ + ((min_latent_height // 2) * (min_latent_width // 2), 3), + ((latent_height // 2) * (latent_width // 2), 3), + ((max_latent_height // 2) * (max_latent_width // 2), 3), + ], + "txt_ids": [ + (self.text_maxlen, 3), + (self.text_maxlen, 3), + (self.text_maxlen, 3), + ], + } + if self.config["guidance_embeds"]: + input_profile["guidance"] = [(min_batch,), (batch_size,), (max_batch,)] + return input_profile + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + shape_dict = { + "hidden_states": ( + batch_size, + (latent_height // 2) * (latent_width // 2), + self.config["in_channels"], + ), + "encoder_hidden_states": ( + batch_size, + self.text_maxlen, + self.config["joint_attention_dim"], + ), + "pooled_projections": (batch_size, self.config["pooled_projection_dim"]), + "timestep": (batch_size,), + "img_ids": ((latent_height // 2) * (latent_width // 2), 3), + "txt_ids": (self.text_maxlen, 3), + "latent": ( + batch_size, + (latent_height // 2) * (latent_width // 2), + self.out_channels, + ), + } + if self.config["guidance_embeds"]: + shape_dict["guidance"] = (batch_size,) + return shape_dict + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = torch.float32 + assert not (self.fp16 and self.bf16), ( + "fp16 and bf16 cannot be enabled simultaneously" + ) + tensor_dtype = ( + torch.bfloat16 + if self.bf16 + else (torch.float16 if self.fp16 else torch.float32) + ) + + sample_input = ( + torch.randn( + batch_size, + (latent_height // 2) * (latent_width // 2), + self.config["in_channels"], + dtype=tensor_dtype, + device=self.device, + ), + torch.randn( + batch_size, + self.text_maxlen, + self.config["joint_attention_dim"], + dtype=tensor_dtype, + device=self.device, + ), + torch.randn( + batch_size, + self.config["pooled_projection_dim"], + dtype=tensor_dtype, + device=self.device, + ), + torch.tensor([1.0] * batch_size, dtype=tensor_dtype, device=self.device), + torch.randn( + (latent_height // 2) * (latent_width // 2), + 3, + dtype=dtype, + device=self.device, + ), + torch.randn(self.text_maxlen, 3, dtype=dtype, device=self.device), + {}, + ) + if self.config["guidance_embeds"]: + sample_input[-1]["guidance"] = torch.tensor( + [1.0] * batch_size, dtype=dtype, device=self.device + ) + return sample_input + + def optimize(self, onnx_graph): + if self.fp8: + return super().optimize(onnx_graph) + if self.int8: + return super().optimize(onnx_graph, fuse_mha_qkv_int8=True) + return super().optimize(onnx_graph) + + +class UpcastLayer(torch.nn.Module): + def __init__(self, base_layer: torch.nn.Module, upcast_to: torch.dtype): + super().__init__() + self.output_dtype = next(base_layer.parameters()).dtype + self.upcast_to = upcast_to + + base_layer = base_layer.to(dtype=self.upcast_to) + self.base_layer = base_layer + + def forward(self, *inputs, **kwargs): + casted_inputs = tuple( + in_val.to(self.upcast_to) if isinstance(in_val, torch.Tensor) else in_val + for in_val in inputs + ) + + kwarg_casted = {} + for name, val in kwargs.items(): + kwarg_casted[name] = ( + val.to(dtype=self.upcast_to) if isinstance(val, torch.Tensor) else val + ) + + output = self.base_layer(*casted_inputs, **kwarg_casted) + if isinstance(output, tuple): + output = tuple( + out.to(self.output_dtype) if isinstance(out, torch.Tensor) else out + for out in output + ) + else: + output = output.to(dtype=self.output_dtype) + return output + + +class SD3TransformerModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + tf32=False, + bf16=False, + max_batch_size=16, + text_maxlen=256, + build_strongly_typed=False, + weight_streaming=False, + weight_streaming_budget_percentage=None, + do_classifier_free_guidance=False, + ): + super(SD3TransformerModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + ) + self.subfolder = "transformer" + self.transformer_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(self.transformer_model_dir): + self.config = SD3Transformer2DModel.load_config( + self.path, subfolder=self.subfolder, token=self.hf_token + ) + else: + print( + f"[I] Load SD3Transformer2DModel config from: {self.transformer_model_dir}" + ) + self.config = SD3Transformer2DModel.load_config(self.transformer_model_dir) + self.build_strongly_typed = build_strongly_typed + self.weight_streaming = weight_streaming + self.weight_streaming_budget_percentage = weight_streaming_budget_percentage + self.out_channels = self.config.get("out_channels") + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + if self.bf16 + else {} + ) + if not load.is_model_cached( + self.transformer_model_dir, model_opts, self.hf_safetensor + ): + model = SD3Transformer2DModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(self.transformer_model_dir, **model_opts) + else: + print( + f"[I] Load SD3Transformer2DModel model from: {self.transformer_model_dir}" + ) + model = SD3Transformer2DModel.from_pretrained( + self.transformer_model_dir, **model_opts + ).to(self.device) + + if self.version == "3.5-large": + model.transformer_blocks[35] = UpcastLayer( + model.transformer_blocks[35], torch.float32 + ) + + if torch_inference: + model.to(memory_format=torch.channels_last) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return [ + "hidden_states", + "encoder_hidden_states", + "pooled_projections", + "timestep", + ] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + dynamic_axes = { + "hidden_states": {0: xB, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: xB}, + "pooled_projections": {0: xB}, + "timestep": {0: xB}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + return dynamic_axes + + def get_input_profile( + self, + batch_size: int, + image_height: int, + image_width: int, + static_batch: bool, + static_shape: bool, + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + + input_profile = { + "hidden_states": [ + ( + self.xB * min_batch, + self.config["in_channels"], + min_latent_height, + min_latent_width, + ), + ( + self.xB * batch_size, + self.config["in_channels"], + latent_height, + latent_width, + ), + ( + self.xB * max_batch, + self.config["in_channels"], + max_latent_height, + max_latent_width, + ), + ], + "encoder_hidden_states": [ + ( + self.xB * min_batch, + self.text_maxlen, + self.config["joint_attention_dim"], + ), + ( + self.xB * batch_size, + self.text_maxlen, + self.config["joint_attention_dim"], + ), + ( + self.xB * max_batch, + self.text_maxlen, + self.config["joint_attention_dim"], + ), + ], + "pooled_projections": [ + (self.xB * min_batch, self.config["pooled_projection_dim"]), + (self.xB * batch_size, self.config["pooled_projection_dim"]), + (self.xB * max_batch, self.config["pooled_projection_dim"]), + ], + "timestep": [ + (self.xB * min_batch,), + (self.xB * batch_size,), + (self.xB * max_batch,), + ], + } + return input_profile + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + shape_dict = { + "hidden_states": ( + self.xB * batch_size, + self.config["in_channels"], + latent_height, + latent_width, + ), + "encoder_hidden_states": ( + self.xB * batch_size, + self.text_maxlen, + self.config["joint_attention_dim"], + ), + "pooled_projections": ( + self.xB * batch_size, + self.config["pooled_projection_dim"], + ), + "timestep": (self.xB * batch_size,), + "latent": ( + self.xB * batch_size, + self.out_channels, + latent_height, + latent_width, + ), + } + return shape_dict + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + assert not (self.fp16 and self.bf16), ( + "fp16 and bf16 cannot be enabled simultaneously" + ) + dtype = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + sample_input = ( + torch.randn( + self.xB * batch_size, + self.config["in_channels"], + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.randn( + self.xB * batch_size, + self.text_maxlen, + self.config["joint_attention_dim"], + dtype=dtype, + device=self.device, + ), + torch.randn( + self.xB * batch_size, + self.config["pooled_projection_dim"], + dtype=dtype, + device=self.device, + ), + torch.randn(self.xB * batch_size, dtype=torch.float32, device=self.device), + ) + return sample_input diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/gan.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/gan.py new file mode 100644 index 000000000..f7ef6b25f --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/gan.py @@ -0,0 +1,187 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from diffusers.pipelines.wuerstchen import PaellaVQModel + +from demo_diffusion.model import base_model, load, optimizer + + +class VQGANModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + bf16=False, + max_batch_size=16, + compression_factor=42, + latent_dim_scale=10.67, + scale_factor=0.3764, + ): + super(VQGANModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + bf16=bf16, + max_batch_size=max_batch_size, + compression_factor=compression_factor, + ) + self.subfolder = "vqgan" + self.latent_dim_scale = latent_dim_scale + self.scale_factor = scale_factor + + def get_model(self, torch_inference=""): + model_opts = ( + {"variant": "bf16", "torch_dtype": torch.bfloat16} if self.bf16 else {} + ) + vqgan_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached( + vqgan_model_dir, model_opts, self.hf_safetensor, model_name="model" + ): + model = PaellaVQModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(vqgan_model_dir, **model_opts) + else: + print(f"[I] Load VQGAN pytorch model from: {vqgan_model_dir}") + model = PaellaVQModel.from_pretrained(vqgan_model_dir, **model_opts).to( + self.device + ) + model.forward = model.decode + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return { + "latent": {0: "B", 2: "H", 3: "W"}, + "images": {0: "B", 2: "8H", 3: "8W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + return torch.randn( + batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device + ) + + def check_dims(self, batch_size, image_height, image_width): + latent_height, latent_width = super().check_dims( + batch_size, image_height, image_width + ) + latent_height = int(latent_height * self.latent_dim_scale) + latent_width = int(latent_width * self.latent_dim_scale) + return (latent_height, latent_width) + + def get_minmax_dims( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = super().get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + min_latent_height = int(min_latent_height * self.latent_dim_scale) + min_latent_width = int(min_latent_width * self.latent_dim_scale) + max_latent_height = int(max_latent_height * self.latent_dim_scale) + max_latent_width = int(max_latent_width * self.latent_dim_scale) + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/load.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/load.py new file mode 100644 index 000000000..1e0a3a288 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/load.py @@ -0,0 +1,127 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Functions for loading models. +""" + +from __future__ import annotations + +import gc +import glob +import os +import sys +from typing import List, Optional + +import onnx +import torch + + +def onnx_graph_needs_external_data(onnx_graph: onnx.ModelProto) -> bool: + """Return true if ONNX graph needs to store external data.""" + if sys.platform == "win32": + # ByteSize is broken (wraps around) on Windows, so always assume external data is needed. + return True + else: + TWO_GIGABYTES = 2147483648 + return onnx_graph.ByteSize() > TWO_GIGABYTES + + +def get_path( + version: str, + pipeline: "pipeline.DiffusionPipeline", + controlnets: Optional[List[str]] = None, +) -> str: + """Return the relative path to the model files directory.""" + if controlnets is not None: + if version == "xl-1.0": + return ["diffusers/controlnet-canny-sdxl-1.0"] + return ["lllyasviel/sd-controlnet-" + modality for modality in controlnets] + + if version in ("1.4", "1.5") and pipeline.is_inpaint(): + return "benjamin-paine/stable-diffusion-v1-5-inpainting" + elif version == "1.4": + return "CompVis/stable-diffusion-v1-4" + elif version == "1.5": + return "KiwiXR/stable-diffusion-v1-5" + elif version == "dreamshaper-7": + return "Lykon/dreamshaper-7" + elif version in ("2.0-base", "2.0") and pipeline.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + elif version == "2.0-base": + return "stabilityai/stable-diffusion-2-base" + elif version == "2.0": + return "stabilityai/stable-diffusion-2" + elif version == "2.1-base": + return "stabilityai/stable-diffusion-2-1-base" + elif version == "2.1": + return "stabilityai/stable-diffusion-2-1" + elif version == "xl-1.0" and pipeline.is_sd_xl_base(): + return "stabilityai/stable-diffusion-xl-base-1.0" + elif version == "xl-1.0" and pipeline.is_sd_xl_refiner(): + return "stabilityai/stable-diffusion-xl-refiner-1.0" + # TODO SDXL turbo with refiner + elif version == "xl-turbo" and pipeline.is_sd_xl_base(): + return "stabilityai/sdxl-turbo" + elif version == "sd3": + return "stabilityai/stable-diffusion-3-medium" + elif version == "3.5-medium": + return "stabilityai/stable-diffusion-3.5-medium" + elif version == "3.5-large": + return "stabilityai/stable-diffusion-3.5-large" + elif version == "svd-xt-1.1" and pipeline.is_img2vid(): + return "stabilityai/stable-video-diffusion-img2vid-xt-1-1" + elif version == "cascade": + if pipeline.is_cascade_decoder(): + return "stabilityai/stable-cascade" + else: + return "stabilityai/stable-cascade-prior" + elif version == "flux.1-dev": + return "black-forest-labs/FLUX.1-dev" + elif version == "flux.1-schnell": + return "black-forest-labs/FLUX.1-schnell" + elif version == "flux.1-dev-canny": + return "black-forest-labs/FLUX.1-Canny-dev" + elif version == "flux.1-dev-depth": + return "black-forest-labs/FLUX.1-Depth-dev" + else: + raise ValueError(f"Unsupported version {version} + pipeline {pipeline.name}") + + +# FIXME serialization not supported for torch.compile +def get_checkpoint_dir( + framework_model_dir: str, version: str, pipeline: str, subfolder: str +) -> str: + """Return the path to the torch model checkpoint directory.""" + return os.path.join(framework_model_dir, version, pipeline, subfolder) + + +def is_model_cached( + model_dir, model_opts, hf_safetensor, model_name="diffusion_pytorch_model" +) -> bool: + """Return True if model was cached.""" + variant = "." + model_opts.get("variant") if "variant" in model_opts else "" + suffix = ".safetensors" if hf_safetensor else ".bin" + # WAR with * for larger models that are split into multiple smaller ckpt files + model_file = model_name + variant + "*" + suffix + return bool(glob.glob(os.path.join(model_dir, model_file))) + + +def unload_torch_model(model): + if model: + del model + torch.cuda.empty_cache() + gc.collect() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/lora.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/lora.py new file mode 100644 index 000000000..d9691444f --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/lora.py @@ -0,0 +1,67 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABC +from diffusers.loaders import StableDiffusionLoraLoaderMixin, FluxLoraLoaderMixin + + +class LoraLoader(ABC): + def __init__(self, paths, weights, scale): + self.paths = paths + self.weights = weights + self.scale = scale + + +class SDLoraLoader(LoraLoader, StableDiffusionLoraLoaderMixin): + def __init__(self, paths, weights, scale): + super().__init__(paths, weights, scale) + + +class FLUXLoraLoader(LoraLoader, FluxLoraLoaderMixin): + def __init__(self, paths, weights, scale): + super().__init__(paths, weights, scale) + + +def merge_loras(model, lora_loader): + paths, weights, scale = lora_loader.paths, lora_loader.weights, lora_loader.scale + for i, path in enumerate(paths): + print(f"[I] Loading LoRA: {path}, weight {weights[i]}") + if isinstance(lora_loader, SDLoraLoader): + state_dict, network_alphas = lora_loader.lora_state_dict( + path, unet_config=model.config + ) + lora_loader.load_lora_into_unet( + state_dict, network_alphas=network_alphas, unet=model, adapter_name=path + ) + elif isinstance(lora_loader, FLUXLoraLoader): + state_dict, network_alphas = lora_loader.lora_state_dict( + path, return_alphas=True + ) + lora_loader.load_lora_into_transformer( + state_dict, + network_alphas=network_alphas, + transformer=model, + adapter_name=path, + ) + else: + raise ValueError(f"Unsupported LoRA loader: {lora_loader}") + + model.set_adapters(paths, weights=weights) + # NOTE: fuse_lora an experimental API in Diffusers + model.fuse_lora(adapter_names=paths, lora_scale=scale) + model.unload_lora() + return model diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/optimizer.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/optimizer.py new file mode 100644 index 000000000..4e047702b --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/optimizer.py @@ -0,0 +1,222 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +import tempfile + +import onnx_graphsurgeon as gs +import torch +from onnxmltools.utils.float16_converter import convert_float_to_float16 +from polygraphy.backend.onnx.loader import fold_constants + +import onnx +from demo_diffusion.model import load +from demo_diffusion.utils_modelopt import ( + cast_fp8_mha_io, + cast_resize_io, + convert_fp16_io, + convert_zp_fp8, +) +from onnx import shape_inference + +# FIXME update callsites after serialization support for torch.compile is added +TORCH_INFERENCE_MODELS = ["default", "reduce-overhead", "max-autotune"] + + +def optimize_checkpoint(model, torch_inference: str): + """Optimize a torch model checkpoint using torch.compile.""" + if not torch_inference or torch_inference == "eager": + return model + assert torch_inference in TORCH_INFERENCE_MODELS + return torch.compile(model, mode=torch_inference, dynamic=False, fullgraph=False) + + +class Optimizer: + def __init__(self, onnx_graph, verbose=False): + self.graph = gs.import_onnx(onnx_graph) + self.verbose = verbose + + def info(self, prefix): + if self.verbose: + print( + f"{prefix} .. {len(self.graph.nodes)} nodes, {len(self.graph.tensors().keys())} tensors, {len(self.graph.inputs)} inputs, {len(self.graph.outputs)} outputs" + ) + + def cleanup(self, return_onnx=False): + self.graph.cleanup().toposort() + return gs.export_onnx(self.graph) if return_onnx else self.graph + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self, return_onnx=False): + onnx_graph = fold_constants( + gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True + ) + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def infer_shapes(self, return_onnx=False): + onnx_graph = gs.export_onnx(self.graph) + if load.onnx_graph_needs_external_data(onnx_graph): + temp_dir = tempfile.TemporaryDirectory().name + os.makedirs(temp_dir, exist_ok=True) + onnx_orig_path = os.path.join(temp_dir, "model.onnx") + onnx_inferred_path = os.path.join(temp_dir, "inferred.onnx") + onnx.save_model( + onnx_graph, + onnx_orig_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + onnx.shape_inference.infer_shapes_path(onnx_orig_path, onnx_inferred_path) + onnx_graph = onnx.load(onnx_inferred_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + if return_onnx: + return onnx_graph + + def clip_add_hidden_states(self, hidden_layer_offset, return_onnx=False): + hidden_layers = -1 + onnx_graph = gs.export_onnx(self.graph) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + name = onnx_graph.graph.node[i].output[j] + if "layers" in name: + hidden_layers = max( + int(name.split(".")[1].split("/")[0]), hidden_layers + ) + for i in range(len(onnx_graph.graph.node)): + for j in range(len(onnx_graph.graph.node[i].output)): + if onnx_graph.graph.node[i].output[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers + hidden_layer_offset + ): + onnx_graph.graph.node[i].output[j] = "hidden_states" + for j in range(len(onnx_graph.graph.node[i].input)): + if onnx_graph.graph.node[i].input[ + j + ] == "/text_model/encoder/layers.{}/Add_1_output_0".format( + hidden_layers + hidden_layer_offset + ): + onnx_graph.graph.node[i].input[j] = "hidden_states" + if return_onnx: + return onnx_graph + + def fuse_mha_qkv_int8_sq(self): + tensors = self.graph.tensors() + keys = tensors.keys() + + # mha : fuse QKV QDQ nodes + # mhca : fuse KV QDQ nodes + q_pat = ( + "/down_blocks.\\d+/attentions.\\d+/transformer_blocks" + ".\\d+/attn\\d+/to_q/input_quantizer/DequantizeLinear_output_0" + ) + k_pat = ( + "/down_blocks.\\d+/attentions.\\d+/transformer_blocks" + ".\\d+/attn\\d+/to_k/input_quantizer/DequantizeLinear_output_0" + ) + v_pat = ( + "/down_blocks.\\d+/attentions.\\d+/transformer_blocks" + ".\\d+/attn\\d+/to_v/input_quantizer/DequantizeLinear_output_0" + ) + + qs = list( + sorted( + map( + lambda x: x.group(0), # type: ignore + filter( + lambda x: x is not None, [re.match(q_pat, key) for key in keys] + ), + ) + ) + ) + ks = list( + sorted( + map( + lambda x: x.group(0), # type: ignore + filter( + lambda x: x is not None, [re.match(k_pat, key) for key in keys] + ), + ) + ) + ) + vs = list( + sorted( + map( + lambda x: x.group(0), # type: ignore + filter( + lambda x: x is not None, [re.match(v_pat, key) for key in keys] + ), + ) + ) + ) + + removed = 0 + assert len(qs) == len(ks) == len(vs), "Failed to collect tensors" + for q, k, v in zip(qs, ks, vs): + is_mha = all(["attn1" in tensor for tensor in [q, k, v]]) + is_mhca = all(["attn2" in tensor for tensor in [q, k, v]]) + assert (is_mha or is_mhca) and (not (is_mha and is_mhca)) + + if is_mha: + tensors[k].outputs[0].inputs[0] = tensors[q] + tensors[v].outputs[0].inputs[0] = tensors[q] + del tensors[k] + del tensors[v] + removed += 2 + else: # is_mhca + tensors[k].outputs[0].inputs[0] = tensors[v] + del tensors[k] + removed += 1 + print(f"Removed {removed} QDQ nodes") + return removed # expected 72 for L2.5 + + def modify_fp8_graph(self, is_fp16_io=True): + onnx_graph = gs.export_onnx(self.graph) + # Convert INT8 Zero to FP8. + onnx_graph = convert_zp_fp8(onnx_graph) + # Convert weights and activations to FP16 and insert Cast nodes in FP8 MHA. + onnx_graph = convert_float_to_float16( + onnx_graph, keep_io_types=True, disable_shape_infer=True + ) + self.graph = gs.import_onnx(onnx_graph) + # Add cast nodes to Resize I/O. + cast_resize_io(self.graph) + # Convert model inputs and outputs to fp16 I/O. + if is_fp16_io: + convert_fp16_io(self.graph) + # Add cast nodes to MHA's BMM1 and BMM2's I/O. + cast_fp8_mha_io(self.graph) + + def flux_convert_rope_weight_type(self): + for node in self.graph.nodes: + if node.op == "Einsum": + print( + f"Fixed RoPE (Rotary Position Embedding) weight type: {node.name}" + ) + return gs.export_onnx(self.graph) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/scheduler.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/scheduler.py new file mode 100644 index 000000000..75331b45e --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/scheduler.py @@ -0,0 +1,41 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from demo_diffusion.model import load + + +def make_scheduler( + cls, version, pipeline, hf_token, framework_model_dir, subfolder="scheduler" +): + scheduler_dir = os.path.join( + framework_model_dir, + version, + pipeline.name, + next(iter({cls.__name__})).lower(), + subfolder, + ) + if not os.path.exists(scheduler_dir): + scheduler = cls.from_pretrained( + load.get_path(version, pipeline), subfolder=subfolder, token=hf_token + ) + scheduler.save_pretrained(scheduler_dir) + else: + print(f"[I] Load Scheduler {cls.__name__} from: {scheduler_dir}") + scheduler = cls.from_pretrained(scheduler_dir) + return scheduler diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/t5.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/t5.py new file mode 100644 index 000000000..a0f89f556 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/t5.py @@ -0,0 +1,139 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import torch +from transformers import ( + AutoConfig, + T5EncoderModel, +) + +from demo_diffusion.model import base_model, load, optimizer + + +class T5Model(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + fp16=False, + tf32=False, + bf16=False, + subfolder="text_encoder", + text_maxlen=512, + build_strongly_typed=False, + weight_streaming=False, + weight_streaming_budget_percentage=None, + ): + super(T5Model, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + ) + self.subfolder = subfolder + self.t5_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(self.t5_model_dir): + self.config = AutoConfig.from_pretrained( + self.path, subfolder=self.subfolder, token=self.hf_token + ) + else: + print(f"[I] Load T5Encoder Config from: {self.t5_model_dir}") + self.config = AutoConfig.from_pretrained(self.t5_model_dir) + self.build_strongly_typed = build_strongly_typed + self.weight_streaming = weight_streaming + self.weight_streaming_budget_percentage = weight_streaming_budget_percentage + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + if self.bf16 + else {} + ) + if not load.is_model_cached( + self.t5_model_dir, model_opts, self.hf_safetensor, model_name="model" + ): + model = T5EncoderModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(self.t5_model_dir, **model_opts) + else: + print(f"[I] Load T5EncoderModel model from: {self.t5_model_dir}") + model = T5EncoderModel.from_pretrained(self.t5_model_dir, **model_opts).to( + self.device + ) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "input_ids": [ + (min_batch, self.text_maxlen), + (batch_size, self.text_maxlen), + (max_batch, self.text_maxlen), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.config.d_model), + } + return output + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + self.check_dims(batch_size, image_height, image_width) + return torch.zeros( + batch_size, self.text_maxlen, dtype=torch.int32, device=self.device + ) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/tokenizer.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/tokenizer.py new file mode 100644 index 000000000..0218e5b74 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/tokenizer.py @@ -0,0 +1,58 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from transformers import ( + CLIPTokenizer, + T5TokenizerFast, +) + +from demo_diffusion.model import load + + +def make_tokenizer( + version, + pipeline, + hf_token, + framework_model_dir, + subfolder="tokenizer", + tokenizer_type="clip", +): + if tokenizer_type == "clip": + tokenizer_class = CLIPTokenizer + elif tokenizer_type == "t5": + tokenizer_class = T5TokenizerFast + else: + raise ValueError( + f"Unsupported tokenizer_type {tokenizer_type}. Only tokenizer_type clip and t5 are currently supported" + ) + tokenizer_model_dir = load.get_checkpoint_dir( + framework_model_dir, version, pipeline.name, subfolder + ) + if not os.path.exists(tokenizer_model_dir): + model = tokenizer_class.from_pretrained( + load.get_path(version, pipeline), + subfolder=subfolder, + use_safetensors=pipeline.is_sd_xl(), + token=hf_token, + ) + model.save_pretrained(tokenizer_model_dir) + else: + print(f"[I] Load {tokenizer_class.__name__} model from: {tokenizer_model_dir}") + model = tokenizer_class.from_pretrained(tokenizer_model_dir) + return model diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/unet.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/unet.py new file mode 100644 index 000000000..cf7b065e4 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/unet.py @@ -0,0 +1,1307 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Model definitions for UNet models. +""" + +import torch + +from demo_diffusion.dynamic_import import import_from_diffusers +from demo_diffusion.model import base_model, load, optimizer + +# List of models to import from diffusers.models +models_to_import = [ + "ControlNetModel", + "UNet2DConditionModel", + "UNetSpatioTemporalConditionModel", + "StableCascadeUNet", +] +for model in models_to_import: + globals()[model] = import_from_diffusers(model, "diffusers.models") + + +def get_unet_embedding_dim(version, pipeline): + if version in ("1.4", "1.5", "dreamshaper-7"): + return 768 + elif version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_base(): + return 2048 + elif version in ("cascade"): + return 1280 + elif version in ("xl-1.0", "xl-turbo") and pipeline.is_sd_xl_refiner(): + return 1280 + elif pipeline.is_img2vid(): + return 1024 + else: + raise ValueError(f"Invalid version {version} + pipeline {pipeline}") + + +class UNet2DConditionControlNetModel(torch.nn.Module): + def __init__(self, unet, controlnets) -> None: + super().__init__() + self.unet = unet + self.controlnets = controlnets + + def forward( + self, + sample, + timestep, + encoder_hidden_states, + images, + controlnet_scales, + added_cond_kwargs=None, + ): + for i, (image, conditioning_scale, controlnet) in enumerate( + zip(images, controlnet_scales, self.controlnets) + ): + down_samples, mid_sample = controlnet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + return_dict=False, + added_cond_kwargs=added_cond_kwargs, + ) + + down_samples = [ + down_sample * conditioning_scale for down_sample in down_samples + ] + mid_sample *= conditioning_scale + + # merge samples + if i == 0: + down_block_res_samples, mid_block_res_sample = down_samples, mid_sample + else: + down_block_res_samples = [ + samples_prev + samples_curr + for samples_prev, samples_curr in zip( + down_block_res_samples, down_samples + ) + ] + mid_block_res_sample += mid_sample + + noise_pred = self.unet( + sample, + timestep, + encoder_hidden_states=encoder_hidden_states, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + ) + return noise_pred + + +class UNetModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + int8=False, + fp8=False, + max_batch_size=16, + text_maxlen=77, + controlnets=None, + do_classifier_free_guidance=False, + ): + super(UNetModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + int8=int8, + fp8=fp8, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + embedding_dim=get_unet_embedding_dim(version, pipeline), + ) + self.subfolder = "unet" + self.controlnets = ( + load.get_path(version, pipeline, controlnets) if controlnets else None + ) + self.unet_dim = 9 if pipeline.is_inpaint() else 4 + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + + def get_model(self, torch_inference=""): + model_opts = ( + {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + ) + if self.controlnets: + unet_model = UNet2DConditionModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + controlnets = torch.nn.ModuleList( + [ + ControlNetModel.from_pretrained(path, **cnet_model_opts).to( + self.device + ) + for path in self.controlnets + ] + ) + # FIXME - cache UNet2DConditionControlNetModel + model = UNet2DConditionControlNetModel(unet_model, controlnets) + else: + unet_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached(unet_model_dir, model_opts, self.hf_safetensor): + model = UNet2DConditionModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(unet_model_dir, **model_opts) + else: + print(f"[I] Load UNet2DConditionModel model from: {unet_model_dir}") + model = UNet2DConditionModel.from_pretrained( + unet_model_dir, **model_opts + ).to(self.device) + if torch_inference: + model.to(memory_format=torch.channels_last) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + if self.controlnets is None: + return ["sample", "timestep", "encoder_hidden_states"] + else: + return [ + "sample", + "timestep", + "encoder_hidden_states", + "images", + "controlnet_scales", + ] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + if self.controlnets is None: + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: xB}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + else: + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: xB}, + "images": {1: xB, 3: "8H", 4: "8W"}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + # WAR to enable inference for H/W that are not multiples of 16 + # If building with Dynamic Shapes: ensure image height and width are not multiples of 16 for ONNX export and TensorRT engine build + if not static_shape: + image_height = image_height - 8 if image_height % 16 == 0 else image_height + image_width = image_width - 8 if image_width % 16 == 0 else image_width + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + if self.controlnets is None: + return { + "sample": [ + ( + self.xB * min_batch, + self.unet_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.unet_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.unet_dim, + max_latent_height, + max_latent_width, + ), + ], + "encoder_hidden_states": [ + (self.xB * min_batch, self.text_maxlen, self.embedding_dim), + (self.xB * batch_size, self.text_maxlen, self.embedding_dim), + (self.xB * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + else: + return { + "sample": [ + ( + self.xB * min_batch, + self.unet_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.unet_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.unet_dim, + max_latent_height, + max_latent_width, + ), + ], + "encoder_hidden_states": [ + (self.xB * min_batch, self.text_maxlen, self.embedding_dim), + (self.xB * batch_size, self.text_maxlen, self.embedding_dim), + (self.xB * max_batch, self.text_maxlen, self.embedding_dim), + ], + "images": [ + ( + len(self.controlnets), + self.xB * min_batch, + 3, + min_image_height, + min_image_width, + ), + ( + len(self.controlnets), + self.xB * batch_size, + 3, + image_height, + image_width, + ), + ( + len(self.controlnets), + self.xB * max_batch, + 3, + max_image_height, + max_image_width, + ), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + if self.controlnets is None: + return { + "sample": ( + self.xB * batch_size, + self.unet_dim, + latent_height, + latent_width, + ), + "encoder_hidden_states": ( + self.xB * batch_size, + self.text_maxlen, + self.embedding_dim, + ), + "latent": (self.xB * batch_size, 4, latent_height, latent_width), + } + else: + return { + "sample": ( + self.xB * batch_size, + self.unet_dim, + latent_height, + latent_width, + ), + "encoder_hidden_states": ( + self.xB * batch_size, + self.text_maxlen, + self.embedding_dim, + ), + "images": ( + len(self.controlnets), + self.xB * batch_size, + 3, + image_height, + image_width, + ), + "latent": (self.xB * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + # WAR to enable inference for H/W that are not multiples of 16 + # If building with Dynamic Shapes: ensure image height and width are not multiples of 16 for ONNX export and TensorRT engine build + if not static_shape: + image_height = image_height - 8 if image_height % 16 == 0 else image_height + image_width = image_width - 8 if image_width % 16 == 0 else image_width + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = torch.float16 if self.fp16 else torch.float32 + if self.controlnets is None: + return ( + torch.randn( + batch_size, + self.unet_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor([1.0], dtype=dtype, device=self.device), + torch.randn( + batch_size, + self.text_maxlen, + self.embedding_dim, + dtype=dtype, + device=self.device, + ), + ) + else: + return ( + torch.randn( + batch_size, + self.unet_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor(999, dtype=dtype, device=self.device), + torch.randn( + batch_size, + self.text_maxlen, + self.embedding_dim, + dtype=dtype, + device=self.device, + ), + torch.randn( + len(self.controlnets), + batch_size, + 3, + image_height, + image_width, + dtype=dtype, + device=self.device, + ), + torch.randn(len(self.controlnets), dtype=dtype, device=self.device), + ) + + def optimize(self, onnx_graph): + if self.fp8: + return super().optimize(onnx_graph, modify_fp8_graph=True) + if self.int8: + return super().optimize(onnx_graph, fuse_mha_qkv_int8=True) + return super().optimize(onnx_graph) + + +class UNetXLModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + int8=False, + fp8=False, + max_batch_size=16, + text_maxlen=77, + do_classifier_free_guidance=False, + ): + super(UNetXLModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + int8=int8, + fp8=fp8, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + embedding_dim=get_unet_embedding_dim(version, pipeline), + ) + self.subfolder = "unet" + self.unet_dim = 9 if pipeline.is_inpaint() else 4 + self.time_dim = 5 if pipeline.is_sd_xl_refiner() else 6 + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + + def get_model(self, torch_inference=""): + model_opts = ( + {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + ) + unet_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached(unet_model_dir, model_opts, self.hf_safetensor): + model = UNet2DConditionModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + # Use default attention processor for ONNX export + if not torch_inference: + model.set_default_attn_processor() + model.save_pretrained(unet_model_dir, **model_opts) + else: + print(f"[I] Load UNet2DConditionModel model from: {unet_model_dir}") + model = UNet2DConditionModel.from_pretrained( + unet_model_dir, **model_opts + ).to(self.device) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return [ + "sample", + "timestep", + "encoder_hidden_states", + "text_embeds", + "time_ids", + ] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "encoder_hidden_states": {0: xB}, + "latent": {0: xB, 2: "H", 3: "W"}, + "text_embeds": {0: xB}, + "time_ids": {0: xB}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + # WAR to enable inference for H/W that are not multiples of 16 + # If building with Dynamic Shapes: ensure image height and width are not multiples of 16 for ONNX export and TensorRT engine build + if not static_shape: + image_height = image_height - 8 if image_height % 16 == 0 else image_height + image_width = image_width - 8 if image_width % 16 == 0 else image_width + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "sample": [ + ( + self.xB * min_batch, + self.unet_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.unet_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.unet_dim, + max_latent_height, + max_latent_width, + ), + ], + "encoder_hidden_states": [ + (self.xB * min_batch, self.text_maxlen, self.embedding_dim), + (self.xB * batch_size, self.text_maxlen, self.embedding_dim), + (self.xB * max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [ + (self.xB * min_batch, 1280), + (self.xB * batch_size, 1280), + (self.xB * max_batch, 1280), + ], + "time_ids": [ + (self.xB * min_batch, self.time_dim), + (self.xB * batch_size, self.time_dim), + (self.xB * max_batch, self.time_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "sample": ( + self.xB * batch_size, + self.unet_dim, + latent_height, + latent_width, + ), + "encoder_hidden_states": ( + self.xB * batch_size, + self.text_maxlen, + self.embedding_dim, + ), + "latent": (self.xB * batch_size, 4, latent_height, latent_width), + "text_embeds": (self.xB * batch_size, 1280), + "time_ids": (self.xB * batch_size, self.time_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + # WAR to enable inference for H/W that are not multiples of 16 + # If building with Dynamic Shapes: ensure image height and width are not multiples of 16 for ONNX export and TensorRT engine build + if not static_shape: + image_height = image_height - 8 if image_height % 16 == 0 else image_height + image_width = image_width - 8 if image_width % 16 == 0 else image_width + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + self.xB * batch_size, + self.unet_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor([1.0], dtype=dtype, device=self.device), + torch.randn( + self.xB * batch_size, + self.text_maxlen, + self.embedding_dim, + dtype=dtype, + device=self.device, + ), + { + "added_cond_kwargs": { + "text_embeds": torch.randn( + self.xB * batch_size, 1280, dtype=dtype, device=self.device + ), + "time_ids": torch.randn( + self.xB * batch_size, + self.time_dim, + dtype=dtype, + device=self.device, + ), + } + }, + ) + + def optimize(self, onnx_graph): + if self.fp8: + return super().optimize(onnx_graph, modify_fp8_graph=True) + if self.int8: + return super().optimize(onnx_graph, fuse_mha_qkv_int8=True) + return super().optimize(onnx_graph) + + +class UNetXLModelControlNet(UNetXLModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + int8=False, + fp8=False, + max_batch_size=16, + text_maxlen=77, + controlnets=None, + do_classifier_free_guidance=False, + ): + super().__init__( + version=version, + pipeline=pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + int8=int8, + fp8=fp8, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + do_classifier_free_guidance=do_classifier_free_guidance, + ) + self.controlnets = ( + load.get_path(version, pipeline, controlnets) if controlnets else None + ) + + def get_model(self, torch_inference=""): + model_opts = ( + {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + ) + unet_model = UNet2DConditionModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + controlnets = torch.nn.ModuleList( + [ + ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) + for path in self.controlnets + ] + ) + # FIXME - cache UNet2DConditionControlNetModel + model = UNet2DConditionControlNetModel(unet_model, controlnets) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return [ + "sample", + "timestep", + "encoder_hidden_states", + "images", + "controlnet_scales", + "text_embeds", + "time_ids", + ] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + result = super().get_dynamic_axes() + result["images"] = {1: xB, 3: "8H", 4: "8W"} + return result + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + result = super().get_input_profile( + batch_size, image_height, image_width, static_batch, static_shape + ) + result["images"] = [ + ( + len(self.controlnets), + self.xB * min_batch, + 3, + min_image_height, + min_image_width, + ), + (len(self.controlnets), self.xB * batch_size, 3, image_height, image_width), + ( + len(self.controlnets), + self.xB * max_batch, + 3, + max_image_height, + max_image_width, + ), + ] + return result + + def get_shape_dict(self, batch_size, image_height, image_width): + result = super().get_shape_dict(batch_size, image_height, image_width) + result["images"] = ( + len(self.controlnets), + self.xB * batch_size, + 3, + image_height, + image_width, + ) + return result + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + dtype = torch.float16 if self.fp16 else torch.float32 + result = super().get_sample_input( + batch_size, image_height, image_width, static_shape + ) + result = ( + result[:-1] + + ( + torch.randn( + len(self.controlnets), + self.xB * batch_size, + 3, + image_height, + image_width, + dtype=dtype, + device=self.device, + ), # images + torch.randn( + len(self.controlnets), dtype=dtype, device=self.device + ), # controlnet_scales + ) + + result[-1:] + ) + return result + + +class UNetTemporalModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + fp8=False, + max_batch_size=16, + num_frames=14, + do_classifier_free_guidance=True, + ): + super(UNetTemporalModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + fp8=fp8, + max_batch_size=max_batch_size, + embedding_dim=get_unet_embedding_dim(version, pipeline), + ) + self.subfolder = "unet" + self.unet_dim = 4 + self.num_frames = num_frames + self.out_channels = 4 + self.cross_attention_dim = 1024 + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + + def get_model(self, torch_inference=""): + model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + unet_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached(unet_model_dir, model_opts, self.hf_safetensor): + model = UNetSpatioTemporalConditionModel.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(unet_model_dir, **model_opts) + else: + print( + f"[I] Load UNetSpatioTemporalConditionModel model from: {unet_model_dir}" + ) + model = UNetSpatioTemporalConditionModel.from_pretrained( + unet_model_dir, **model_opts + ).to(self.device) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "added_time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = str(self.xB) + "B" + return { + "sample": {0: xB, 1: "num_frames", 3: "H", 4: "W"}, + "encoder_hidden_states": {0: xB}, + "added_time_ids": {0: xB}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "sample": [ + ( + self.xB * min_batch, + self.num_frames, + 2 * self.out_channels, + min_latent_height, + min_latent_width, + ), + ( + self.xB * batch_size, + self.num_frames, + 2 * self.out_channels, + latent_height, + latent_width, + ), + ( + self.xB * max_batch, + self.num_frames, + 2 * self.out_channels, + max_latent_height, + max_latent_width, + ), + ], + "encoder_hidden_states": [ + (self.xB * min_batch, 1, self.cross_attention_dim), + (self.xB * batch_size, 1, self.cross_attention_dim), + (self.xB * max_batch, 1, self.cross_attention_dim), + ], + "added_time_ids": [ + (self.xB * min_batch, 3), + (self.xB * batch_size, 3), + (self.xB * max_batch, 3), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "sample": ( + self.xB * batch_size, + self.num_frames, + 2 * self.out_channels, + latent_height, + latent_width, + ), + "timestep": (1,), + "encoder_hidden_states": ( + self.xB * batch_size, + 1, + self.cross_attention_dim, + ), + "added_time_ids": (self.xB * batch_size, 3), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + # TODO chunk_size if forward_chunking is used + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + self.xB * batch_size, + self.num_frames, + 2 * self.out_channels, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn( + self.xB * batch_size, + 1, + self.cross_attention_dim, + dtype=dtype, + device=self.device, + ), + torch.randn(self.xB * batch_size, 3, dtype=dtype, device=self.device), + ) + + def optimize(self, onnx_graph): + return super().optimize(onnx_graph, modify_fp8_graph=self.fp8) + + +class UNetCascadeModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + bf16=False, + max_batch_size=16, + text_maxlen=77, + do_classifier_free_guidance=False, + compression_factor=42, + latent_dim_scale=10.67, + image_embedding_dim=768, + lite=False, + ): + super(UNetCascadeModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + bf16=bf16, + max_batch_size=max_batch_size, + text_maxlen=text_maxlen, + embedding_dim=get_unet_embedding_dim(version, pipeline), + compression_factor=compression_factor, + ) + self.is_prior = True if pipeline.is_cascade_prior() else False + self.subfolder = "prior" if self.is_prior else "decoder" + if lite: + self.subfolder += "_lite" + self.prior_dim = 16 + self.decoder_dim = 4 + self.xB = 2 if do_classifier_free_guidance else 1 # batch multiplier + self.latent_dim_scale = latent_dim_scale + self.min_latent_shape = self.min_image_shape // self.compression_factor + self.max_latent_shape = self.max_image_shape // self.compression_factor + self.do_constant_folding = False + self.image_embedding_dim = image_embedding_dim + + def get_model(self, torch_inference=""): + # FP16 variant doesn't exist + model_opts = {"torch_dtype": torch.float16} if self.fp16 else {} + model_opts = ( + {"variant": "bf16", "torch_dtype": torch.bfloat16} + if self.bf16 + else model_opts + ) + unet_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not load.is_model_cached(unet_model_dir, model_opts, self.hf_safetensor): + model = StableCascadeUNet.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(unet_model_dir, **model_opts) + else: + print(f"[I] Load Stable Cascade UNet pytorch model from: {unet_model_dir}") + model = StableCascadeUNet.from_pretrained(unet_model_dir, **model_opts).to( + self.device + ) + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + if self.is_prior: + return [ + "sample", + "timestep_ratio", + "clip_text_pooled", + "clip_text", + "clip_img", + ] + else: + return ["sample", "timestep_ratio", "clip_text_pooled", "effnet"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + xB = "2B" if self.xB == 2 else "B" + if self.is_prior: + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "timestep_ratio": {0: xB}, + "clip_text_pooled": {0: xB}, + "clip_text": {0: xB}, + "clip_img": {0: xB}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + else: + return { + "sample": {0: xB, 2: "H", 3: "W"}, + "timestep_ratio": {0: xB}, + "clip_text_pooled": {0: xB}, + "effnet": {0: xB, 2: "H_effnet", 3: "W_effnet"}, + "latent": {0: xB, 2: "H", 3: "W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + if self.is_prior: + return { + "sample": [ + ( + self.xB * min_batch, + self.prior_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.prior_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.prior_dim, + max_latent_height, + max_latent_width, + ), + ], + "timestep_ratio": [ + (self.xB * min_batch,), + (self.xB * batch_size,), + (self.xB * max_batch,), + ], + "clip_text_pooled": [ + (self.xB * min_batch, 1, self.embedding_dim), + (self.xB * batch_size, 1, self.embedding_dim), + (self.xB * max_batch, 1, self.embedding_dim), + ], + "clip_text": [ + (self.xB * min_batch, self.text_maxlen, self.embedding_dim), + (self.xB * batch_size, self.text_maxlen, self.embedding_dim), + (self.xB * max_batch, self.text_maxlen, self.embedding_dim), + ], + "clip_img": [ + (self.xB * min_batch, 1, self.image_embedding_dim), + (self.xB * batch_size, 1, self.image_embedding_dim), + (self.xB * max_batch, 1, self.image_embedding_dim), + ], + } + else: + return { + "sample": [ + ( + self.xB * min_batch, + self.decoder_dim, + int(min_latent_height * self.latent_dim_scale), + int(min_latent_width * self.latent_dim_scale), + ), + ( + self.xB * batch_size, + self.decoder_dim, + int(latent_height * self.latent_dim_scale), + int(latent_width * self.latent_dim_scale), + ), + ( + self.xB * max_batch, + self.decoder_dim, + int(max_latent_height * self.latent_dim_scale), + int(max_latent_width * self.latent_dim_scale), + ), + ], + "timestep_ratio": [ + (self.xB * min_batch,), + (self.xB * batch_size,), + (self.xB * max_batch,), + ], + "clip_text_pooled": [ + (self.xB * min_batch, 1, self.embedding_dim), + (self.xB * batch_size, 1, self.embedding_dim), + (self.xB * max_batch, 1, self.embedding_dim), + ], + "effnet": [ + ( + self.xB * min_batch, + self.prior_dim, + min_latent_height, + min_latent_width, + ), + (self.xB * batch_size, self.prior_dim, latent_height, latent_width), + ( + self.xB * max_batch, + self.prior_dim, + max_latent_height, + max_latent_width, + ), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + if self.is_prior: + return { + "sample": ( + self.xB * batch_size, + self.prior_dim, + latent_height, + latent_width, + ), + "timestep_ratio": (self.xB * batch_size,), + "clip_text_pooled": (self.xB * batch_size, 1, self.embedding_dim), + "clip_text": ( + self.xB * batch_size, + self.text_maxlen, + self.embedding_dim, + ), + "clip_img": (self.xB * batch_size, 1, self.image_embedding_dim), + "latent": ( + self.xB * batch_size, + self.prior_dim, + latent_height, + latent_width, + ), + } + else: + return { + "sample": ( + self.xB * batch_size, + self.decoder_dim, + int(latent_height * self.latent_dim_scale), + int(latent_width * self.latent_dim_scale), + ), + "timestep_ratio": (self.xB * batch_size,), + "clip_text_pooled": (self.xB * batch_size, 1, self.embedding_dim), + "effnet": ( + self.xB * batch_size, + self.prior_dim, + latent_height, + latent_width, + ), + "latent": ( + self.xB * batch_size, + self.decoder_dim, + int(latent_height * self.latent_dim_scale), + int(latent_width * self.latent_dim_scale), + ), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + if self.is_prior: + return ( + torch.randn( + batch_size, + self.prior_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + torch.tensor([1.0] * batch_size, dtype=dtype, device=self.device), + torch.randn( + batch_size, 1, self.embedding_dim, dtype=dtype, device=self.device + ), + { + "clip_text": torch.randn( + batch_size, + self.text_maxlen, + self.embedding_dim, + dtype=dtype, + device=self.device, + ), + "clip_img": torch.randn( + batch_size, + 1, + self.image_embedding_dim, + dtype=dtype, + device=self.device, + ), + }, + ) + else: + return ( + torch.randn( + batch_size, + self.decoder_dim, + int(latent_height * self.latent_dim_scale), + int(latent_width * self.latent_dim_scale), + dtype=dtype, + device=self.device, + ), + torch.tensor([1.0] * batch_size, dtype=dtype, device=self.device), + torch.randn( + batch_size, 1, self.embedding_dim, dtype=dtype, device=self.device + ), + { + "effnet": torch.randn( + batch_size, + self.prior_dim, + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ), + }, + ) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/model/vae.py b/flux.1-dev-trt-b200/model/demo_diffusion/model/vae.py new file mode 100644 index 000000000..a873838bb --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/model/vae.py @@ -0,0 +1,675 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +import torch +from huggingface_hub import hf_hub_download +from safetensors import safe_open + +from demo_diffusion.dynamic_import import import_from_diffusers +from demo_diffusion.model import base_model, load, optimizer +from demo_diffusion.utils_sd3.other_impls import load_into +from demo_diffusion.utils_sd3.sd3_impls import SDVAE + +# List of models to import from diffusers.models +models_to_import = [ + "AutoencoderKL", + "AutoencoderKLTemporalDecoder", +] +for model in models_to_import: + globals()[model] = import_from_diffusers(model, "diffusers.models") + + +class VAEModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + tf32=False, + bf16=False, + max_batch_size=16, + ): + super(VAEModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + max_batch_size=max_batch_size, + ) + self.subfolder = "vae" + self.vae_decoder_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(self.vae_decoder_model_dir): + self.config = AutoencoderKL.load_config( + self.path, subfolder=self.subfolder, token=self.hf_token + ) + else: + print( + f"[I] Load AutoencoderKL (decoder) config from: {self.vae_decoder_model_dir}" + ) + self.config = AutoencoderKL.load_config(self.vae_decoder_model_dir) + + def get_model(self, torch_inference=""): + model_opts = ( + {"torch_dtype": torch.float16} + if self.fp16 + else {"torch_dtype": torch.bfloat16} + if self.bf16 + else {} + ) + if not load.is_model_cached( + self.vae_decoder_model_dir, model_opts, self.hf_safetensor + ): + model = AutoencoderKL.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + **model_opts, + ).to(self.device) + model.save_pretrained(self.vae_decoder_model_dir, **model_opts) + else: + print( + f"[I] Load AutoencoderKL (decoder) model from: {self.vae_decoder_model_dir}" + ) + model = AutoencoderKL.from_pretrained( + self.vae_decoder_model_dir, **model_opts + ).to(self.device) + model.forward = model.decode + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return { + "latent": {0: "B", 2: "H", 3: "W"}, + "images": {0: "B", 2: "8H", 3: "8W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "latent": [ + ( + min_batch, + self.config["latent_channels"], + min_latent_height, + min_latent_width, + ), + ( + batch_size, + self.config["latent_channels"], + latent_height, + latent_width, + ), + ( + max_batch, + self.config["latent_channels"], + max_latent_height, + max_latent_width, + ), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "latent": ( + batch_size, + self.config["latent_channels"], + latent_height, + latent_width, + ), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + return torch.randn( + batch_size, + self.config["latent_channels"], + latent_height, + latent_width, + dtype=dtype, + device=self.device, + ) + + +class SD3_VAEDecoderModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + fp16=False, + ): + super(SD3_VAEDecoderModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + ) + self.subfolder = "sd3" + + def get_model(self, torch_inference=""): + dtype = torch.float16 if self.fp16 else torch.float32 + sd3_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + sd3_filename = "sd3_medium.safetensors" + sd3_model_path = f"{sd3_model_dir}/{sd3_filename}" + if not os.path.exists(sd3_model_path): + hf_hub_download( + repo_id=self.path, filename=sd3_filename, local_dir=sd3_model_dir + ) + with safe_open(sd3_model_path, framework="pt", device=self.device) as f: + model = SDVAE(device=self.device, dtype=dtype).eval().cuda() + prefix = "" + if any(k.startswith("first_stage_model.") for k in f.keys()): + prefix = "first_stage_model." + load_into(f, model, prefix, self.device, dtype) + model.forward = model.decode + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return { + "latent": {0: "B", 2: "H", 3: "W"}, + "images": {0: "B", 2: "8H", 3: "8W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "latent": [ + (min_batch, 16, min_latent_height, min_latent_width), + (batch_size, 16, latent_height, latent_width), + (max_batch, 16, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "latent": (batch_size, 16, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + dtype = torch.float16 if self.fp16 else torch.float32 + return torch.randn( + batch_size, 16, latent_height, latent_width, dtype=dtype, device=self.device + ) + + +class VAEDecTemporalModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size=16, + decode_chunk_size=14, + ): + super(VAEDecTemporalModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + max_batch_size=max_batch_size, + ) + self.subfolder = "vae" + self.decode_chunk_size = decode_chunk_size + + def get_model(self, torch_inference=""): + vae_decoder_model_path = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + if not os.path.exists(vae_decoder_model_path): + model = AutoencoderKLTemporalDecoder.from_pretrained( + self.path, + subfolder=self.subfolder, + use_safetensors=self.hf_safetensor, + token=self.hf_token, + ).to(self.device) + model.save_pretrained(vae_decoder_model_path) + else: + print( + f"[I] Load AutoencoderKLTemporalDecoder model from: {vae_decoder_model_path}" + ) + model = AutoencoderKLTemporalDecoder.from_pretrained( + vae_decoder_model_path + ).to(self.device) + model.forward = model.decode + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["latent", "num_frames_in"] + + def get_output_names(self): + return ["frames"] + + def get_dynamic_axes(self): + return { + "latent": {0: "num_frames_in", 2: "H", 3: "W"}, + "frames": {0: "num_frames_in", 2: "8H", 3: "8W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + assert batch_size == 1 + ( + _, + _, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "latent": [ + (1, 4, min_latent_height, min_latent_width), + (self.decode_chunk_size, 4, latent_height, latent_width), + (self.decode_chunk_size, 4, max_latent_height, max_latent_width), + ], + "num_frames_in": [(1,), (1,), (1,)], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + assert batch_size == 1 + return { + "latent": (self.decode_chunk_size, 4, latent_height, latent_width), + #'num_frames_in': (1,), + "frames": (self.decode_chunk_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + assert batch_size == 1 + return ( + torch.randn( + self.decode_chunk_size, + 4, + latent_height, + latent_width, + dtype=torch.float32, + device=self.device, + ), + self.decode_chunk_size, + ) + + +class TorchVAEEncoder(torch.nn.Module): + def __init__( + self, + version, + pipeline, + hf_token, + device, + path, + framework_model_dir, + subfolder, + fp16=False, + bf16=False, + hf_safetensor=False, + ): + super().__init__() + model_opts = ( + {"torch_dtype": torch.float16} + if fp16 + else {"torch_dtype": torch.bfloat16} + if bf16 + else {} + ) + vae_encoder_model_dir = load.get_checkpoint_dir( + framework_model_dir, version, pipeline, subfolder + ) + if not load.is_model_cached(vae_encoder_model_dir, model_opts, hf_safetensor): + self.vae_encoder = AutoencoderKL.from_pretrained( + path, + subfolder="vae", + use_safetensors=hf_safetensor, + token=hf_token, + **model_opts, + ).to(device) + self.vae_encoder.save_pretrained(vae_encoder_model_dir, **model_opts) + else: + print( + f"[I] Load AutoencoderKL (encoder) model from: {vae_encoder_model_dir}" + ) + self.vae_encoder = AutoencoderKL.from_pretrained( + vae_encoder_model_dir, **model_opts + ).to(device) + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoderModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + fp16=False, + tf32=False, + bf16=False, + max_batch_size=16, + ): + super(VAEEncoderModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + tf32=tf32, + bf16=bf16, + max_batch_size=max_batch_size, + ) + self.subfolder = "vae" + self.vae_encoder_model_dir = load.get_checkpoint_dir( + framework_model_dir, version, self.pipeline, self.subfolder + ) + if not os.path.exists(self.vae_encoder_model_dir): + self.config = AutoencoderKL.load_config( + self.path, subfolder=self.subfolder, token=self.hf_token + ) + else: + print( + f"[I] Load AutoencoderKL (encoder) config from: {self.vae_encoder_model_dir}" + ) + self.config = AutoencoderKL.load_config(self.vae_encoder_model_dir) + + def get_model(self, torch_inference=""): + vae_encoder = TorchVAEEncoder( + self.version, + self.pipeline, + self.hf_token, + self.device, + self.path, + self.framework_model_dir, + self.subfolder, + self.fp16, + self.bf16, + hf_safetensor=self.hf_safetensor, + ) + return vae_encoder + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "images": {0: "B", 2: "8H", 3: "8W"}, + "latent": {0: "B", 2: "H", 3: "W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": ( + batch_size, + self.config["latent_channels"], + latent_height, + latent_width, + ), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + self.check_dims(batch_size, image_height, image_width) + dtype = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + return torch.randn( + batch_size, 3, image_height, image_width, dtype=dtype, device=self.device + ) + + +class SD3_VAEEncoderModel(base_model.BaseModel): + def __init__( + self, + version, + pipeline, + device, + hf_token, + verbose, + framework_model_dir, + max_batch_size, + fp16=False, + ): + super(SD3_VAEEncoderModel, self).__init__( + version, + pipeline, + device=device, + hf_token=hf_token, + verbose=verbose, + framework_model_dir=framework_model_dir, + fp16=fp16, + max_batch_size=max_batch_size, + ) + self.subfolder = "sd3" + + def get_model(self, torch_inference=""): + dtype = torch.float16 if self.fp16 else torch.float32 + sd3_model_dir = load.get_checkpoint_dir( + self.framework_model_dir, self.version, self.pipeline, self.subfolder + ) + sd3_filename = "sd3_medium.safetensors" + sd3_model_path = f"{sd3_model_dir}/{sd3_filename}" + if not os.path.exists(sd3_model_path): + hf_hub_download( + repo_id=self.path, filename=sd3_filename, local_dir=sd3_model_dir + ) + with safe_open(sd3_model_path, framework="pt", device=self.device) as f: + model = SDVAE(device=self.device, dtype=dtype).eval().cuda() + prefix = "" + if any(k.startswith("first_stage_model.") for k in f.keys()): + prefix = "first_stage_model." + load_into(f, model, prefix, self.device, dtype) + model.forward = model.encode + model = optimizer.optimize_checkpoint(model, torch_inference) + return model + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "images": {0: "B", 2: "8H", 3: "8W"}, + "latent": {0: "B", 2: "H", 3: "W"}, + } + + def get_input_profile( + self, batch_size, image_height, image_width, static_batch, static_shape + ): + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_shape + ) + return { + "images": [ + (min_batch, 3, image_height, image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, image_height, image_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims( + batch_size, image_height, image_width + ) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 16, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width, static_shape): + dtype = torch.float16 if self.fp16 else torch.float32 + return torch.randn( + batch_size, 3, image_height, image_width, dtype=dtype, device=self.device + ) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/path/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/path/__init__.py new file mode 100644 index 000000000..d9afdfb7e --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/path/__init__.py @@ -0,0 +1,21 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from demo_diffusion.path.dd_path import DDPath +from demo_diffusion.path.resolve_path import resolve_path + +__all__ = ["DDPath", "resolve_path"] diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/path/dd_path.py b/flux.1-dev-trt-b200/model/demo_diffusion/path/dd_path.py new file mode 100644 index 000000000..78651542e --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/path/dd_path.py @@ -0,0 +1,64 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Define a data structure for storing various paths used in DemoDiffusion. +""" + +import dataclasses +import os +from typing import Dict + + +@dataclasses.dataclass +class DDPath: + """Data class that stores various paths used in DemoDiffusion.""" + + model_name_to_optimized_onnx_path: Dict[str, str] = dataclasses.field( + default_factory=dict + ) + model_name_to_engine_path: Dict[str, str] = dataclasses.field(default_factory=dict) + + # Artifact paths. + model_name_to_unoptimized_onnx_path: Dict[str, str] = dataclasses.field( + default_factory=dict + ) + model_name_to_weights_map_path: Dict[str, str] = dataclasses.field( + default_factory=dict + ) + model_name_to_refit_weights_path: Dict[str, str] = dataclasses.field( + default_factory=dict + ) + model_name_to_quantized_model_state_dict_path: Dict[str, str] = dataclasses.field( + default_factory=dict + ) + + def create_directory(self) -> None: + """Create directories for all paths, if they do not exist.""" + all_paths = [ + value + for name_to_path in dataclasses.astuple(self) + for value in name_to_path.values() + ] + + for path in all_paths: + directory = os.path.dirname(path) + + # If `path` does not have a directory component, `directory` will be an empty string. + # Only proceed if `directory` is non-empty. + if directory: + os.makedirs(directory, exist_ok=True) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/path/resolve_path.py b/flux.1-dev-trt-b200/model/demo_diffusion/path/resolve_path.py new file mode 100644 index 000000000..3842d0f7d --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/path/resolve_path.py @@ -0,0 +1,184 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import argparse +import hashlib +import os +from typing import Dict, List + +import tensorrt as trt + +from demo_diffusion import pipeline +from demo_diffusion.path import dd_path + +ARTIFACT_CACHE_DIRECTORY = os.path.join(os.getcwd(), "artifacts_cache") + + +def resolve_path( + model_names: List[str], + args: argparse.Namespace, + pipeline_type: pipeline.PIPELINE_TYPE, + pipeline_uid: str, +) -> dd_path.DDPath: + """Resolve all paths and store them in a newly constructed dd_path.DDPath object. + + Args: + model_names (List[str]): List of model names. + args (argparse.Namespace): Parsed arguments. + + Returns: + dd_path.DDPath: Path object containing all the resolved paths. + """ + path = dd_path.DDPath() + model_name_to_model_uri = { + model_name: _resolve_model_uri(model_name, args, pipeline_type, pipeline_uid) + for model_name in model_names + } + + _resolve_default_path(model_name_to_model_uri, args, path) + _resolve_custom_path(args, path) + + path.create_directory() + + return path + + +def _resolve_model_uri( + model_name: str, + args: argparse.Namespace, + pipeline_type: pipeline.PIPELINE_TYPE, + pipeline_uid: str, +) -> str: + """Resolve and return the model URI. + + The model URI is a partial path that uniquely identifies the model. It is used to construct various model paths like + artifact cache path, checkpoint path, etc. + """ + # Lora unique ID represents the lora configuration. + if args.lora_path and args.lora_weight: + lora_config_uid = "-".join( + sorted( + [ + f"{hashlib.sha256(lora_path.encode()).hexdigest()}-{lora_weight}-{args.lora_scale}" + for lora_path, lora_weight in zip(args.lora_path, args.lora_weight) + if args.lora_path + ] + ) + ) + else: + lora_config_uid = "" + + # Quantization config unique ID represents the quantization configuration. + def _is_quantized() -> bool: + """Return True if model is quantized, False if otherwise. + + When quantization flags are set in `args`, only a subset of the models are actually quantized. + """ + is_unet = model_name == "unet" + is_unetxl_base = pipeline_type.is_sd_xl_base() and model_name == "unetxl" + is_flux_transformer = ( + args.version.startswith("flux.1") and model_name == "transformer" + ) + + if args.int8: + return is_unet or is_unetxl_base + elif args.fp8: + return is_unet or is_unetxl_base or is_flux_transformer + elif args.fp4: + return is_flux_transformer + else: + return False + + if _is_quantized(): + if args.int8 or args.fp8: + quantization_config_uid = ( + f"{'int8' if args.int8 else 'fp8'}.l{args.quantization_level}.bs2" + f".c{args.calibration_size}.p{args.quantization_percentile}.a{args.quantization_alpha}" + ) + else: + quantization_config_uid = "fp4" + else: + quantization_config_uid = "" + + # Model unique ID represents the model name and its configuration. It is unique under the same pipeline. + model_uid = "_".join( + [s for s in [model_name, lora_config_uid, quantization_config_uid] if s] + ) + + # Model URI is the concatenation of pipeline unique ID and model unique ID. + model_uri = os.path.join(pipeline_uid, model_uid) + + return model_uri + + +def _resolve_default_path( + model_name_to_model_uri: Dict[str, str], + args: argparse.Namespace, + path: dd_path.DDPath, +) -> None: + """Resolve the default paths. + + Args: + model_name_to_model_uri (Dict[str, str]): Dictionary of model name to model URI. + args (argparse.Namespace): Parsed arguments. + path (dd_path.DDPath): Path object. This object is modified in-place to store all resolved default paths. + """ + for model_name, model_uri in model_name_to_model_uri.items(): + path.model_name_to_optimized_onnx_path[model_name] = os.path.join( + args.onnx_dir, model_uri, "model_optimized.onnx" + ) + path.model_name_to_engine_path[model_name] = os.path.join( + args.engine_dir, model_uri, f"engine_trt{trt.__version__}.plan" + ) + + # Resolve artifact paths. + artifact_dir = os.path.join(ARTIFACT_CACHE_DIRECTORY, model_uri) + + path.model_name_to_unoptimized_onnx_path[model_name] = os.path.join( + artifact_dir, "model_unoptimized.onnx" + ) + path.model_name_to_weights_map_path[model_name] = os.path.join( + artifact_dir, "weights_map.json" + ) + path.model_name_to_refit_weights_path[model_name] = os.path.join( + artifact_dir, "refit_weights.json" + ) + path.model_name_to_quantized_model_state_dict_path[model_name] = os.path.join( + artifact_dir, "quantized_model_state_dict.json" + ) + + +def _resolve_custom_path(args: argparse.Namespace, path: dd_path.DDPath) -> None: + """Resolve the custom paths. + + If a different path already exists in `path`, it will be overridden. + + Args: + args (argparse.Namespace): Parsed arguments. + path (dd_path.DDPath): Path object. This object is modified in-place to store or override all resolved paths. + """ + # Resolve and override custom ONNX paths. + if args.custom_onnx_paths: + for model_name, optimized_onnx_path in args.custom_onnx_paths.items(): + path.model_name_to_optimized_onnx_path[model_name] = optimized_onnx_path + + # Resolve and override custom engine paths. + if args.custom_engine_paths: + for model_name, engine_path in args.custom_engine_paths.items(): + path.model_name_to_engine_path[model_name] = engine_path diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/__init__.py new file mode 100644 index 000000000..c77059f08 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/__init__.py @@ -0,0 +1,40 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from demo_diffusion.pipeline.diffusion_pipeline import DiffusionPipeline +from demo_diffusion.pipeline.flux_pipeline import FluxPipeline +from demo_diffusion.pipeline.stable_cascade_pipeline import StableCascadePipeline +from demo_diffusion.pipeline.stable_diffusion_3_pipeline import StableDiffusion3Pipeline +from demo_diffusion.pipeline.stable_diffusion_35_pipeline import ( + StableDiffusion35Pipeline, +) +from demo_diffusion.pipeline.stable_diffusion_pipeline import StableDiffusionPipeline +from demo_diffusion.pipeline.stable_video_diffusion_pipeline import ( + StableVideoDiffusionPipeline, +) +from demo_diffusion.pipeline.type import PIPELINE_TYPE + +__all__ = [ + "DiffusionPipeline", + "FluxPipeline", + "StableCascadePipeline", + "StableDiffusion3Pipeline", + "StableDiffusion35Pipeline", + "StableDiffusionPipeline", + "StableVideoDiffusionPipeline", + "PIPELINE_TYPE", +] diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/calibrate.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/calibrate.py new file mode 100644 index 000000000..29d9bdcca --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/calibrate.py @@ -0,0 +1,37 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +from diffusers.utils import load_image + + +def load_calib_prompts(batch_size, calib_data_path): + with open(calib_data_path, "r", encoding="utf-8") as file: + lst = [line.rstrip("\n") for line in file] + return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)] + + +def load_calibration_images(folder_path): + images = [] + for filename in os.listdir(folder_path): + img_path = os.path.join(folder_path, filename) + if os.path.isfile(img_path): + image = load_image(img_path) + if image is not None: + images.append(image) + return images diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/diffusion_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/diffusion_pipeline.py new file mode 100755 index 000000000..cc835c275 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/diffusion_pipeline.py @@ -0,0 +1,1129 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import abc +import argparse +import gc +import json +import os +import pathlib +import sys +from abc import ABC, abstractmethod +from typing import Any, List + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import nvtx +import torch +from cuda import cudart +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DDPMWuerstchenScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + FlowMatchEulerDiscreteScheduler, + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UniPCMultistepScheduler, +) +from torch.utils.data import DataLoader + +import demo_diffusion.engine as engine_module +import demo_diffusion.image as image_module +from demo_diffusion.model import ( + make_scheduler, + merge_loras, + unload_torch_model, +) +from demo_diffusion.pipeline.calibrate import load_calib_prompts +from demo_diffusion.pipeline.model_memory_manager import ModelMemoryManager +from demo_diffusion.pipeline.type import PIPELINE_TYPE +from demo_diffusion.utils_modelopt import ( + SD_FP8_BF16_FLUX_MMDIT_BMM2_FP8_OUTPUT_CONFIG, + SD_FP8_FP16_DEFAULT_CONFIG, + SD_FP8_FP32_DEFAULT_CONFIG, + PromptImageDataset, + SameSizeSampler, + check_lora, + custom_collate, + filter_func, + filter_func_no_proj_out, + fp8_mha_disable, + generate_fp8_scales, + get_int8_config, + infinite_dataloader, + quantize_lvl, + set_fmha, + set_quant_precision, +) + + +class DiffusionPipeline(ABC): + """ + Application showcasing the acceleration of Stable Diffusion pipelines using NVidia TensorRT. + """ + + VALID_DIFFUSION_PIPELINES = ( + "1.4", + "1.5", + "dreamshaper-7", + "2.0-base", + "2.0", + "2.1-base", + "2.1", + "xl-1.0", + "xl-turbo", + "svd-xt-1.1", + "sd3", + "3.5-medium", + "3.5-large", + "cascade", + "flux.1-dev", + "flux.1-dev-canny", + "flux.1-dev-depth", + "flux.1-schnell", + ) + SCHEDULER_DEFAULTS = { + "1.4": "PNDM", + "1.5": "PNDM", + "dreamshaper-7": "PNDM", + "2.0-base": "DDIM", + "2.0": "DDIM", + "2.1-base": "PNDM", + "2.1": "DDIM", + "xl-1.0": "Euler", + "xl-turbo": "EulerA", + "3.5-large": "FlowMatchEuler", + "3.5-medium": "FlowMatchEuler", + "svd-xt-1.1": "Euler", + "cascade": "DDPMWuerstchen", + "flux.1-dev": "FlowMatchEuler", + "flux.1-dev-canny": "FlowMatchEuler", + "flux.1-dev-depth": "FlowMatchEuler", + "flux.1-schnell": "FlowMatchEuler", + } + + def __init__( + self, + dd_path, + version="1.5", + pipeline_type=PIPELINE_TYPE.TXT2IMG, + bf16=False, + max_batch_size=16, + denoising_steps=30, + scheduler=None, + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + return_latents=False, + low_vram=False, + torch_inference="", + torch_fallback=None, + weight_streaming=False, + text_encoder_weight_streaming_budget_percentage=None, + denoiser_weight_streaming_budget_percentage=None, + ): + """ + Initializes the Diffusion pipeline. + + Args: + dd_path (load_module.DDPath): DDPath object that contains all paths used in DemoDiffusion. + version (str): + The version of the pipeline. Should be one of the values listed in DiffusionPipeline.VALID_DIFFUSION_PIPELINES. + pipeline_type (PIPELINE_TYPE): + Task performed by the current pipeline. Should be one of PIPELINE_TYPE.__members__. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + bf16 (`bool`, defaults to False): + Whether to run the pipeline in BFloat16 precision. + denoising_steps (int): + The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense of slower inference. + scheduler (str): + The scheduler to guide the denoising process. Must be one of the values listed in DiffusionPipeline.SCHEDULER_DEFAULTS.values(). + lora_scale (float): + Controls how much to influence the outputs with the LoRA parameters. (must between 0 and 1). + lora_weight (float): + The LoRA adapter(s) weights to use with the UNet. (must between 0 and 1). + lora_path (str): + Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'. + device (str): + PyTorch device to run inference. Default: 'cuda'. + output_dir (str): + Output directory for log files and image artifacts. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference. + framework_model_dir (str): + cache directory for framework checkpoints. + return_latents (bool): + Skip decoding the image and return latents instead. + low_vram (bool): + [FLUX only] Optimize for low VRAM usage, possibly at the expense of inference performance. Disabled by default. + torch_inference (str): + Run inference with PyTorch (using specified compilation mode) instead of TensorRT. The compilation mode specified should be one of ['eager', 'reduce-overhead', 'max-autotune']. + torch_fallback (str): + [FLUX only] Comma separated list of models to be inferenced using PyTorch instead of TRT. For example --torch-fallback t5,transformer. If --torch-inference set, this parameter will be ignored. + weight_streaming (`bool`, defaults to False): + Whether to enable weight streaming during TensorRT engine build. + text_encoder_ws_budget_percentage (`int`, defaults to None): + Weight streaming budget as a percentage of the size of total streamable weights for the text encoder model. + denoiser_weight_streaming_budget_percentage (`int`, defaults to None): + Weight streaming budget as a percentage of the size of total streamable weights for the denoiser model. + """ + self.bf16 = bf16 + self.dd_path = dd_path + + self.denoising_steps = denoising_steps + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + self.version = version + self.pipeline_type = pipeline_type + self.return_latents = return_latents + + self.low_vram = low_vram + self.weight_streaming = weight_streaming + self.text_encoder_weight_streaming_budget_percentage = ( + text_encoder_weight_streaming_budget_percentage + ) + self.denoiser_weight_streaming_budget_percentage = ( + denoiser_weight_streaming_budget_percentage + ) + + self.stages = self.get_model_names(self.pipeline_type) + # config to store additional info + self.config = {} + if torch_fallback: + assert type(torch_fallback) is list + for model_name in torch_fallback: + if model_name not in self.stages: + raise ValueError( + f'Model "{model_name}" set in --torch-fallback does not exist' + ) + self.config[model_name.replace("-", "_") + "_torch_fallback"] = True + print(f"[I] Setting torch_fallback for {model_name} model.") + + if not scheduler: + scheduler = ( + "UniPC" + if self.pipeline_type.is_controlnet() + else self.SCHEDULER_DEFAULTS.get(version, "DDIM") + ) + print(f"[I] Autoselected scheduler: {scheduler}") + + scheduler_class_map = { + "DDIM": DDIMScheduler, + "DDPM": DDPMScheduler, + "EulerA": EulerAncestralDiscreteScheduler, + "Euler": EulerDiscreteScheduler, + "LCM": LCMScheduler, + "LMSD": LMSDiscreteScheduler, + "PNDM": PNDMScheduler, + "UniPC": UniPCMultistepScheduler, + "DDPMWuerstchen": DDPMWuerstchenScheduler, + "FlowMatchEuler": FlowMatchEulerDiscreteScheduler, + } + try: + scheduler_class = scheduler_class_map[scheduler] + except KeyError: + raise ValueError( + f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class.keys())}." + ) + self.scheduler = make_scheduler( + scheduler_class, version, pipeline_type, hf_token, framework_model_dir + ) + + self.torch_inference = torch_inference + if self.torch_inference: + torch._inductor.config.conv_1x1_as_mm = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_check_all_directions = True + self.use_cuda_graph = use_cuda_graph + + # initialized in load_engines() + self.models = {} + self.torch_models = {} + self.engine = {} + self.shape_dicts = {} + self.shared_device_memory = None + self.lora_loader = None + + # initialized in load_resources() + self.events = {} + self.generator = None + self.markers = {} + self.seed = None + self.stream = None + self.tokenizer = None + + def model_memory_manager(self, model_names, low_vram=False): + return ModelMemoryManager(self, model_names, low_vram) + + @classmethod + @abc.abstractmethod + def FromArgs( + cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE + ) -> DiffusionPipeline: + """Factory method to construct a concrete pipeline object from parsed arguments.""" + raise NotImplementedError( + "FromArgs cannot be called from the abstract base class." + ) + + @classmethod + @abc.abstractmethod + def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + """Return a list of model names used by this pipeline.""" + raise NotImplementedError( + "get_model_names cannot be called from the abstract base class." + ) + + @classmethod + def _get_pipeline_uid(cls, version: str) -> str: + """Return the unique ID of this pipeline. + + This is typically used to determine the default path for things like engine files, artifacts caches, etc. + """ + return f"{cls.__name__}_{version}" + + def profile_start(self, name, color="blue", domain=None): + if self.nvtx_profile: + self.markers[name] = nvtx.start_range( + message=name, color=color, domain=domain + ) + if name in self.events: + cudart.cudaEventRecord(self.events[name][0], 0) + + def profile_stop(self, name): + if name in self.events: + cudart.cudaEventRecord(self.events[name][1], 0) + if self.nvtx_profile: + nvtx.end_range(self.markers[name]) + + def load_resources(self, image_height, image_width, batch_size, seed): + # Initialize noise generator + if seed is not None: + self.seed = seed + self.generator = torch.Generator(device="cuda").manual_seed(seed) + + # Create CUDA events and stream + for stage in self.stages: + self.events[stage] = [ + cudart.cudaEventCreate()[1], + cudart.cudaEventCreate()[1], + ] + self.stream = cudart.cudaStreamCreate()[1] + + # Allocate TensorRT I/O buffers + if not self.torch_inference: + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self.shape_dicts[model_name] = obj.get_shape_dict( + batch_size, image_height, image_width + ) + if not self.low_vram: + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict( + batch_size, image_height, image_width + ), + device=self.device, + ) + + @abstractmethod + def _initialize_models(self, *args, **kwargs): + raise NotImplementedError("Please Implement the _initialize_models method") + + def _prepare_model_configs(self, enable_refit, int8, fp8, fp4): + model_names = self.models.keys() + self.torch_fallback = dict( + zip( + model_names, + [ + self.torch_inference + or self.config.get( + model_name.replace("-", "_") + "_torch_fallback", False + ) + for model_name in model_names + ], + ) + ) + + configs = {} + for model_name in model_names: + # Initialize config + do_engine_refit = ( + enable_refit + and not self.pipeline_type.is_sd_xl_refiner() + and any( + model_name.startswith(prefix) for prefix in ("unet", "transformer") + ) + ) + do_lora_merge = ( + not enable_refit + and self.lora_loader + and any( + model_name.startswith(prefix) for prefix in ("unet", "transformer") + ) + ) + + config = { + "do_engine_refit": do_engine_refit, + "do_lora_merge": do_lora_merge, + "use_int8": False, + "use_fp8": False, + "use_fp4": False, + } + + # TODO: Move this to when arguments are first being validated in dd_argparse.py + # 8-bit/4-bit precision inference + if int8: + assert self.pipeline_type.is_sd_xl_base() or self.version in [ + "1.5", + "2.1", + "2.1-base", + ], "int8 quantization only supported for SDXL, SD1.5 and SD2.1 pipeline" + if (self.pipeline_type.is_sd_xl() and model_name == "unetxl") or ( + model_name == "unet" + ): + config["use_int8"] = True + + elif fp8: + assert ( + self.pipeline_type.is_sd_xl() + or self.version in ["1.5", "2.1", "2.1-base"] + or self.version.startswith("flux.1") + ), ( + "fp8 quantization only supported for SDXL, SD1.5, SD2.1 and FLUX pipeline" + ) + if ( + (self.pipeline_type.is_sd_xl() and model_name == "unetxl") + or ( + (self.version.startswith("flux.1")) + and model_name == "transformer" + ) + or (model_name == "unet") + ): + config["use_fp8"] = True + elif fp4: + config["use_fp4"] = True + + # Setup paths + config["onnx_path"] = self.dd_path.model_name_to_unoptimized_onnx_path[ + model_name + ] + config["onnx_opt_path"] = self.dd_path.model_name_to_optimized_onnx_path[ + model_name + ] + config["engine_path"] = self.dd_path.model_name_to_engine_path[model_name] + config["weights_map_path"] = ( + self.dd_path.model_name_to_weights_map_path[model_name] + if config["do_engine_refit"] + else None + ) + config["state_dict_path"] = ( + self.dd_path.model_name_to_quantized_model_state_dict_path[model_name] + ) + config["refit_weights_path"] = ( + self.dd_path.model_name_to_refit_weights_path[model_name] + ) + + configs[model_name] = config + + return configs + + def _calibrate_and_save_model( + self, + pipeline, + model, + model_config, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + enable_lora_merge=False, + **kwargs, + ): + print( + f"[I] Calibrated weights not found, generating {model_config['state_dict_path']}" + ) + + # TODO check size > calibration_size + def do_calibrate(pipeline, calibration_prompts, **kwargs): + for i_th, prompts in enumerate(calibration_prompts): + if i_th >= kwargs["calib_size"]: + return + if kwargs["model_id"] in ("flux.1-dev", "flux.1-schnell"): + common_args = { + "prompt": prompts, + "prompt_2": prompts, + "num_inference_steps": kwargs["n_steps"], + "height": kwargs.get("height", 1024), + "width": kwargs.get("width", 1024), + "guidance_scale": 3.5, + "max_sequence_length": 512 + if kwargs["model_id"] == "flux.1-dev" + else 256, + } + else: + common_args = { + "prompt": prompts, + "num_inference_steps": kwargs["n_steps"], + "negative_prompt": [ + "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" + ] + * len(prompts), + } + + pipeline(**common_args).images + + def do_calibrate_img2img(pipeline, dataloader, **kwargs): + for i_th, (img_conds, prompts) in enumerate(dataloader): + if i_th >= kwargs["calib_size"]: + return + + common_args = { + "prompt": list(prompts), + "control_image": img_conds, + "num_inference_steps": kwargs["n_steps"], + "height": img_conds.size(2), + "width": img_conds.size(3), + "generator": torch.Generator().manual_seed(42), + "guidance_scale": 3.5, + "max_sequence_length": 512, + } + pipeline(**common_args).images + + if self.version in ("flux.1-dev-depth", "flux.1-dev-canny"): + dataset = PromptImageDataset( + root_dir=self.calibration_dataset, + ) + + dataloader = DataLoader( + dataset, + batch_size=calib_batch_size, + shuffle=False, + num_workers=0, + sampler=SameSizeSampler(dataset=dataset, batch_size=calib_batch_size), + collate_fn=custom_collate, + ) + else: + root_dir = os.path.dirname( + os.path.abspath(sys.modules["__main__"].__file__) + ) + calibration_file = os.path.join( + root_dir, "calibration_data", "calibration-prompts.txt" + ) + calibration_prompts = load_calib_prompts(calib_batch_size, calibration_file) + + def forward_loop(model): + if self.version not in ( + "sd3", + "flux.1-dev", + "flux.1-schnell", + "flux.1-dev-depth", + "flux.1-dev-canny", + ): + pipeline.unet = model + else: + pipeline.transformer = model + + if self.version in ("flux.1-dev-depth", "flux.1-dev-canny"): + do_calibrate_img2img( + pipeline=pipeline, + dataloader=infinite_dataloader(dataloader), + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + model_id=self.version, + ) + else: + do_calibrate( + pipeline=pipeline, + calibration_prompts=calibration_prompts, + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + model_id=self.version, + **kwargs, + ) + + print(f"[I] Performing calibration for {calibration_size} steps.") + if model_config["use_int8"]: + quant_config = get_int8_config( + model, + quantization_level, + quantization_alpha, + quantization_percentile, + self.denoising_steps, + ) + elif model_config["use_fp8"]: + if self.version.startswith("flux.1"): + quant_config = SD_FP8_BF16_FLUX_MMDIT_BMM2_FP8_OUTPUT_CONFIG + elif self.version == "2.1": + quant_config = SD_FP8_FP32_DEFAULT_CONFIG + else: + quant_config = SD_FP8_FP16_DEFAULT_CONFIG + + # Handle LoRA + if enable_lora_merge: + assert self.lora_loader is not None + model = merge_loras(model, self.lora_loader) + + check_lora(model) + + if self.version.startswith("flux.1"): + set_quant_precision(quant_config, "BFloat16") + mtq.quantize(model, quant_config, forward_loop) + mto.save(model, model_config["state_dict_path"]) + + def _get_quantized_model( + self, + obj, + model_config, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + enable_lora_merge=False, + **kwargs, + ): + pipeline = obj.get_pipeline() + is_flux = self.version.startswith("flux.1") + model = ( + pipeline.unet + if self.version + not in ( + "sd3", + "flux.1-dev", + "flux.1-schnell", + "flux.1-dev-depth", + "flux.1-dev-canny", + ) + else pipeline.transformer + ) + if model_config["use_fp8"] and quantization_level == 4.0: + set_fmha(model, is_flux=is_flux) + + if not os.path.exists(model_config["state_dict_path"]): + self._calibrate_and_save_model( + pipeline, + model, + model_config, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + enable_lora_merge, + **kwargs, + ) + else: + mto.restore(model, model_config["state_dict_path"]) + + if not os.path.exists(model_config["onnx_path"]): + quantize_lvl(self.version, model, quantization_level) + if self.version.startswith("flux.1"): + mtq.disable_quantizer(model, filter_func_no_proj_out) + else: + mtq.disable_quantizer(model, filter_func) + if model_config["use_fp8"] and not self.version.startswith("flux.1"): + generate_fp8_scales(model) + if quantization_level == 4.0: + fp8_mha_disable( + model, quantized_mha_output=False + ) # Remove Q/DQ after BMM2 in MHA + else: + model = None + + return model + + @abstractmethod + def download_onnx_models( + self, model_name: str, model_config: dict[str, Any] + ) -> None: + """Download pre-exported ONNX Models""" + raise NotImplementedError("Please Implement the download_onnx_models method") + + def is_native_export_supported(self, model_config: dict[str, Any]) -> bool: + """Check if pipeline supports native ONNX export""" + # Native export is supported by default + return True + + def _export_onnx( + self, + obj, + model_name, + model_config, + opt_image_height, + opt_image_width, + static_shape, + onnx_opset, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + onnx_export_only, + download_onnx_models, + ): + # With onnx_export_only True, the export still happens even if the TRT engine exists. However, it will not re-run the export if the onnx exists. + do_export_onnx = ( + not os.path.exists(model_config["engine_path"]) or onnx_export_only + ) and not os.path.exists(model_config["onnx_opt_path"]) + do_export_weights_map = model_config["weights_map_path"] and not os.path.exists( + model_config["weights_map_path"] + ) + + # If ONNX export is required, either download ONNX models or check if the pipeline supports native ONNX export + if do_export_onnx: + if download_onnx_models: + self.download_onnx_models(model_name, model_config) + do_export_onnx = False + else: + self.is_native_export_supported(model_config) + + if do_export_onnx or do_export_weights_map: + if not model_config["use_int8"] and not model_config["use_fp8"]: + obj.export_onnx( + model_config["onnx_path"], + model_config["onnx_opt_path"], + onnx_opset, + opt_image_height, + opt_image_width, + enable_lora_merge=model_config["do_lora_merge"], + static_shape=static_shape, + lora_loader=self.lora_loader, + ) + else: + print( + f"[I] Generating quantized ONNX model: {model_config['onnx_path']}" + ) + quantized_model = self._get_quantized_model( + obj, + model_config, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + height=opt_image_width, + width=opt_image_width, + enable_lora_merge=model_config["do_lora_merge"], + ) + obj.export_onnx( + model_config["onnx_path"], + model_config["onnx_opt_path"], + onnx_opset, + opt_image_height, + opt_image_width, + custom_model=quantized_model, + static_shape=static_shape, + ) + + # FIXME do_export_weights_map needs ONNX graph + if do_export_weights_map: + print(f"[I] Saving weights map: {model_config['weights_map_path']}") + obj.export_weights_map( + model_config["onnx_opt_path"], model_config["weights_map_path"] + ) + + def _build_engine( + self, + obj, + engine, + model_config, + opt_batch_size, + opt_image_height, + opt_image_width, + optimization_level, + static_batch, + static_shape, + enable_all_tactics, + timing_cache, + ): + update_output_names = ( + obj.get_output_names() + obj.extra_output_names + if obj.extra_output_names + else None + ) + fp16amp = ( + False + if (model_config["use_fp8"] or getattr(obj, "build_strongly_typed", False)) + else obj.fp16 + ) + tf32amp = obj.tf32 + bf16amp = ( + False + if (model_config["use_fp8"] or getattr(obj, "build_strongly_typed", False)) + else obj.bf16 + ) + strongly_typed = ( + True + if (model_config["use_fp8"] or getattr(obj, "build_strongly_typed", False)) + else False + ) + weight_streaming = getattr(obj, "weight_streaming", False) + int8amp = model_config.get("use_int8", False) + precision_constraints = "prefer" if int8amp else "none" + engine.build( + model_config["onnx_opt_path"], + strongly_typed=strongly_typed, + fp16=fp16amp, + tf32=tf32amp, + bf16=bf16amp, + int8=int8amp, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_refit=model_config["do_engine_refit"], + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=update_output_names, + weight_streaming=weight_streaming, + verbose=self.verbose, + builder_optimization_level=optimization_level, + precision_constraints=precision_constraints, + ) + + def _refit_engine(self, obj, model_name, model_config): + assert model_config["weights_map_path"] + with open(model_config["weights_map_path"], "r") as fp_wts: + print(f"[I] Loading weights map: {model_config['weights_map_path']} ") + [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) + + if not os.path.exists(model_config["refit_weights_path"]): + model = merge_loras(obj.get_model(), self.lora_loader) + refit_weights, updated_weight_names = engine_module.get_refit_weights( + model.state_dict(), + model_config["onnx_opt_path"], + weights_name_mapping, + weights_shape_mapping, + ) + print(f"[I] Saving refit weights: {model_config['refit_weights_path']}") + torch.save( + (refit_weights, updated_weight_names), + model_config["refit_weights_path"], + ) + unload_torch_model(model) + else: + print( + f"[I] Loading refit weights: {model_config['refit_weights_path']}" + ) + refit_weights, updated_weight_names = torch.load( + model_config["refit_weights_path"] + ) + self.engine[model_name].refit(refit_weights, updated_weight_names) + + def _load_torch_models(self): + # Load torch models + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + self.torch_models[model_name] = obj.get_model( + torch_inference=self.torch_inference + ) + if self.low_vram: + self.torch_models[model_name] = self.torch_models[model_name].to( + "cpu" + ) + torch.cuda.empty_cache() + + def load_engines( + self, + framework_model_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + optimization_level=3, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + int8=False, + fp8=False, + fp4=False, + quantization_level=2.5, + quantization_percentile=1.0, + quantization_alpha=0.8, + calibration_size=32, + calib_batch_size=2, + onnx_export_only=False, + download_onnx_models=False, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + framework_model_dir (str): + Directory to store the framework model ckpt. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + optimization_level (int): + Optimization level to build the TensorRT engine with. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to speed up TensorRT build. + int8 (bool): + Whether to quantize to int8 format or not (SDXL, SD15 and SD21 only). + fp8 (bool): + Whether to quantize to fp8 format or not (SDXL, SD15 and SD21 only). + quantization_level (float): + Controls which layers to quantize. 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC + quantization_percentile (float): + Control quantization scaling factors (amax) collecting range, where the minimum amax in + range(n_steps * percentile) will be collected. Recommendation: 1.0 + quantization_alpha (float): + The alpha parameter for SmoothQuant quantization used for linear layers. + Recommendation: 0.8 for SDXL + calibration_size (int): + The number of steps to use for calibrating the model for quantization. + Recommendation: 32, 64, 128 for SDXL + calib_batch_size (int): + The batch size to use for calibration. Defaults to 2. + onnx_export_only (bool): + Whether only export onnx without building the TRT engine. + download_onnx_models (bool): + Download pre-exported ONNX models + """ + self._initialize_models(framework_model_dir, int8, fp8, fp4) + + model_configs = self._prepare_model_configs(enable_refit, int8, fp8, fp4) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self._export_onnx( + obj, + model_name, + model_configs[model_name], + opt_image_height, + opt_image_width, + static_shape, + onnx_opset, + quantization_level, + quantization_percentile, + quantization_alpha, + calibration_size, + calib_batch_size, + onnx_export_only, + download_onnx_models, + ) + + # Release temp GPU memory during onnx export to avoid OOM. + gc.collect() + torch.cuda.empty_cache() + + if onnx_export_only: + return + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + + model_config = model_configs[model_name] + engine = engine_module.Engine(model_config["engine_path"]) + if not os.path.exists(model_config["engine_path"]): + self._build_engine( + obj, + engine, + model_config, + opt_batch_size, + opt_image_height, + opt_image_width, + optimization_level, + static_batch, + static_shape, + enable_all_tactics, + timing_cache, + ) + self.engine[model_name] = engine + + # Load and refit TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + model_config = model_configs[model_name] + + # For non low_vram case, the engines will remain in GPU memory from now on. + assert self.engine[model_name].engine is None + if not self.low_vram: + weight_streaming = getattr(obj, "weight_streaming", False) + weight_streaming_budget_percentage = getattr( + obj, "weight_streaming_budget_percentage", None + ) + self.engine[model_name].load( + weight_streaming, weight_streaming_budget_percentage + ) + + if model_config["do_engine_refit"] and self.lora_loader: + # For low_vram, using on-demand load and unload for refit. + if self.low_vram: + assert self.engine[model_name].engine is None + self.engine[model_name].load() + self._refit_engine(obj, model_name, model_config) + if self.low_vram: + self.engine[model_name].unload() + + # Load PyTorch models if torch-inference mode is enabled + self._load_torch_models() + + # Reclaim GPU memory from torch cache + torch.cuda.empty_cache() + + def calculate_max_device_memory(self): + max_device_memory = 0 + for model_name, engine in self.engine.items(): + if self.low_vram: + engine.load() + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + if self.low_vram: + engine.unload() + return max_device_memory + + def get_device_memory_sizes(self): + device_memory_sizes = {} + for model_name, engine in self.engine.items(): + engine.load() + device_memory_sizes[model_name] = engine.engine.device_memory_size + engine.unload() + return device_memory_sizes + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.calculate_max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + if not self.low_vram: + for engine in self.engine.values(): + engine.activate(device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + engine = self.engine[model_name] + # CUDA graphs should be disabled when low_vram is enabled. + if self.low_vram: + assert self.use_cuda_graph == False + return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e[0]) + cudart.cudaEventDestroy(e[1]) + + for engine in self.engine.values(): + del engine + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + for torch_model in self.torch_models.values(): + del torch_model + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def initialize_latents( + self, + batch_size, + unet_channels, + latent_height, + latent_width, + latents_dtype=torch.float32, + ): + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn( + latents_shape, + device=self.device, + dtype=latents_dtype, + generator=self.generator, + ) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def save_image(self, images, pipeline, prompt, seed): + # Save image + prompt_prefix = "".join( + set([prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))]) + ) + image_name_prefix = "-".join([pipeline, prompt_prefix, str(seed)]) + image_name_suffix = "torch" if self.torch_inference else "trt" + image_module.save_image( + images, self.output_dir, image_name_prefix, image_name_suffix + ) + + @abstractmethod + def print_summary(self): + """Print a summary of the pipeline's configuration.""" + raise NotImplementedError("Please Implement the print_summary method") + + @abstractmethod + def infer(self): + """Perform inference using the pipeline.""" + raise NotImplementedError("Please Implement the infer method") + + @abstractmethod + def run(self): + """Run the pipeline.""" + raise NotImplementedError("Please Implement the run method") diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/flux_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/flux_pipeline.py new file mode 100644 index 000000000..356862057 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/flux_pipeline.py @@ -0,0 +1,943 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import argparse +import inspect +import os +import time +import warnings +from typing import Any, List, Optional + +import numpy as np +import tensorrt as trt +import torch +from cuda import cudart +from diffusers.image_processor import VaeImageProcessor +from huggingface_hub import snapshot_download + +from demo_diffusion import path as path_module +from demo_diffusion.model import ( + CLIPModel, + FLUXLoraLoader, + FluxTransformerModel, + T5Model, + VAEEncoderModel, + VAEModel, + get_clip_embedding_dim, + load, + make_tokenizer, +) +from demo_diffusion.pipeline.diffusion_pipeline import DiffusionPipeline +from demo_diffusion.pipeline.type import PIPELINE_TYPE + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def calculate_shift( + image_seq_len, + base_seq_len: int = 256, + max_seq_len: int = 4096, + base_shift: float = 0.5, + max_shift: float = 1.16, +): + m = (max_shift - base_shift) / (max_seq_len - base_seq_len) + b = base_shift - m * base_seq_len + mu = image_seq_len * m + b + return mu + + +class FluxPipeline(DiffusionPipeline): + """ + Application showcasing the acceleration of Flux pipelines using Nvidia TensorRT. + """ + + def __init__( + self, + version="flux.1-dev", + pipeline_type=PIPELINE_TYPE.TXT2IMG, + guidance_scale=3.5, + max_sequence_length=512, + calibration_dataset=None, + t5_weight_streaming_budget_percentage=None, + transformer_weight_streaming_budget_percentage=None, + lora_scale: float = 1.0, + lora_weight: Optional[List[float]] = None, + lora_path: Optional[List[str]] = None, + **kwargs, + ): + """ + Initializes the Flux pipeline. + + Args: + version (`str`, defaults to `flux.1-dev`) + Version of the underlying Flux model. + guidance_scale (`float`, defaults to 3.5): + Guidance scale is enabled by setting as > 1. + Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. + max_sequence_length (`int`, defaults to 512): + Maximum sequence length to use with the `prompt`. + t5_weight_streaming_budget_percentage (`int`, defaults to None): + Weight streaming budget as a percentage of the size of total streamable weights for the T5 model. + transformer_weight_streaming_budget_percentage (`int`, defaults to None): + Weight streaming budget as a percentage of the size of total streamable weights for the transformer model. + """ + super().__init__( + version=version, + pipeline_type=pipeline_type, + text_encoder_weight_streaming_budget_percentage=t5_weight_streaming_budget_percentage, + denoiser_weight_streaming_budget_percentage=transformer_weight_streaming_budget_percentage, + **kwargs, + ) + self.guidance_scale = guidance_scale + self.max_sequence_length = max_sequence_length + self.calibration_dataset = calibration_dataset # Currently supported for Flux ControlNet pipelines only + + # Initialize LoRA + self.lora_loader = None + if lora_path: + self.lora_weights = dict() + self.lora_loader = FLUXLoraLoader(lora_path, lora_weight, lora_scale) + assert len(lora_path) == len(lora_weight) + for i, path in enumerate(lora_path): + self.lora_weights[path] = lora_weight[i] + + @classmethod + def FromArgs( + cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE + ) -> FluxPipeline: + """Factory method to construct a `FluxPipeline` object from parsed arguments. + + Overrides: + DiffusionPipeline.FromArgs + """ + MAX_BATCH_SIZE = 4 + DEVICE = "cuda" + DO_RETURN_LATENTS = False + + # Resolve all paths. + dd_path = path_module.resolve_path( + cls.get_model_names(pipeline_type), + args, + pipeline_type, + cls._get_pipeline_uid(args.version), + ) + + return cls( + dd_path=dd_path, + version=args.version, + pipeline_type=pipeline_type, + guidance_scale=args.guidance_scale, + max_sequence_length=args.max_sequence_length, + bf16=args.bf16, + calibration_dataset=args.calibration_dataset + if hasattr(args, "calibration_dataset") + else None, + low_vram=args.low_vram, + torch_fallback=args.torch_fallback, + weight_streaming=args.ws, + t5_weight_streaming_budget_percentage=args.t5_ws_percentage, + transformer_weight_streaming_budget_percentage=args.transformer_ws_percentage, + max_batch_size=MAX_BATCH_SIZE, + denoising_steps=args.denoising_steps, + scheduler=args.scheduler, + lora_scale=args.lora_scale, + lora_weight=args.lora_weight, + lora_path=args.lora_path, + device=DEVICE, + output_dir=args.output_dir, + hf_token=args.hf_token, + verbose=args.verbose, + nvtx_profile=args.nvtx_profile, + use_cuda_graph=args.use_cuda_graph, + framework_model_dir=args.framework_model_dir, + return_latents=DO_RETURN_LATENTS, + torch_inference=args.torch_inference, + ) + + @classmethod + def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + """Return a list of model names used by this pipeline. + + Overrides: + DiffusionPipeline.get_model_names + """ + if pipeline_type.is_img2img(): + return ["clip", "t5", "transformer", "vae", "vae_encoder"] + else: + return ["clip", "t5", "transformer", "vae"] + + def download_onnx_models( + self, model_name: str, model_config: dict[str, Any] + ) -> None: + if self.fp16: + raise ValueError( + "ONNX models can be downloaded only for the following precisions: BF16, FP8, FP4. This pipeline is running in FP16." + ) + + hf_download_path = "-".join( + [load.get_path(self.version, self.pipeline_type.name), "onnx"] + ) + model_path = model_config["onnx_opt_path"] + base_dir = os.path.dirname(os.path.dirname(model_config["onnx_opt_path"])) + + if not os.path.exists(model_path): + if model_name == "clip": + dirname = "clip.opt" + elif model_name == "t5": + dirname = "t5.opt" + elif model_name == "transformer": + if model_config["use_fp4"]: + dirname = "transformer.opt/fp4" + elif model_config["use_fp8"]: + dirname = "transformer.opt/fp8" + elif self.bf16: + dirname = "transformer.opt/bf16" + elif model_name == "vae": + dirname = "vae.opt" + elif model_name == "vae_encoder": + dirname = "vae_encoder.opt" + else: + raise ValueError(f"{model_name} not found in {self.stages}") + + snapshot_download( + repo_id=hf_download_path, + allow_patterns=os.path.join(dirname, "*"), + local_dir=base_dir, + token=self.hf_token, + ) + # Rename directory from .opt to + saved_dir = os.path.join(base_dir, dirname) + model_dir = os.path.dirname(model_path) + os.rename(saved_dir, model_dir) + # Rename model from model.onnx to model_optimized.onnx + os.rename(os.path.join(model_dir, "model.onnx"), model_path) + + def is_native_export_supported(self, model_config: dict[str, Any]) -> bool: + if self.version.startswith("flux.1") and model_config["use_fp4"]: + # Native export not supported for FP4. + raise ValueError( + f"No ONNX model found in {model_config['onnx_opt_path']}. Please pass --download-onnx-models." + ) + if ( + self.version in ["flux.1-dev-canny", "flux.1-dev-depth"] + and model_config["use_fp8"] + and not self.calibration_dataset + ): + # Native export of FP8 model requires calibration data. + raise ValueError( + f"No ONNX model found in {model_config['onnx_opt_path']}. Please pass --download-onnx-models. If you would like to quantize and export natively, please provide calibration data using --calibration-." + ) + return True + + def _initialize_models(self, framework_model_dir, int8, fp8, fp4): + # Load text tokenizer(s) + self.tokenizer = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + ) + self.tokenizer2 = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + subfolder="tokenizer_2", + tokenizer_type="t5", + ) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + self.bf16 = True if int8 or fp8 or fp4 else self.bf16 + self.fp16 = True if not self.bf16 else False + self.tf32 = True + if "clip" in self.stages: + self.models["clip"] = CLIPModel( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + bf16=self.bf16, + embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), + keep_pooled_output=True, + subfolder="text_encoder", + ) + + if "t5" in self.stages: + # Known accuracy issues with FP16 + self.models["t5"] = T5Model( + **models_args, + fp16=self.fp16, + tf32=self.tf32, + bf16=self.bf16, + subfolder="text_encoder_2", + text_maxlen=self.max_sequence_length, + build_strongly_typed=True, + weight_streaming=self.weight_streaming, + weight_streaming_budget_percentage=self.text_encoder_weight_streaming_budget_percentage, + ) + + if "transformer" in self.stages: + self.models["transformer"] = FluxTransformerModel( + **models_args, + bf16=self.bf16, + fp16=self.fp16, + int8=int8, + fp8=fp8, + tf32=self.tf32, + text_maxlen=self.max_sequence_length, + build_strongly_typed=True, + weight_streaming=self.weight_streaming, + weight_streaming_budget_percentage=self.denoiser_weight_streaming_budget_percentage, + ) + + if "vae" in self.stages: + # Accuracy issues with FP16 + self.models["vae"] = VAEModel( + **models_args, fp16=False, tf32=self.tf32, bf16=self.bf16 + ) + + self.vae_scale_factor = ( + 2 ** (len(self.models["vae"].config["block_out_channels"])) + if "vae" in self.stages and self.models["vae"] is not None + else 16 + ) + + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = VAEEncoderModel( + **models_args, fp16=False, tf32=self.tf32, bf16=self.bf16 + ) + self.vae_latent_channels = ( + self.models["vae"].config["latent_channels"] + if "vae" in self.stages and self.models["vae"] is not None + else 16 + ) + self.image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor * 2, + vae_latent_channels=self.vae_latent_channels, + ) + + def encode_image(self, input_image, encoder="vae_encoder"): + self.profile_start(encoder, color="red") + cast_to = ( + torch.float16 + if self.models[encoder].fp16 + else torch.bfloat16 + if self.models[encoder].bf16 + else torch.float32 + ) + input_image = input_image.to(dtype=cast_to) + if self.torch_inference or self.torch_fallback[encoder]: + image_latents = self.torch_models[encoder](input_image) + else: + image_latents = self.run_engine(encoder, {"images": input_image})["latent"] + + image_latents = self.models[encoder].config["scaling_factor"] * ( + image_latents - self.models[encoder].config["shift_factor"] + ) + self.profile_stop(encoder) + return image_latents + + # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py#L546 + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + if isinstance(image, torch.Tensor): + pass + else: + image = self.image_processor.preprocess(image, height=height, width=width) + + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L436 + @staticmethod + def _pack_latents(latents, batch_size, num_channels_latents, height, width): + """ + Reshapes latents from (B, C, H, W) to (B, H/2, W/2, C*4) as expected by the denoiser + """ + latents = latents.view( + batch_size, num_channels_latents, height // 2, 2, width // 2, 2 + ) + latents = latents.permute(0, 2, 4, 1, 3, 5) + latents = latents.reshape( + batch_size, (height // 2) * (width // 2), num_channels_latents * 4 + ) + + return latents + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L444 + @staticmethod + def _unpack_latents(latents, height, width, vae_scale_factor): + """ + Reshapes denoised latents to the format (B, C, H, W) + """ + batch_size, num_patches, channels = latents.shape + + height = height // vae_scale_factor + width = width // vae_scale_factor + + latents = latents.view(batch_size, height, width, channels // 4, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + + latents = latents.reshape( + batch_size, channels // (2 * 2), height * 2, width * 2 + ) + + return latents + + # Copied from https://github.com/huggingface/diffusers/blob/v0.30.1/src/diffusers/pipelines/flux/pipeline_flux.py#L421 + @staticmethod + def _prepare_latent_image_ids(height, width, dtype, device): + """ + Prepares latent image indices + """ + latent_image_ids = torch.zeros(height // 2, width // 2, 3) + latent_image_ids[..., 1] = ( + latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + ) + latent_image_ids[..., 2] = ( + latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + ) + + latent_image_id_height, latent_image_id_width, latent_image_id_channels = ( + latent_image_ids.shape + ) + + latent_image_ids = latent_image_ids.reshape( + latent_image_id_height * latent_image_id_width, latent_image_id_channels + ) + + return latent_image_ids.to(device=device, dtype=dtype) + + def initialize_latents( + self, + batch_size, + num_channels_latents, + latent_height, + latent_width, + latent_timestep=None, + image_latents=None, + latents_dtype=torch.float32, + ): + latents_dtype = latents_dtype # text_embeddings.dtype + latents_shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = torch.randn( + latents_shape, + device=self.device, + dtype=latents_dtype, + generator=self.generator, + ) + + if image_latents is not None: + image_latents = torch.cat([image_latents], dim=0).to(latents_dtype) + latents = self.scheduler.scale_noise( + image_latents, latent_timestep, latents + ) + + latents = self._pack_latents( + latents, batch_size, num_channels_latents, latent_height, latent_width + ) + + latent_image_ids = self._prepare_latent_image_ids( + latent_height, latent_width, latents_dtype, self.device + ) + + return latents, latent_image_ids + + # Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux/pipeline_flux_img2img.py#L416C1 + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(num_inference_steps * strength, num_inference_steps) + + t_start = int(max(num_inference_steps - init_timestep, 0)) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + if hasattr(self.scheduler, "set_begin_index"): + self.scheduler.set_begin_index(t_start * self.scheduler.order) + + return timesteps, num_inference_steps - t_start + + def encode_prompt( + self, prompt, encoder="clip", max_sequence_length=None, pooled_output=False + ): + self.profile_start(encoder, color="green") + + tokenizer = self.tokenizer2 if encoder == "t5" else self.tokenizer + max_sequence_length = ( + tokenizer.model_max_length + if max_sequence_length is None + else max_sequence_length + ) + + def tokenize(prompt, max_sequence_length): + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + return_overflowing_tokens=False, + return_length=False, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + untruncated_ids = ( + tokenizer(prompt, padding="longest", return_tensors="pt") + .input_ids.type(torch.int32) + .to(self.device) + ) + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, max_sequence_length - 1 : -1] + ) + warnings.warn( + "The following part of your input was truncated because `max_sequence_length` is set to " + f"{max_sequence_length} tokens: {removed_text}" + ) + + if self.torch_inference or self.torch_fallback[encoder]: + outputs = self.torch_models[encoder]( + text_input_ids, output_hidden_states=False + ) + text_encoder_output = ( + outputs[0].clone() + if pooled_output == False + else outputs.pooler_output.clone() + ) + else: + # NOTE: output tensor for the encoder must be cloned because it will be overwritten when called again for prompt2 + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + output_name = ( + "text_embeddings" if not pooled_output else "pooled_embeddings" + ) + text_encoder_output = outputs[output_name].clone() + + return text_encoder_output + + # Tokenize prompt + text_encoder_output = tokenize(prompt, max_sequence_length) + + self.profile_stop(encoder) + return ( + text_encoder_output.to(torch.float16) + if self.fp16 + else text_encoder_output.to(torch.bfloat16) + if self.bf16 + else text_encoder_output + ) + + def denoise_latent( + self, + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + denoiser="transformer", + guidance=None, + control_latent=None, + ): + do_autocast = self.torch_inference != "" and self.models[denoiser].fp16 + with torch.autocast("cuda", enabled=do_autocast): + self.profile_start(denoiser, color="blue") + + # handle guidance + if self.models[denoiser].config["guidance_embeds"] and guidance is None: + guidance = torch.full( + [1], self.guidance_scale, device=self.device, dtype=torch.float32 + ) + guidance = guidance.expand(latents.shape[0]) + + for step_index, timestep in enumerate(timesteps): + # Prepare latents + latents_input = ( + latents + if control_latent is None + else torch.cat((latents, control_latent), dim=-1) + ) + # prepare inputs + timestep_inp = timestep.expand(latents.shape[0]).to(latents_input.dtype) + params = { + "hidden_states": latents_input, + "timestep": timestep_inp / 1000, + "pooled_projections": pooled_embeddings, + "encoder_hidden_states": text_embeddings, + "txt_ids": text_ids.float(), + "img_ids": latent_image_ids.float(), + } + if guidance is not None: + params.update({"guidance": guidance}) + + # Predict the noise residual + if self.torch_inference or self.torch_fallback[denoiser]: + noise_pred = self.torch_models[denoiser](**params)["sample"] + else: + noise_pred = self.run_engine(denoiser, params)["latent"] + + latents = self.scheduler.step( + noise_pred, timestep, latents, return_dict=False + )[0] + + self.profile_stop(denoiser) + return ( + latents.to(dtype=torch.bfloat16) + if self.bf16 + else latents.to(dtype=torch.float32) + ) + + def decode_latent(self, latents, decoder="vae"): + self.profile_start(decoder, color="red") + cast_to = ( + torch.float16 + if self.models[decoder].fp16 + else torch.bfloat16 + if self.models[decoder].bf16 + else torch.float32 + ) + latents = latents.to(dtype=cast_to) + if self.torch_inference or self.torch_fallback[decoder]: + images = self.torch_models[decoder](latents, return_dict=False)[0] + else: + images = self.run_engine(decoder, {"latent": latents})["images"] + self.profile_stop(decoder) + return images + + def print_summary(self, denoising_steps, walltime_ms, batch_size): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP", + cudart.cudaEventElapsedTime( + self.events["clip"][0], self.events["clip"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "T5", + cudart.cudaEventElapsedTime(self.events["t5"][0], self.events["t5"][1])[ + 1 + ], + ) + ) + if "vae_encoder" in self.stages: + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime( + self.events["vae_encoder"][0], self.events["vae_encoder"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "Transformer x " + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["transformer"][0], self.events["transformer"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Dec", + cudart.cudaEventElapsedTime( + self.events["vae"][0], self.events["vae"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print("Throughput: {:.2f} image/s".format(batch_size * 1000.0 / walltime_ms)) + + def infer( + self, + prompt, + prompt2, + image_height, + image_width, + input_image=None, + image_strength=1.0, + control_image=None, + warmup=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + prompt2 (str): + The prompt to be sent to the T5 tokenizer and text encoder + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + input_image (PIL.Image.Image): + `Image` representing an image batch to be used as the starting point. + image_strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + control_image (PIL.Image.Image): + The ControlNet input condition to provide guidance to the `transformer` for generation. + warmup (bool): + Indicate if this is a warmup run. + save_image (bool): + Save the generated image (if applicable) + """ + assert len(prompt) == len(prompt2) + batch_size = len(prompt) + + # Spatial dimensions of latent tensor + latent_height = 2 * (int(image_height) // self.vae_scale_factor) + latent_width = 2 * (int(image_width) // self.vae_scale_factor) + + num_inference_steps = self.denoising_steps + latent_kwargs = {} + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + num_channels_latents = self.models["transformer"].config["in_channels"] // 4 + if control_image: + num_channels_latents = ( + self.models["transformer"].config["in_channels"] // 8 + ) + + # Prepare control latents + control_image = self.prepare_image( + image=control_image, + width=image_width, + height=image_height, + batch_size=batch_size, + num_images_per_prompt=1, + device=self.device, + dtype=torch.float16 + if self.models["vae"].fp16 + else torch.bfloat16 + if self.models["vae"].bf16 + else torch.float32, + ) + + if control_image.ndim == 4: + with self.model_memory_manager( + ["vae_encoder"], low_vram=self.low_vram + ): + control_image = self.encode_image(control_image) + + height_control_image, width_control_image = control_image.shape[2:] + control_image = self._pack_latents( + control_image, + batch_size, + num_channels_latents, + height_control_image, + width_control_image, + ) + + # CLIP and T5 text encoder(s) + with self.model_memory_manager(["clip", "t5"], low_vram=self.low_vram): + pooled_embeddings = self.encode_prompt(prompt, pooled_output=True) + text_embeddings = self.encode_prompt( + prompt2, encoder="t5", max_sequence_length=self.max_sequence_length + ) + text_ids = torch.zeros(text_embeddings.shape[1], 3).to( + device=self.device, dtype=text_embeddings.dtype + ) + + # Prepare timesteps + sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) + image_seq_len = (latent_height // 2) * (latent_width // 2) + mu = calculate_shift( + image_seq_len, + self.scheduler.config.base_image_seq_len, + self.scheduler.config.max_image_seq_len, + self.scheduler.config.base_shift, + self.scheduler.config.max_shift, + ) + timesteps = None + # TODO: support custom timesteps + if timesteps is not None: + if ( + "timesteps" + not in inspect.signature(self.scheduler.set_timesteps).parameters + ): + raise ValueError( + f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + self.scheduler.set_timesteps(timesteps=timesteps, device=self.device) + assert self.denoising_steps == len(self.scheduler.timesteps) + else: + self.scheduler.set_timesteps(sigmas=sigmas, mu=mu, device=self.device) + timesteps = self.scheduler.timesteps.to(self.device) + num_inference_steps = len(timesteps) + + # Pre-process input image and timestep for the img2img pipeline + if input_image: + input_image = self.image_processor.preprocess( + input_image, height=image_height, width=image_width + ).to(self.device) + with self.model_memory_manager(["vae_encoder"], low_vram=self.low_vram): + image_latents = self.encode_image(input_image) + + timesteps, num_inference_steps = self.get_timesteps( + self.denoising_steps, image_strength + ) + if num_inference_steps < 1: + raise ValueError( + f"After adjusting the num_inference_steps by strength parameter: {image_strength}, the number of pipeline" + f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline." + ) + latent_timestep = timesteps[:1].repeat(batch_size) + + latent_kwargs.update( + {"image_latents": image_latents, "latent_timestep": latent_timestep} + ) + + # Initialize latents + latents, latent_image_ids = self.initialize_latents( + batch_size=batch_size, + num_channels_latents=num_channels_latents, + latent_height=latent_height, + latent_width=latent_width, + latents_dtype=torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32, + **latent_kwargs, + ) + + # DiT denoiser + with self.model_memory_manager(["transformer"], low_vram=self.low_vram): + latents = self.denoise_latent( + latents, + timesteps, + text_embeddings, + pooled_embeddings, + text_ids, + latent_image_ids, + control_latent=control_image, + ) + + # VAE decode latent + with self.model_memory_manager(["vae"], low_vram=self.low_vram): + latents = self._unpack_latents( + latents, image_height, image_width, self.vae_scale_factor + ) + latents = ( + latents / self.models["vae"].config["scaling_factor"] + ) + self.models["vae"].config["shift_factor"] + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + self.print_summary(num_inference_steps, walltime_ms, batch_size) + if not self.return_latents and save_image: + # post-process images + images = ( + ((images + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return (latents, walltime_ms) if self.return_latents else (images, walltime_ms) + + def run( + self, + prompt, + prompt2, + height, + width, + batch_count, + num_warmup_runs, + use_cuda_graph, + **kwargs, + ): + if self.low_vram and self.use_cuda_graph: + print("[W] Using low_vram, use_cuda_graph will be disabled") + self.use_cuda_graph = False + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer(prompt, prompt2, height, width, warmup=True, **kwargs) + + for _ in range(batch_count): + print("[I] Running Flux pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(prompt, prompt2, height, width, warmup=False, **kwargs) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/model_memory_manager.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/model_memory_manager.py new file mode 100644 index 000000000..47555b48f --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/model_memory_manager.py @@ -0,0 +1,82 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch +from cuda import cudart + + +class ModelMemoryManager: + """ + Context manager for efficiently loading and unloading models to optimize VRAM usage. + + This class provides a context to temporarily load models into GPU memory for inference + and automatically unload them afterward. It's especially useful in low VRAM environments + where models need to be swapped in and out of GPU memory. + + Args: + parent: The parent class instance that contains the model references and resources. + model_names (list): List of model names to load and unload. + low_vram (bool, optional): If True, enables VRAM optimization. If False, the context manager does nothing. Defaults to False. + """ + + def __init__(self, parent, model_names, low_vram=False): + self.parent = parent + self.model_names = model_names + self.low_vram = low_vram + + def __enter__(self): + if not self.low_vram: + return + for model_name in self.model_names: + if not self.parent.torch_fallback[model_name]: + # creating engine object (load from plan file) + self.parent.engine[model_name].load() + # allocate device memory + _, shared_device_memory = cudart.cudaMalloc( + self.parent.device_memory_sizes[model_name] + ) + self.parent.shared_device_memory = shared_device_memory + # creating context + self.parent.engine[model_name].activate( + device_memory=self.parent.shared_device_memory + ) + # creating input and output buffer + self.parent.engine[model_name].allocate_buffers( + shape_dict=self.parent.shape_dicts[model_name], + device=self.parent.device, + ) + else: + print(f"[I] Reloading torch model {model_name} from cpu.") + self.parent.torch_models[model_name] = self.parent.torch_models[ + model_name + ].to("cuda") + + def __exit__(self, exc_type, exc_val, exc_tb): + if not self.low_vram: + return + for model_name in self.model_names: + if not self.parent.torch_fallback[model_name]: + self.parent.engine[model_name].deallocate_buffers() + self.parent.engine[model_name].deactivate() + self.parent.engine[model_name].unload() + cudart.cudaFree(self.parent.shared_device_memory) + else: + print(f"[I] Offloading torch model {model_name} to cpu.") + self.parent.torch_models[model_name] = self.parent.torch_models[ + model_name + ].to("cpu") + torch.cuda.empty_cache() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_cascade_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_cascade_pipeline.py new file mode 100644 index 000000000..316038c7f --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_cascade_pipeline.py @@ -0,0 +1,480 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import inspect +import time + +import tensorrt as trt +import torch +from cuda import cudart +from diffusers import DDPMWuerstchenScheduler + +from demo_diffusion.model import ( + CLIPWithProjModel, + UNetCascadeModel, + VQGANModel, + make_tokenizer, +) +from demo_diffusion.pipeline.stable_diffusion_pipeline import StableDiffusionPipeline +from demo_diffusion.pipeline.type import PIPELINE_TYPE + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +class StableCascadePipeline(StableDiffusionPipeline): + """ + Application showcasing the acceleration of Stable Cascade pipelines using NVidia TensorRT. + """ + + def __init__( + self, + version="cascade", + pipeline_type=PIPELINE_TYPE.CASCADE_PRIOR, + latent_dim_scale=10.67, + lite=False, + **kwargs, + ): + """ + Initializes the Stable Cascade pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of [cascade] + pipeline_type (PIPELINE_TYPE): + Type of current pipeline. + latent_dim_scale (float): + Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are + height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and + width=int(24*10.67)=256 in order to match the training conditions. + lite (bool): + Boolean indicating if the Lite Version of the Stage B and Stage C models is to be used + """ + super().__init__(version=version, pipeline_type=pipeline_type, **kwargs) + self.config["clip_hidden_states"] = True + # from Diffusers: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py#L91C9-L91C41 + self.latent_dim_scale = latent_dim_scale + self.lite = lite + + def initializeModels(self, framework_model_dir, int8, fp8): + # Load text tokenizer(s) + self.tokenizer = make_tokenizer( + self.version, self.pipeline_type, self.hf_token, framework_model_dir + ) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + self.fp16 = False # TODO: enable FP16 mode for decoder model (requires strongly typed engine) + self.bf16 = True + if "clip" in self.stages: + self.models["clip"] = CLIPWithProjModel( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + output_hidden_states=self.config.get("clip_hidden_states", False), + subfolder="text_encoder", + ) + + if "unet" in self.stages: + self.models["unet"] = UNetCascadeModel( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + lite=self.lite, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + if "vqgan" in self.stages: + self.models["vqgan"] = VQGANModel( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + latent_dim_scale=self.latent_dim_scale, + ) + + def encode_prompt( + self, + prompt, + negative_prompt, + encoder="clip", + pooled_outputs=False, + output_hidden_states=False, + ): + self.profile_start("clip", color="green") + + tokenizer = self.tokenizer + + def tokenize(prompt, output_hidden_states): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids.type(torch.int32).to(self.device) + attention_mask = text_inputs.attention_mask.type(torch.int32).to( + self.device + ) + + text_hidden_states = None + if self.torch_inference: + outputs = self.torch_models[encoder]( + text_input_ids, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + ) + text_embeddings = outputs[0].clone() + if output_hidden_states: + hidden_state_layer = -1 + text_hidden_states = outputs["hidden_states"][ + hidden_state_layer + ].clone() + else: + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.runEngine( + encoder, + {"input_ids": text_input_ids, "attention_mask": attention_mask}, + ) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + text_hidden_states = outputs["hidden_states"].clone() + + return text_embeddings, text_hidden_states + + # Tokenize prompt + text_embeddings, text_hidden_states = tokenize(prompt, output_hidden_states) + + if self.do_classifier_free_guidance: + # Tokenize negative prompt + uncond_embeddings, uncond_hidden_states = tokenize( + negative_prompt, output_hidden_states + ) + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([text_embeddings, uncond_embeddings]) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = ( + torch.cat([text_hidden_states, uncond_hidden_states]) + if self.do_classifier_free_guidance + else text_hidden_states + ) + + self.profile_stop("clip") + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + def denoise_latent( + self, + latents, + pooled_embeddings, + text_embeddings=None, + image_embeds=None, + effnet=None, + denoiser="unet", + timesteps=None, + ): + do_autocast = False + with torch.autocast("cuda", enabled=do_autocast): + self.profile_start("denoise", color="blue") + for step_index, timestep in enumerate(timesteps): + # ratio input required for stable cascade prior + timestep_ratio = timestep.expand(latents.size(0)).to(latents.dtype) + # Expand the latents and timestep_ratio if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + timestep_ratio_input = ( + torch.cat([timestep_ratio] * 2) + if self.do_classifier_free_guidance + else timestep_ratio + ) + + params = { + "sample": latent_model_input, + "timestep_ratio": timestep_ratio_input, + "clip_text_pooled": pooled_embeddings, + } + if text_embeddings is not None: + params.update({"clip_text": text_embeddings}) + if image_embeds is not None: + params.update({"clip_img": image_embeds}) + if effnet is not None: + params.update({"effnet": effnet}) + + # Predict the noise residual + if self.torch_inference: + noise_pred = self.torch_models[denoiser](**params)["sample"] + else: + noise_pred = self.runEngine(denoiser, params)["latent"] + + # Perform guidance + if self.do_classifier_free_guidance: + noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # from diffusers (prepare_extra_step_kwargs) + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ): + # TODO: configurable eta + eta = 0.0 + extra_step_kwargs["eta"] = eta + if "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ): + extra_step_kwargs["generator"] = self.generator + + latents = self.scheduler.step( + noise_pred, + timestep_ratio, + latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + latents = latents.to(dtype=torch.bfloat16 if self.bf16 else torch.float32) + + self.profile_stop("denoise") + return latents + + def decode_latent(self, latents): + self.profile_start("vqgan", color="red") + latents = self.models["vqgan"].scale_factor * latents + if self.torch_inference: + images = self.torch_models["vqgan"](latents)["sample"] + else: + images = self.runEngine("vqgan", {"latent": latents})["images"] + self.profile_stop("vqgan") + return images + + def print_summary(self, denoising_steps, walltime_ms, batch_size): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP", + cudart.cudaEventElapsedTime( + self.events["clip"][0], self.events["clip"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "UNet" + " x " + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["denoise"][0], self.events["denoise"][1] + )[1], + ) + ) + if "vqgan" in self.stages: + print( + "| {:^15} | {:>9.2f} ms |".format( + "VQGAN", + cudart.cudaEventElapsedTime( + self.events["vqgan"][0], self.events["vqgan"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print("Throughput: {:.2f} image/s".format(batch_size * 1000.0 / walltime_ms)) + + def infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + image_embeddings=None, + warmup=False, + verbose=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + image_embeddings (`torch.FloatTensor` or `List[torch.FloatTensor]`): + Image Embeddings either extracted from an image or generated by a Prior Model. + warmup (bool): + Indicate if this is a warmup run. + verbose (bool): + Verbose in logging + save_image (bool): + Save the generated image (if applicable) + """ + if self.pipeline_type.is_cascade_decoder(): + assert image_embeddings is not None, ( + "Image Embeddings are required to run the decoder. Provided None" + ) + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + # Spatial dimensions of latent tensor + latent_height = image_height // 42 + latent_width = image_width // 42 + + if image_embeddings is not None: + assert latent_height == image_embeddings.shape[-2] + assert latent_width == image_embeddings.shape[-1] + + if self.generator and self.seed: + self.generator.manual_seed(self.seed) + + num_inference_steps = self.denoising_steps + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + denoise_kwargs = {} + # TODO: support custom timesteps + timesteps = None + if timesteps is not None: + if "timesteps" not in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ): + raise ValueError( + f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + self.scheduler.set_timesteps(timesteps=timesteps, device=self.device) + assert self.denoising_steps == len(self.scheduler.timesteps) + else: + self.scheduler.set_timesteps(self.denoising_steps, device=self.device) + timesteps = self.scheduler.timesteps.to(self.device) + if isinstance(self.scheduler, DDPMWuerstchenScheduler): + timesteps = timesteps[:-1] + denoise_kwargs.update({"timesteps": timesteps}) + + # Initialize latents + latents_dtpye = ( + torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32 + ) + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=16 + if self.pipeline_type.is_cascade_prior() + else 4, # TODO: can we query "in_channels" from config + latent_height=latent_height + if self.pipeline_type.is_cascade_prior() + else int(latent_height * self.latent_dim_scale), + latent_width=latent_width + if self.pipeline_type.is_cascade_prior() + else int(latent_width * self.latent_dim_scale), + latents_dtype=latents_dtpye, + ) + + # CLIP text encoder(s) + text_embeddings, pooled_embeddings = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip", + pooled_outputs=True, + output_hidden_states=True, + ) + + if self.pipeline_type.is_cascade_prior(): + denoise_kwargs.update({"text_embeddings": text_embeddings}) + + # image embeds + image_embeds_pooled = torch.zeros( + batch_size, 1, 768, device=self.device, dtype=latents_dtpye + ) + image_embeds = ( + torch.cat( + [image_embeds_pooled, torch.zeros_like(image_embeds_pooled)] + ) + if self.do_classifier_free_guidance + else image_embeddings + ) + denoise_kwargs.update({"image_embeds": image_embeds}) + else: + effnet = ( + torch.cat([image_embeddings, torch.zeros_like(image_embeddings)]) + if self.do_classifier_free_guidance + else image_embeddings + ) + denoise_kwargs.update({"effnet": effnet}) + + # UNet denoiser + latents = self.denoise_latent( + latents, + pooled_embeddings.unsqueeze(1), + denoiser="unet", + **denoise_kwargs, + ) + + if not self.return_latents: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + self.print_summary(num_inference_steps, walltime_ms, batch_size) + if not self.return_latents and save_image: + # post-process images + images = ( + ((images) * 255) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return (latents, walltime_ms) if self.return_latents else (images, walltime_ms) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py new file mode 100644 index 000000000..b40ade622 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_35_pipeline.py @@ -0,0 +1,862 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +import argparse +import inspect +import time +from typing import Any, List + +import tensorrt as trt +import torch +from cuda import cudart +from transformers import PreTrainedTokenizerBase + +from demo_diffusion import path as path_module +from demo_diffusion.model import ( + CLIPWithProjModel, + SD3TransformerModel, + T5Model, + VAEEncoderModel, + VAEModel, + make_tokenizer, +) +from demo_diffusion.pipeline.diffusion_pipeline import DiffusionPipeline +from demo_diffusion.pipeline.type import PIPELINE_TYPE + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +class StableDiffusion35Pipeline(DiffusionPipeline): + """ + Application showcasing the acceleration of Stable Diffusion 3.5 pipelines using Nvidia TensorRT. + """ + + def __init__( + self, + version: str, + pipeline_type=PIPELINE_TYPE.TXT2IMG, + guidance_scale: float = 7.0, + max_sequence_length: int = 256, + **kwargs, + ): + """ + Initializes the Stable Diffusion 3.5 pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of ['3.5-medium', '3.5-large'] + pipeline_type (PIPELINE_TYPE): + Type of current pipeline. + guidance_scale (`float`, defaults to 7.0): + Guidance scale is enabled by setting as > 1. + Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. + max_sequence_length (`int`, defaults to 256): + Maximum sequence length to use with the `prompt`. + """ + super().__init__(version=version, pipeline_type=pipeline_type, **kwargs) + + self.fp16 = True if not self.bf16 else False + + self.force_weakly_typed_t5 = False + self.config["clip_hidden_states"] = True + + self.guidance_scale = guidance_scale + self.do_classifier_free_guidance = self.guidance_scale > 1 + self.max_sequence_length = max_sequence_length + + @classmethod + def FromArgs( + cls, args: argparse.Namespace, pipeline_type: PIPELINE_TYPE + ) -> StableDiffusion35Pipeline: + """Factory method to construct a `StableDiffusion35Pipeline` object from parsed arguments. + + Overrides: + DiffusionPipeline.FromArgs + """ + MAX_BATCH_SIZE = 4 + DEVICE = "cuda" + DO_RETURN_LATENTS = False + + # Resolve all paths. + dd_path = path_module.resolve_path( + cls.get_model_names(pipeline_type), + args, + pipeline_type, + cls._get_pipeline_uid(args.version), + ) + + return cls( + dd_path=dd_path, + version=args.version, + pipeline_type=pipeline_type, + guidance_scale=args.guidance_scale, + max_sequence_length=args.max_sequence_length, + bf16=args.bf16, + low_vram=args.low_vram, + torch_fallback=args.torch_fallback, + weight_streaming=args.ws, + max_batch_size=MAX_BATCH_SIZE, + denoising_steps=args.denoising_steps, + scheduler=args.scheduler, + device=DEVICE, + output_dir=args.output_dir, + hf_token=args.hf_token, + verbose=args.verbose, + nvtx_profile=args.nvtx_profile, + use_cuda_graph=args.use_cuda_graph, + framework_model_dir=args.framework_model_dir, + return_latents=DO_RETURN_LATENTS, + torch_inference=args.torch_inference, + ) + + @classmethod + def get_model_names(cls, pipeline_type: PIPELINE_TYPE) -> List[str]: + """Return a list of model names used by this pipeline. + + Overrides: + DiffusionPipeline.get_model_names + """ + return ["clip_l", "clip_g", "t5", "transformer", "vae"] + + def download_onnx_models( + self, model_name: str, model_config: dict[str, Any] + ) -> None: + raise ValueError( + "ONNX models download is not supported for the Stable Diffusion 3.5 pipeline" + ) + + def load_resources( + self, + image_height: int, + image_width: int, + batch_size: int, + seed: int, + ): + super().load_resources(image_height, image_width, batch_size, seed) + + def _initialize_models(self, framework_model_dir, int8, fp8, fp4): + # Load text tokenizer(s) + self.tokenizer = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + ) + self.tokenizer2 = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + subfolder="tokenizer_2", + ) + self.tokenizer3 = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + subfolder="tokenizer_3", + tokenizer_type="t5", + ) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + self.bf16 = True if int8 or fp8 or fp4 else self.bf16 + self.fp16 = True if not self.bf16 else False + if "clip_l" in self.stages: + self.models["clip_l"] = CLIPWithProjModel( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + subfolder="text_encoder", + output_hidden_states=self.config.get("clip_hidden_states", False), + ) + + if "clip_g" in self.stages: + self.models["clip_g"] = CLIPWithProjModel( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + subfolder="text_encoder_2", + output_hidden_states=self.config.get("clip_hidden_states", False), + ) + + if "t5" in self.stages: + # Known accuracy issues with FP16 + self.models["t5"] = T5Model( + **models_args, + fp16=self.fp16, + bf16=self.bf16, + subfolder="text_encoder_3", + text_maxlen=self.max_sequence_length, + build_strongly_typed=True, + weight_streaming=self.weight_streaming, + weight_streaming_budget_percentage=self.text_encoder_weight_streaming_budget_percentage, + ) + + if "transformer" in self.stages: + self.models["transformer"] = SD3TransformerModel( + **models_args, + bf16=self.bf16, + fp16=self.fp16, + text_maxlen=self.models["t5"].text_maxlen + + self.models["clip_g"].text_maxlen, + build_strongly_typed=True, + weight_streaming=self.weight_streaming, + weight_streaming_budget_percentage=self.denoiser_weight_streaming_budget_percentage, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + if "vae" in self.stages: + self.models["vae"] = VAEModel( + **models_args, fp16=self.fp16, tf32=True, bf16=self.bf16 + ) + + self.vae_scale_factor = ( + 2 ** (len(self.models["vae"].config["block_out_channels"]) - 1) + if "vae" in self.models + else 8 + ) + self.patch_size = ( + self.models["transformer"].config["patch_size"] + if "transformer" in self.stages and self.models["transformer"] is not None + else 2 + ) + + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = VAEEncoderModel( + **models_args, fp16=False, tf32=self.tf32, bf16=self.bf16 + ) + self.vae_latent_channels = ( + self.models["vae"].config["latent_channels"] + if "vae" in self.stages and self.models["vae"] is not None + else 16 + ) + + def print_summary(self, denoising_steps, walltime_ms): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + if "vae_encoder" in self.stages: + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE Encoder", + cudart.cudaEventElapsedTime( + self.events["vae_encode"][0], self.events["vae_encode"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP-G", + cudart.cudaEventElapsedTime( + self.events["clip_g"][0], self.events["clip_g"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP-L", + cudart.cudaEventElapsedTime( + self.events["clip_l"][0], self.events["clip_l"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "T5", + cudart.cudaEventElapsedTime(self.events["t5"][0], self.events["t5"][1])[ + 1 + ], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "MMDiT" + " x " + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["transformer"][0], self.events["transformer"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE Decoder", + cudart.cudaEventElapsedTime( + self.events["vae"][0], self.events["vae"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print( + "Throughput: {:.2f} image/s".format(self.batch_size * 1000.0 / walltime_ms) + ) + + @staticmethod + def _tokenize( + tokenizer: PreTrainedTokenizerBase, + prompt: list[str], + max_sequence_length: int, + device: torch.device, + ): + text_input_ids = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ).input_ids + text_input_ids = text_input_ids.type(torch.int32) + + untruncated_ids = tokenizer( + prompt, + padding="longest", + return_tensors="pt", + ).input_ids.type(torch.int32) + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, max_sequence_length - 1 : -1] + ) + TRT_LOGGER.warning( + "The following part of your input was truncated because `max_sequence_length` is set to " + f" {max_sequence_length} tokens: {removed_text}" + ) + text_input_ids = text_input_ids.to(device) + return text_input_ids + + def _get_prompt_embed( + self, + prompt: list[str], + encoder_name: str, + domain="positive_prompt", + ): + if encoder_name == "clip_l": + tokenizer = self.tokenizer + max_sequence_length = tokenizer.model_max_length + output_hidden_states = True + elif encoder_name == "clip_g": + tokenizer = self.tokenizer2 + max_sequence_length = tokenizer.model_max_length + output_hidden_states = True + elif encoder_name == "t5": + tokenizer = self.tokenizer3 + max_sequence_length = self.max_sequence_length + output_hidden_states = False + else: + raise NotImplementedError(f"encoder not found: {encoder_name}") + + self.profile_start(encoder_name, color="green", domain=domain) + + text_input_ids = self._tokenize( + tokenizer=tokenizer, + prompt=prompt, + device=self.device, + max_sequence_length=max_sequence_length, + ) + + text_hidden_states = None + if self.torch_inference or self.torch_fallback[encoder_name]: + outputs = self.torch_models[encoder_name]( + text_input_ids, + output_hidden_states=output_hidden_states, + ) + text_embeddings = outputs[0].clone() + if output_hidden_states: + text_hidden_states = outputs["hidden_states"][-2].clone() + else: + # NOTE: output tensor for the encoder must be cloned because it will be overwritten when called again for prompt2 + outputs = self.run_engine(encoder_name, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + text_hidden_states = outputs["hidden_states"].clone() + + self.profile_stop(encoder_name) + return text_hidden_states, text_embeddings + + @staticmethod + def _duplicate_text_embed( + prompt_embed: torch.Tensor, + batch_size: int, + num_images_per_prompt: int, + pooled_prompt_embed: torch.Tensor | None = None, + ): + _, seq_len, _ = prompt_embed.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embed = prompt_embed.repeat(1, num_images_per_prompt, 1) + prompt_embed = prompt_embed.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + if pooled_prompt_embed is not None: + pooled_prompt_embed = pooled_prompt_embed.repeat( + 1, num_images_per_prompt, 1 + ) + pooled_prompt_embed = pooled_prompt_embed.view( + batch_size * num_images_per_prompt, -1 + ) + + return prompt_embed, pooled_prompt_embed + + def encode_prompt( + self, + prompt: list[str], + negative_prompt: list[str] | None = None, + num_images_per_prompt: int = 1, + ): + clip_l_prompt_embed, clip_l_pooled_embed = self._get_prompt_embed( + prompt=prompt, + encoder_name="clip_l", + ) + prompt_embed, pooled_prompt_embed = self._duplicate_text_embed( + prompt_embed=clip_l_prompt_embed.clone(), + pooled_prompt_embed=clip_l_pooled_embed.clone(), + num_images_per_prompt=num_images_per_prompt, + batch_size=self.batch_size, + ) + + clip_g_prompt_embed, clip_g_pooled_embed = self._get_prompt_embed( + prompt=prompt, + encoder_name="clip_g", + ) + prompt_2_embed, pooled_prompt_2_embed = self._duplicate_text_embed( + prompt_embed=clip_g_prompt_embed.clone(), + pooled_prompt_embed=clip_g_pooled_embed.clone(), + batch_size=self.batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + _, t5_prompt_embed = self._get_prompt_embed( + prompt=prompt, + encoder_name="t5", + ) + + t5_prompt_embed, _ = self._duplicate_text_embed( + prompt_embed=t5_prompt_embed.clone(), + batch_size=self.batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, + (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]), + ) + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat( + [pooled_prompt_embed, pooled_prompt_2_embed], dim=-1 + ) + + if negative_prompt is None: + negative_prompt = "" + + clip_l_negative_prompt_embed, clip_l_negative_pooled_embed = ( + self._get_prompt_embed( + prompt=negative_prompt, + encoder_name="clip_l", + ) + ) + negative_prompt_embed, negative_pooled_prompt_embed = ( + self._duplicate_text_embed( + prompt_embed=clip_l_negative_prompt_embed.clone(), + pooled_prompt_embed=clip_l_negative_pooled_embed.clone(), + batch_size=self.batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + ) + + clip_g_negative_prompt_embed, clip_g_negative_pooled_embed = ( + self._get_prompt_embed( + prompt=negative_prompt, + encoder_name="clip_g", + ) + ) + negative_prompt_2_embed, negative_pooled_prompt_2_embed = ( + self._duplicate_text_embed( + prompt_embed=clip_g_negative_prompt_embed.clone(), + pooled_prompt_embed=clip_g_negative_pooled_embed.clone(), + batch_size=self.batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + ) + + _, t5_negative_prompt_embed = self._get_prompt_embed( + prompt=negative_prompt, + encoder_name="t5", + ) + + t5_negative_prompt_embed, _ = self._duplicate_text_embed( + prompt_embed=t5_negative_prompt_embed.clone(), + batch_size=self.batch_size, + num_images_per_prompt=num_images_per_prompt, + ) + + negative_clip_prompt_embeds = torch.cat( + [negative_prompt_embed, negative_prompt_2_embed], dim=-1 + ) + negative_clip_prompt_embeds = torch.nn.functional.pad( + negative_clip_prompt_embeds, + ( + 0, + t5_negative_prompt_embed.shape[-1] + - negative_clip_prompt_embeds.shape[-1], + ), + ) + negative_prompt_embeds = torch.cat( + [negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2 + ) + negative_pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1 + ) + + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + @staticmethod + def initialize_latents( + batch_size: int, + num_channels_latents: int, + latent_height: int, + latent_width: int, + device: torch.device, + generator: torch.Generator, + dtype=torch.float32, + layout=torch.strided, + ) -> torch.Tensor: + latents_shape = (batch_size, num_channels_latents, latent_height, latent_width) + latents = torch.randn( + latents_shape, + dtype=dtype, + device="cuda", + generator=generator, + layout=layout, + ).to(device) + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps + @staticmethod + def retrieve_timesteps( + scheduler, + num_inference_steps: int | None = None, + device: str | torch.device | None = None, + timesteps: list[int] | None = None, + sigmas: list[float] | None = None, + **kwargs, + ): + r""" + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError( + "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values" + ) + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + def denoise_latents( + self, + latents: torch.Tensor, + prompt_embeds: torch.Tensor, + pooled_prompt_embeds: torch.Tensor, + timesteps: torch.FloatTensor, + guidance_scale: float, + denoiser="transformer", + ) -> torch.Tensor: + do_autocast = self.torch_inference != "" and self.models[denoiser].fp16 + with torch.autocast("cuda", enabled=do_autocast): + self.profile_start(denoiser, color="blue") + + for step_index, timestep in enumerate(timesteps): + # expand the latents as we are doing classifier free guidance + latents_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timestep_inp = timestep.expand(latents_model_input.shape[0]) + + params = { + "hidden_states": latents_model_input, + "timestep": timestep_inp, + "encoder_hidden_states": prompt_embeds, + "pooled_projections": pooled_prompt_embeds, + } + # Predict the noise residual + if self.torch_inference or self.torch_fallback[denoiser]: + noise_pred = self.torch_models[denoiser](**params)["sample"] + else: + noise_pred = self.run_engine(denoiser, params)["latent"] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, timestep, latents, return_dict=False + )[0] + + self.profile_stop(denoiser) + return latents + + def decode_latents(self, latents: torch.Tensor, decoder="vae") -> torch.Tensor: + cast_to = ( + torch.float16 + if self.models[decoder].fp16 + else torch.bfloat16 + if self.models[decoder].bf16 + else torch.float32 + ) + latents = latents.to(dtype=cast_to) + self.profile_start(decoder, color="red") + if self.torch_inference or self.torch_fallback[decoder]: + images = self.torch_models[decoder](latents, return_dict=False)[0] + else: + images = self.run_engine(decoder, {"latent": latents})["images"] + self.profile_stop(decoder) + return images + + def infer( + self, + prompt: list[str], + negative_prompt: list[str], + image_height: int, + image_width: int, + warmup=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (list[str]): + The text prompt to guide image generation. + negative_prompt (list[str]): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + warmup (bool): + Indicate if this is a warmup run. + save_image (bool): + Save the generated image (if applicable) + """ + assert len(prompt) == len(negative_prompt) + self.batch_size = len(prompt) + + # Spatial dimensions of latent tensor + assert image_height % (self.vae_scale_factor * self.patch_size) == 0, ( + f"image height not supported {image_height}" + ) + assert image_width % (self.vae_scale_factor * self.patch_size) == 0, ( + f"image width not supported {image_width}" + ) + latent_height = int(image_height) // self.vae_scale_factor + latent_width = int(image_width) // self.vae_scale_factor + + if self.generator and self.seed: + self.generator.manual_seed(self.seed) + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # 3. encode inputs + with self.model_memory_manager( + ["clip_g", "clip_l", "t5"], low_vram=self.low_vram + ): + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=1, + ) + # do classifier free guidance + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + pooled_prompt_embeds = torch.cat( + [negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0 + ) + + # 4. Prepare latent variables + num_channels_latents = self.models["transformer"].config["in_channels"] + latents = self.initialize_latents( + batch_size=self.batch_size, + num_channels_latents=num_channels_latents, + latent_height=latent_height, + latent_width=latent_width, + device=prompt_embeds.device, + generator=self.generator, + dtype=torch.float16 + if self.fp16 + else torch.bfloat16 + if self.bf16 + else torch.float32, + ) + + # 5. Prepare timesteps + timesteps, num_inference_steps = self.retrieve_timesteps( + scheduler=self.scheduler, + num_inference_steps=self.denoising_steps, + device=self.device, + sigmas=None, + ) + + # 7 Denoise + with self.model_memory_manager(["transformer"], low_vram=self.low_vram): + latents = self.denoise_latents( + latents=latents, + prompt_embeds=prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + timesteps=timesteps, + guidance_scale=self.guidance_scale, + ) + + # Decode Latents + latents = ( + latents / self.models["vae"].config["scaling_factor"] + ) + self.models["vae"].config["shift_factor"] + with self.model_memory_manager(["vae"], low_vram=self.low_vram): + images = self.decode_latents(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + self.print_summary( + num_inference_steps, + walltime_ms, + ) + if save_image: + # post-process images + images = ( + ((images + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return images, walltime_ms + + def run( + self, + prompt: list[str], + negative_prompt: list[str], + height: int, + width: int, + batch_count: int, + num_warmup_runs: int, + use_cuda_graph: bool, + **kwargs, + ): + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer( + prompt, negative_prompt, height, width, warmup=True, **kwargs + ) + + for _ in range(batch_count): + print("[I] Running StableDiffusion 3.5 pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(prompt, negative_prompt, height, width, warmup=False, **kwargs) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_3_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_3_pipeline.py new file mode 100644 index 000000000..0b86039b1 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_3_pipeline.py @@ -0,0 +1,807 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import math +import os +import pathlib +import time + +import nvtx +import tensorrt as trt +import torch +from cuda import cudart + +import demo_diffusion.engine as engine_module +import demo_diffusion.image as image_module +from demo_diffusion.model import ( + SD3_CLIPGModel, + SD3_CLIPLModel, + SD3_MMDiTModel, + SD3_T5XXLModel, + SD3_VAEDecoderModel, + SD3_VAEEncoderModel, + get_clip_embedding_dim, +) +from demo_diffusion.pipeline.type import PIPELINE_TYPE +from demo_diffusion.utils_sd3.other_impls import SD3Tokenizer +from demo_diffusion.utils_sd3.sd3_impls import SD3LatentFormat, sample_euler + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +class StableDiffusion3Pipeline: + """ + Application showcasing the acceleration of Stable Diffusion 3 pipelines using NVidia TensorRT. + """ + + def __init__( + self, + version="sd3", + pipeline_type=PIPELINE_TYPE.TXT2IMG, + max_batch_size=16, + shift=1.0, + cfg_scale=5, + denoising_steps=50, + denoising_percentage=0.6, + input_image=None, + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + torch_inference="", + ): + """ + Initializes the Stable Diffusion 3 pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of ['sd3] + pipeline_type (PIPELINE_TYPE): + Type of current pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + shift (float): + Shift parameter for MMDiT model. Default: 1.0 + cfg_scale (int): + CFG Scale used for denoising. Default: 5 + denoising_steps (int): + Number of denoising steps. Default: 1.0 + denoising_percentage (float): + Denoising percentage. Default: 0.6 + input_image (float): + Input image for conditioning. Default: None + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + framework_model_dir (str): + cache directory for framework checkpoints + torch_inference (str): + Run inference with PyTorch (using specified compilation mode) instead of TensorRT. + """ + + self.max_batch_size = max_batch_size + self.shift = shift + self.cfg_scale = cfg_scale + self.denoising_steps = denoising_steps + self.input_image = input_image + self.denoising_percentage = ( + denoising_percentage if input_image is not None else 1.0 + ) + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + self.version = version + + # Pipeline type + self.pipeline_type = pipeline_type + self.stages = ["clip_g", "clip_l", "t5xxl", "mmdit", "vae_decoder"] + if input_image is not None: + self.stages += ["vae_encoder"] + + self.config = {} + self.config["clip_hidden_states"] = True + self.torch_inference = torch_inference + if self.torch_inference: + torch._inductor.config.conv_1x1_as_mm = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_check_all_directions = True + self.use_cuda_graph = use_cuda_graph + + # initialized in loadEngines() + self.models = {} + self.torch_models = {} + self.engine = {} + self.shared_device_memory = None + + # initialized in loadResources() + self.events = {} + self.generator = None + self.markers = {} + self.seed = None + self.stream = None + self.tokenizer = None + + def loadResources(self, image_height, image_width, batch_size, seed): + # Initialize noise generator + if seed: + self.seed = seed + self.generator = torch.Generator(device="cuda").manual_seed(seed) + + # Create CUDA events and stream + for stage in [ + "clip_g", + "clip_l", + "t5xxl", + "denoise", + "vae_encode", + "vae_decode", + ]: + self.events[stage] = [ + cudart.cudaEventCreate()[1], + cudart.cudaEventCreate()[1], + ] + self.stream = cudart.cudaStreamCreate()[1] + + # Allocate TensorRT I/O buffers + if not self.torch_inference: + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict( + batch_size, image_height, image_width + ), + device=self.device, + ) + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e[0]) + cudart.cudaEventDestroy(e[1]) + + for engine in self.engine.values(): + del engine + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def getOnnxPath(self, model_name, onnx_dir, opt=True, suffix=""): + onnx_model_dir = os.path.join( + onnx_dir, model_name + suffix + (".opt" if opt else "") + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def getEnginePath(self, model_name, engine_dir, enable_refit=False, suffix=""): + return os.path.join( + engine_dir, + model_name + + suffix + + (".refit" if enable_refit else "") + + ".trt" + + trt.__version__ + + ".plan", + ) + + def loadEngines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=False, + static_shape=True, + enable_all_tactics=False, + timing_cache=None, + **_kwargs, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to store the TensorRT engines. + framework_model_dir (str): + Directory to store the framework model ckpt. + onnx_dir (str): + Directory to store the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to speed up TensorRT build. + """ + # Create directories if missing + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + # Load text tokenizer + self.tokenizer = SD3Tokenizer() + + # Load text encoders + if "clip_g" in self.stages: + self.models["clip_g"] = SD3_CLIPGModel( + **models_args, fp16=True, pooled_output=True + ) + + if "clip_l" in self.stages: + self.models["clip_l"] = SD3_CLIPLModel( + **models_args, fp16=True, pooled_output=True + ) + + if "t5xxl" in self.stages: + self.models["t5xxl"] = SD3_T5XXLModel( + **models_args, + fp16=True, + embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), + ) + + # Load MMDiT model + if "mmdit" in self.stages: + self.models["mmdit"] = SD3_MMDiTModel( + **models_args, fp16=True, shift=self.shift + ) + + # Load VAE Encoder model + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = SD3_VAEEncoderModel(**models_args, fp16=True) + + # Load VAE Decoder model + if "vae_decoder" in self.stages: + self.models["vae_decoder"] = SD3_VAEDecoderModel(**models_args, fp16=True) + + # Configure pipeline models to load + model_names = self.models.keys() + # Torch fallback + self.torch_fallback = dict( + zip( + model_names, + [ + self.torch_inference or model_name in ("t5xxl") + for model_name in model_names + ], + ) + ) + + onnx_path = dict( + zip( + model_names, + [ + self.getOnnxPath(model_name, onnx_dir, opt=False) + for model_name in model_names + ], + ) + ) + onnx_opt_path = dict( + zip( + model_names, + [self.getOnnxPath(model_name, onnx_dir) for model_name in model_names], + ) + ) + engine_path = dict( + zip( + model_names, + [ + self.getEnginePath(model_name, engine_dir) + for model_name in model_names + ], + ) + ) + + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + # Export models to ONNX + do_export_onnx = not os.path.exists( + engine_path[model_name] + ) and not os.path.exists(onnx_opt_path[model_name]) + if do_export_onnx: + obj.export_onnx( + onnx_path[model_name], + onnx_opt_path[model_name], + onnx_opset, + opt_image_height, + opt_image_width, + static_shape=static_shape, + ) + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + engine = engine_module.Engine(engine_path[model_name]) + if not os.path.exists(engine_path[model_name]): + update_output_names = ( + obj.get_output_names() + obj.extra_output_names + if obj.extra_output_names + else None + ) + extra_build_args = {"verbose": self.verbose} + fp16amp = obj.fp16 + engine.build( + onnx_opt_path[model_name], + fp16=fp16amp, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=update_output_names, + verbose=self.verbose, + ) + self.engine[model_name] = engine + + # Load TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + self.engine[model_name].load() + + # Load torch models + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name] or model_name == "mmdit": + self.torch_models[model_name] = obj.get_model( + torch_inference=self.torch_inference + ) + + def calculateMaxDeviceMemory(self): + max_device_memory = 0 + for model_name, engine in self.engine.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activateEngines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.calculateMaxDeviceMemory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engine.values(): + engine.activate(device_memory=self.shared_device_memory) + + def runEngine(self, model_name, feed_dict): + engine = self.engine[model_name] + return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) + + def initialize_latents( + self, batch_size, unet_channels, latent_height, latent_width + ): + return ( + torch.ones( + batch_size, unet_channels, latent_height, latent_width, device="cuda" + ) + * 0.0609 + ) + + def profile_start(self, name, color="blue"): + if self.nvtx_profile: + self.markers[name] = nvtx.start_range(message=name, color=color) + if name in self.events: + cudart.cudaEventRecord(self.events[name][0], 0) + + def profile_stop(self, name): + if name in self.events: + cudart.cudaEventRecord(self.events[name][1], 0) + if self.nvtx_profile: + nvtx.end_range(self.markers[name]) + + def print_summary(self, denoising_steps, walltime_ms, batch_size): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + if "vae_encoder" in self.stages: + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE Encoder", + cudart.cudaEventElapsedTime( + self.events["vae_encode"][0], self.events["vae_encode"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP-G", + cudart.cudaEventElapsedTime( + self.events["clip_g"][0], self.events["clip_g"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP-L", + cudart.cudaEventElapsedTime( + self.events["clip_l"][0], self.events["clip_l"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "T5XXL", + cudart.cudaEventElapsedTime( + self.events["t5xxl"][0], self.events["t5xxl"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "MMDiT" + " x " + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["denoise"][0], self.events["denoise"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE Decoder", + cudart.cudaEventElapsedTime( + self.events["vae_decode"][0], self.events["vae_decode"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print("Throughput: {:.2f} image/s".format(batch_size * 1000.0 / walltime_ms)) + + def save_image(self, images, pipeline, prompt, seed): + # Save image + image_name_prefix = ( + pipeline + + "".join( + set( + ["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))] + ) + ) + + "-" + + str(seed) + + "-" + ) + image_name_suffix = "torch" if self.torch_inference else "trt" + image_module.save_image( + images, self.output_dir, image_name_prefix, image_name_suffix + ) + + def encode_prompt(self, prompt, negative_prompt): + def encode_token_weights(model_name, token_weight_pairs): + self.profile_start(model_name, color="green") + + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + tokens = torch.tensor([tokens], dtype=torch.int64, device=self.device) + if self.torch_inference or self.torch_fallback[model_name]: + out, pooled = self.torch_models[model_name](tokens) + else: + trt_out = self.runEngine(model_name, {"input_ids": tokens}) + out, pooled = trt_out["text_embeddings"], trt_out["pooled_output"] + + self.profile_stop(model_name) + + if pooled is not None: + first_pooled = pooled[0:1].cuda() + else: + first_pooled = pooled + output = [out[0:1]] + return torch.cat(output, dim=-2).cuda(), first_pooled + + def tokenize(prompt): + tokens = self.tokenizer.tokenize_with_weights(prompt) + l_out, l_pooled = encode_token_weights("clip_l", tokens["l"]) + g_out, g_pooled = encode_token_weights("clip_g", tokens["g"]) + t5_out, _ = encode_token_weights("t5xxl", tokens["t5xxl"]) + lg_out = torch.cat([l_out, g_out], dim=-1) + lg_out = torch.nn.functional.pad(lg_out, (0, 4096 - lg_out.shape[-1])) + + return torch.cat([lg_out, t5_out], dim=-2), torch.cat( + (l_pooled, g_pooled), dim=-1 + ) + + conditioning = tokenize(prompt[0]) + neg_conditioning = tokenize(negative_prompt[0]) + return conditioning, neg_conditioning + + def denoise_latent( + self, latent, conditioning, neg_conditioning, model_name="mmdit" + ): + def get_noise(latent): + return torch.randn( + latent.size(), + dtype=torch.float32, + layout=latent.layout, + generator=self.generator, + device="cuda", + ).to(latent.dtype) + + def get_sigmas(sampling, steps): + start = sampling.timestep(sampling.sigma_max) + end = sampling.timestep(sampling.sigma_min) + timesteps = torch.linspace(start, end, steps) + sigs = [] + for x in range(len(timesteps)): + ts = timesteps[x] + sigs.append(sampling.sigma(ts)) + sigs += [0.0] + return torch.FloatTensor(sigs) + + def max_denoise(sigmas): + max_sigma = float(self.torch_models[model_name].model_sampling.sigma_max) + sigma = float(sigmas[0]) + return math.isclose(max_sigma, sigma, rel_tol=1e-05) or sigma > max_sigma + + def fix_cond(cond): + cond, pooled = (cond[0].half().cuda(), cond[1].half().cuda()) + return {"c_crossattn": cond, "y": pooled} + + def cfg_denoiser(x, timestep, cond, uncond, cond_scale): + # Run cond and uncond in a batch together + sample = torch.cat([x, x]) + sigma = torch.cat([timestep, timestep]) + c_crossattn = torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]) + y = torch.cat([cond["y"], uncond["y"]]) + if self.torch_inference: + with torch.autocast("cuda", dtype=torch.float16): + batched = self.torch_models[model_name]( + sample, sigma, c_crossattn=c_crossattn, y=y + ) + else: + input_dict = { + "sample": sample, + "sigma": sigma, + "c_crossattn": c_crossattn, + "y": y, + } + batched = self.runEngine(model_name, input_dict)["latent"] + + # Then split and apply CFG Scaling + pos_out, neg_out = batched.chunk(2) + scaled = neg_out + (pos_out - neg_out) * cond_scale + return scaled + + self.profile_start("denoise", color="blue") + + latent = latent.half().cuda() + noise = get_noise(latent).cuda() + sigmas = get_sigmas( + self.torch_models[model_name].model_sampling, self.denoising_steps + ).cuda() + sigmas = sigmas[int(self.denoising_steps * (1 - self.denoising_percentage)) :] + conditioning = fix_cond(conditioning) + neg_conditioning = fix_cond(neg_conditioning) + + noise_scaled = self.torch_models[model_name].model_sampling.noise_scaling( + sigmas[0], noise, latent, max_denoise(sigmas) + ) + extra_args = { + "cond": conditioning, + "uncond": neg_conditioning, + "cond_scale": self.cfg_scale, + } + latent = sample_euler(cfg_denoiser, noise_scaled, sigmas, extra_args=extra_args) + latent = SD3LatentFormat().process_out(latent) + + self.profile_stop("denoise") + + return latent + + def encode_image(self): + self.input_image = self.input_image.to(self.device) + self.profile_start("vae_encode", color="orange") + if self.torch_inference: + with torch.autocast("cuda", dtype=torch.float16): + latent = self.torch_models["vae_encoder"](self.input_image) + else: + latent = self.runEngine("vae_encoder", {"images": self.input_image})[ + "latent" + ] + + latent = SD3LatentFormat().process_in(latent) + self.profile_stop("vae_encode") + return latent + + def decode_latent(self, latent): + self.profile_start("vae_decode", color="red") + if self.torch_inference: + with torch.autocast("cuda", dtype=torch.float16): + image = self.torch_models["vae_decoder"](latent) + else: + image = self.runEngine("vae_decoder", {"latent": latent})["images"] + image = image.float() + self.profile_stop("vae_decode") + return image + + def infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + warmup=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + warmup (bool): + Indicate if this is a warmup run. + save_image (bool): + Save the generated image (if applicable) + """ + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + # Spatial dimensions of latent tensor + latent_height = image_height // 8 + latent_width = image_width // 8 + + if self.generator and self.seed: + self.generator.manual_seed(self.seed) + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize Latents + latent = self.initialize_latents( + batch_size=batch_size, + unet_channels=16, + latent_height=latent_height, + latent_width=latent_width, + ) + + # Encode input image + if self.input_image is not None: + latent = self.encode_image() + + # Get Conditionings + conditioning, neg_conditioning = self.encode_prompt(prompt, negative_prompt) + + # Denoise + latent = self.denoise_latent(latent, conditioning, neg_conditioning) + + # Decode Latents + images = self.decode_latent(latent) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + num_inference_steps = int(self.denoising_steps * self.denoising_percentage) + self.print_summary(num_inference_steps, walltime_ms, batch_size) + if save_image: + # post-process images + images = ( + ((images + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return images, walltime_ms + + def run( + self, + prompt, + negative_prompt, + height, + width, + batch_size, + batch_count, + num_warmup_runs, + use_cuda_graph, + **kwargs, + ): + # Process prompt + if not isinstance(prompt, list): + raise ValueError( + f"`prompt` must be of type `str` list, but is {type(prompt)}" + ) + prompt = prompt * batch_size + + if not isinstance(negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` list, but is {type(negative_prompt)}" + ) + if len(negative_prompt) == 1: + negative_prompt = negative_prompt * batch_size + + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer( + prompt, negative_prompt, height, width, warmup=True, **kwargs + ) + + for _ in range(batch_count): + print("[I] Running StableDiffusion3 pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(prompt, negative_prompt, height, width, warmup=False, **kwargs) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_pipeline.py new file mode 100644 index 000000000..b860404b4 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_diffusion_pipeline.py @@ -0,0 +1,1581 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import inspect +import json +import os +import pathlib +import sys +import time +from hashlib import md5 +from typing import List, Optional + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import numpy as np +import nvtx +import tensorrt as trt +import torch +from cuda import cudart +from diffusers import ( + DDIMScheduler, + DDPMScheduler, + DDPMWuerstchenScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LCMScheduler, + LMSDiscreteScheduler, + PNDMScheduler, + UniPCMultistepScheduler, +) + +import demo_diffusion.engine as engine_module +import demo_diffusion.image as image_module +from demo_diffusion.model import ( + CLIPModel, + CLIPWithProjModel, + SDLoraLoader, + UNetModel, + UNetXLModel, + UNetXLModelControlNet, + VAEEncoderModel, + VAEModel, + get_clip_embedding_dim, + make_scheduler, + make_tokenizer, + merge_loras, + unload_torch_model, +) +from demo_diffusion.pipeline.calibrate import load_calib_prompts +from demo_diffusion.pipeline.type import PIPELINE_TYPE +from demo_diffusion.utils_modelopt import ( + SD_FP8_FP16_DEFAULT_CONFIG, + SD_FP8_FP32_DEFAULT_CONFIG, + check_lora, + filter_func, + generate_fp8_scales, + get_int8_config, + quantize_lvl, + set_fmha, +) + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +class StableDiffusionPipeline: + SCHEDULER_DEFAULTS = { + "1.4": "PNDM", + "1.5": "PNDM", + "dreamshaper-7": "PNDM", + "2.0-base": "DDIM", + "2.0": "DDIM", + "2.1-base": "PNDM", + "2.1": "DDIM", + "xl-1.0": "Euler", + "xl-turbo": "EulerA", + "svd-xt-1.1": "Euler", + "cascade": "DDPMWuerstchen", + } + """ + Application showcasing the acceleration of Stable Diffusion pipelines using NVidia TensorRT. + """ + + def __init__( + self, + version="1.5", + pipeline_type=PIPELINE_TYPE.TXT2IMG, + max_batch_size=16, + denoising_steps=30, + scheduler=None, + guidance_scale=7.5, + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + vae_scaling_factor=0.18215, + framework_model_dir="pytorch_model", + controlnets=None, + lora_scale: float = 1.0, + lora_weight: Optional[List[float]] = None, + lora_path: Optional[List[str]] = None, + return_latents=False, + torch_inference="", + ): + """ + Initializes the Diffusion pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of [1.4, 1.5, 2.0, 2.0-base, 2.1, 2.1-base] + pipeline_type (PIPELINE_TYPE): + Type of current pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + denoising_steps (int): + The number of denoising steps. + More denoising steps usually lead to a higher quality image at the expense of slower inference. + scheduler (str): + The scheduler to guide the denoising process. Must be one of [DDIM, DPM, EulerA, Euler, LCM, LMSD, PNDM]. + guidance_scale (float): + Guidance scale is enabled by setting as > 1. + Higher guidance scale encourages to generate images that are closely linked to the text prompt, usually at the expense of lower image quality. + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + vae_scaling_factor (float): + VAE scaling factor + framework_model_dir (str): + cache directory for framework checkpoints + controlnets (str): + Which ControlNet/ControlNets to use. + return_latents (bool): + Skip decoding the image and return latents instead. + torch_inference (str): + Run inference with PyTorch (using specified compilation mode) instead of TensorRT. + """ + + self.denoising_steps = denoising_steps + self.guidance_scale = guidance_scale + self.do_classifier_free_guidance = guidance_scale > 1.0 + self.vae_scaling_factor = vae_scaling_factor + + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + self.version = version + self.controlnets = controlnets + + # Pipeline type + self.pipeline_type = pipeline_type + if self.pipeline_type.is_txt2img(): + self.stages = ["clip", "unet", "vae"] + elif self.pipeline_type.is_img2img() or self.pipeline_type.is_inpaint(): + self.stages = ["vae_encoder", "clip", "unet", "vae"] + elif self.pipeline_type.is_sd_xl_base(): + self.stages = ["clip", "clip2", "unetxl"] + if not return_latents: + self.stages.append("vae") + elif self.pipeline_type.is_sd_xl_refiner(): + self.stages = ["clip2", "unetxl", "vae"] + elif self.pipeline_type.is_img2vid(): + self.stages = ["clip-vis", "clip-imgfe", "unet-temp", "vae-temp"] + elif self.pipeline_type.is_cascade_prior(): + self.stages = ["clip", "unet"] + elif self.pipeline_type.is_cascade_decoder(): + self.stages = ["clip", "unet", "vqgan"] + else: + raise ValueError(f"Unsupported pipeline {self.pipeline_type.name}.") + self.return_latents = return_latents + + if not scheduler: + scheduler = ( + "UniPC" + if self.pipeline_type.is_controlnet() + else self.SCHEDULER_DEFAULTS.get(version, "DDIM") + ) + print(f"[I] Autoselected scheduler: {scheduler}") + + scheduler_class_map = { + "DDIM": DDIMScheduler, + "DDPM": DDPMScheduler, + "EulerA": EulerAncestralDiscreteScheduler, + "Euler": EulerDiscreteScheduler, + "LCM": LCMScheduler, + "LMSD": LMSDiscreteScheduler, + "PNDM": PNDMScheduler, + "UniPC": UniPCMultistepScheduler, + "DDPMWuerstchen": DDPMWuerstchenScheduler, + } + try: + scheduler_class = scheduler_class_map[scheduler] + except KeyError: + raise ValueError( + f"Unsupported scheduler {scheduler}. Should be one of {list(scheduler_class.keys())}." + ) + self.scheduler = make_scheduler( + scheduler_class, version, pipeline_type, hf_token, framework_model_dir + ) + + self.config = {} + if self.pipeline_type.is_sd_xl(): + self.config["clip_hidden_states"] = True + self.torch_inference = torch_inference + if self.torch_inference: + torch._inductor.config.conv_1x1_as_mm = True + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.epilogue_fusion = False + torch._inductor.config.coordinate_descent_check_all_directions = True + self.use_cuda_graph = use_cuda_graph + + # initialized in loadEngines() + self.models = {} + self.torch_models = {} + self.engine = {} + self.shared_device_memory = None + + # initialize lora loader and scales + self.lora_loader = None + self.lora_weights = dict() + if lora_path: + self.lora_loader = SDLoraLoader(lora_path, lora_weight, lora_scale) + assert len(lora_path) == len(lora_weight) + for i, path in enumerate(lora_path): + self.lora_weights[path] = lora_weight[i] + + # initialized in loadResources() + self.events = {} + self.generator = None + self.markers = {} + self.seed = None + self.stream = None + self.tokenizer = None + + def loadResources(self, image_height, image_width, batch_size, seed): + # Initialize noise generator + if seed: + self.seed = seed + self.generator = torch.Generator(device="cuda").manual_seed(seed) + + # Create CUDA events and stream + for stage in ["clip", "denoise", "vae", "vae_encoder", "vqgan"]: + self.events[stage] = [ + cudart.cudaEventCreate()[1], + cudart.cudaEventCreate()[1], + ] + self.stream = cudart.cudaStreamCreate()[1] + + # Allocate TensorRT I/O buffers + if not self.torch_inference: + for model_name, obj in self.models.items(): + self.engine[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict( + batch_size, image_height, image_width + ), + device=self.device, + ) + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e[0]) + cudart.cudaEventDestroy(e[1]) + + for engine in self.engine.values(): + del engine + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + for torch_model in self.torch_models.values(): + del torch_model + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def cachedModelName(self, model_name): + if self.pipeline_type.is_inpaint(): + model_name += "_inpaint" + return model_name + + def getOnnxPath(self, model_name, onnx_dir, opt=True, suffix=""): + onnx_model_dir = os.path.join( + onnx_dir, + self.cachedModelName(model_name) + suffix + (".opt" if opt else ""), + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def getEnginePath(self, model_name, engine_dir, enable_refit=False, suffix=""): + return os.path.join( + engine_dir, + self.cachedModelName(model_name) + + suffix + + (".refit" if enable_refit else "") + + ".trt" + + trt.__version__ + + ".plan", + ) + + def getWeightsMapPath(self, model_name, onnx_dir): + onnx_model_dir = os.path.join( + onnx_dir, self.cachedModelName(model_name) + ".opt" + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "weights_map.json") + + def getRefitNodesPath(self, model_name, onnx_dir, suffix=""): + onnx_model_dir = os.path.join( + onnx_dir, self.cachedModelName(model_name) + ".opt" + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "refit" + suffix + ".json") + + def getStateDictPath(self, model_name, onnx_dir, suffix=""): + onnx_model_dir = os.path.join( + onnx_dir, self.cachedModelName(model_name) + suffix + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "state_dict.pt") + + def initializeModels(self, framework_model_dir, int8, fp8): + # Load text tokenizer(s) + if not self.pipeline_type.is_sd_xl_refiner(): + self.tokenizer = make_tokenizer( + self.version, self.pipeline_type, self.hf_token, framework_model_dir + ) + if self.pipeline_type.is_sd_xl(): + self.tokenizer2 = make_tokenizer( + self.version, + self.pipeline_type, + self.hf_token, + framework_model_dir, + subfolder="tokenizer_2", + ) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + + if "clip" in self.stages: + subfolder = "text_encoder" + self.models["clip"] = CLIPModel( + **models_args, + fp16=True, + embedding_dim=get_clip_embedding_dim(self.version, self.pipeline_type), + output_hidden_states=self.config.get("clip_hidden_states", False), + subfolder=subfolder, + ) + + if "clip2" in self.stages: + subfolder = "text_encoder_2" + self.models["clip2"] = CLIPWithProjModel( + **models_args, + fp16=True, + output_hidden_states=self.config.get("clip_hidden_states", False), + subfolder=subfolder, + ) + + if "unet" in self.stages: + self.models["unet"] = UNetModel( + **models_args, + fp16=True, + int8=int8, + fp8=fp8, + controlnets=self.controlnets, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + if "unetxl" in self.stages: + if not self.controlnets: + self.models["unetxl"] = UNetXLModel( + **models_args, + fp16=True, + int8=int8, + fp8=fp8, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + else: + self.models["unetxl"] = UNetXLModelControlNet( + **models_args, + fp16=True, + int8=int8, + fp8=fp8, + controlnets=self.controlnets, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + + vae_fp16 = not self.pipeline_type.is_sd_xl() + + if "vae" in self.stages: + self.models["vae"] = VAEModel(**models_args, fp16=vae_fp16, tf32=True) + + if "vae_encoder" in self.stages: + self.models["vae_encoder"] = VAEEncoderModel(**models_args, fp16=vae_fp16) + + def loadEngines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + optimization_level=3, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + int8=False, + fp8=False, + quantization_level=2.5, + quantization_percentile=1.0, + quantization_alpha=0.8, + calibration_size=32, + calib_batch_size=2, + **_kwargs, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to store the TensorRT engines. + framework_model_dir (str): + Directory to store the framework model ckpt. + onnx_dir (str): + Directory to store the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + optimization_level (int): + Optimization level to build the TensorRT engine with. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to speed up TensorRT build. + int8 (bool): + Whether to quantize to int8 format or not (SDXL, SD15 and SD21 only). + fp8 (bool): + Whether to quantize to fp8 format or not (SDXL, SD15 and SD21 only). + quantization_level (float): + Controls which layers to quantize. 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC + quantization_percentile (float): + Control quantization scaling factors (amax) collecting range, where the minimum amax in + range(n_steps * percentile) will be collected. Recommendation: 1.0 + quantization_alpha (float): + The alpha parameter for SmoothQuant quantization used for linear layers. + Recommendation: 0.8 for SDXL + calibration_size (int): + The number of steps to use for calibrating the model for quantization. + Recommendation: 32, 64, 128 for SDXL + calib_batch_size (int): + The batch size to use for calibration. Defaults to 2. + """ + # Create directories if missing + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + # Initialize models + self.initializeModels(framework_model_dir, int8, fp8) + + # Configure pipeline models to load + model_names = self.models.keys() + lora_suffix = ( + "-" + + "-".join( + [ + str(md5(path.encode("utf-8")).hexdigest()) + + "-" + + ("%.2f" % self.lora_weights[path]) + + "-" + + ("%.2f" % self.lora_loader.scale) + for path in sorted(self.lora_loader.paths) + ] + ) + if self.lora_loader + else "" + ) + # Enable refit and LoRA merging only for UNet & UNetXL for now + do_engine_refit = dict( + zip( + model_names, + [ + not self.pipeline_type.is_sd_xl_refiner() + and enable_refit + and model_name.startswith("unet") + for model_name in model_names + ], + ) + ) + do_lora_merge = dict( + zip( + model_names, + [ + not enable_refit + and self.lora_loader + and model_name.startswith("unet") + for model_name in model_names + ], + ) + ) + # Torch fallback for VAE if specified + torch_fallback = dict( + zip(model_names, [self.torch_inference for model_name in model_names]) + ) + model_suffix = dict( + zip( + model_names, + [ + lora_suffix if do_lora_merge[model_name] else "" + for model_name in model_names + ], + ) + ) + use_int8 = dict.fromkeys(model_names, False) + use_fp8 = dict.fromkeys(model_names, False) + if int8: + assert self.pipeline_type.is_sd_xl_base() or self.version in [ + "1.4", + "1.5", + "2.1", + "2.1-base", + ], ( + "int8 quantization only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipeline" + ) + model_name = "unetxl" if self.pipeline_type.is_sd_xl() else "unet" + use_int8[model_name] = True + model_suffix[model_name] += ( + f"-int8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" + ) + elif fp8: + assert self.pipeline_type.is_sd_xl() or self.version in [ + "1.4", + "1.5", + "2.1", + "2.1-base", + ], ( + "fp8 quantization only supported for SDXL, SD1.4, SD1.5 and SD2.1 pipeline" + ) + model_name = "unetxl" if self.pipeline_type.is_sd_xl() else "unet" + use_fp8[model_name] = True + model_suffix[model_name] += ( + f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}.p{quantization_percentile}.a{quantization_alpha}" + ) + onnx_path = dict( + zip( + model_names, + [ + self.getOnnxPath( + model_name, onnx_dir, opt=False, suffix=model_suffix[model_name] + ) + for model_name in model_names + ], + ) + ) + onnx_opt_path = dict( + zip( + model_names, + [ + self.getOnnxPath( + model_name, onnx_dir, suffix=model_suffix[model_name] + ) + for model_name in model_names + ], + ) + ) + engine_path = dict( + zip( + model_names, + [ + self.getEnginePath( + model_name, + engine_dir, + do_engine_refit[model_name], + suffix=model_suffix[model_name], + ) + for model_name in model_names + ], + ) + ) + weights_map_path = dict( + zip( + model_names, + [ + ( + self.getWeightsMapPath(model_name, onnx_dir) + if do_engine_refit[model_name] + else None + ) + for model_name in model_names + ], + ) + ) + + for model_name, obj in self.models.items(): + if torch_fallback[model_name]: + continue + # Export models to ONNX and save weights name mapping + do_export_onnx = not os.path.exists( + engine_path[model_name] + ) and not os.path.exists(onnx_opt_path[model_name]) + do_export_weights_map = weights_map_path[model_name] and not os.path.exists( + weights_map_path[model_name] + ) + if do_export_onnx or do_export_weights_map: + # Non-quantized ONNX export + if not use_int8[model_name] and not use_fp8[model_name]: + obj.export_onnx( + onnx_path[model_name], + onnx_opt_path[model_name], + onnx_opset, + opt_image_height, + opt_image_width, + enable_lora_merge=do_lora_merge[model_name], + static_shape=static_shape, + lora_loader=self.lora_loader, + ) + else: + pipeline = obj.get_pipeline() + model = pipeline.unet + if use_fp8[model_name] and quantization_level == 4.0: + set_fmha(model) + + state_dict_path = self.getStateDictPath( + model_name, onnx_dir, suffix=model_suffix[model_name] + ) + if not os.path.exists(state_dict_path): + print( + f"[I] Calibrated weights not found, generating {state_dict_path}" + ) + root_dir = os.path.dirname( + os.path.abspath(sys.modules["__main__"].__file__) + ) + calibration_file = os.path.join( + root_dir, "calibration_data", "calibration-prompts.txt" + ) + calibration_prompts = load_calib_prompts( + calib_batch_size, calibration_file + ) + + # TODO check size > calibration_size + def do_calibrate(pipeline, calibration_prompts, **kwargs): + for i_th, prompts in enumerate(calibration_prompts): + if i_th >= kwargs["calib_size"]: + return + pipeline( + prompt=prompts, + num_inference_steps=kwargs["n_steps"], + negative_prompt=[ + "normal quality, low quality, worst quality, low res, blurry, nsfw, nude" + ] + * len(prompts), + ).images + + def forward_loop(model): + pipeline.unet = model + do_calibrate( + pipeline=pipeline, + calibration_prompts=calibration_prompts, + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + ) + + print( + f"[I] Performing calibration for {calibration_size} steps." + ) + if use_int8[model_name]: + quant_config = get_int8_config( + model, + quantization_level, + quantization_alpha, + quantization_percentile, + self.denoising_steps, + ) + elif use_fp8[model_name]: + quant_config = ( + SD_FP8_FP32_DEFAULT_CONFIG + if self.version == "2.1" + else SD_FP8_FP16_DEFAULT_CONFIG + ) + + # Handle LoRA + if do_lora_merge[model_name]: + assert self.lora_loader is not None + model = merge_loras(model, self.lora_loader) + check_lora(model) + mtq.quantize(model, quant_config, forward_loop) + mto.save(model, state_dict_path) + else: + mto.restore(model, state_dict_path) + + print( + f"[I] Generating quantized ONNX model: {onnx_opt_path[model_name]}" + ) + if not os.path.exists(onnx_path[model_name]): + quantize_lvl(self.version, model, quantization_level) + mtq.disable_quantizer(model, filter_func) + if use_fp8[model_name]: + generate_fp8_scales(model) + else: + model = None + obj.export_onnx( + onnx_path[model_name], + onnx_opt_path[model_name], + onnx_opset, + opt_image_height, + opt_image_width, + custom_model=model, + static_shape=static_shape, + ) + + # FIXME do_export_weights_map needs ONNX graph + if do_export_weights_map: + print(f"[I] Saving weights map: {weights_map_path[model_name]}") + obj.export_weights_map( + onnx_opt_path[model_name], weights_map_path[model_name] + ) + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if torch_fallback[model_name]: + continue + engine = engine_module.Engine(engine_path[model_name]) + if not os.path.exists(engine_path[model_name]): + update_output_names = ( + obj.get_output_names() + obj.extra_output_names + if obj.extra_output_names + else None + ) + fp16amp = obj.fp16 if not use_fp8[model_name] else False + bf16amp = obj.bf16 if not use_fp8[model_name] else False + # TF32 can be enabled for all precisions (including INT8/FP8) + tf32amp = obj.tf32 + strongly_typed = False if not use_fp8[model_name] else True + int8amp = use_int8.get("model_name", False) + precision_constraints = "prefer" if int8amp else "none" + engine.build( + onnx_opt_path[model_name], + strongly_typed=strongly_typed, + fp16=fp16amp, + bf16=bf16amp, + tf32=tf32amp, + int8=int8amp, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_refit=do_engine_refit[model_name], + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=update_output_names, + verbose=self.verbose, + builder_optimization_level=optimization_level, + precision_constraints=precision_constraints, + ) + self.engine[model_name] = engine + + # Load TensorRT engines + for model_name, obj in self.models.items(): + if torch_fallback[model_name]: + continue + self.engine[model_name].load() + if do_engine_refit[model_name] and self.lora_loader: + assert weights_map_path[model_name] + with open(weights_map_path[model_name], "r") as fp_wts: + print(f"[I] Loading weights map: {weights_map_path[model_name]} ") + [weights_name_mapping, weights_shape_mapping] = json.load(fp_wts) + refit_weights_path = self.getRefitNodesPath( + model_name, engine_dir, suffix=lora_suffix + ) + if not os.path.exists(refit_weights_path): + print(f"[I] Saving refit weights: {refit_weights_path}") + model = merge_loras(obj.get_model(), self.lora_loader) + refit_weights, updated_weight_names = ( + engine_module.get_refit_weights( + model.state_dict(), + onnx_opt_path[model_name], + weights_name_mapping, + weights_shape_mapping, + ) + ) + torch.save( + (refit_weights, updated_weight_names), refit_weights_path + ) + unload_torch_model(model) + else: + print(f"[I] Loading refit weights: {refit_weights_path}") + refit_weights, updated_weight_names = torch.load( + refit_weights_path + ) + self.engine[model_name].refit(refit_weights, updated_weight_names) + + # Load torch models + for model_name, obj in self.models.items(): + if torch_fallback[model_name]: + self.torch_models[model_name] = obj.get_model( + torch_inference=self.torch_inference + ) + + def calculateMaxDeviceMemory(self): + max_device_memory = 0 + for model_name, engine in self.engine.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activateEngines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.calculateMaxDeviceMemory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engine.values(): + engine.activate(device_memory=self.shared_device_memory) + + def runEngine(self, model_name, feed_dict): + engine = self.engine[model_name] + return engine.infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) + + def initialize_latents( + self, + batch_size, + unet_channels, + latent_height, + latent_width, + latents_dtype=torch.float32, + ): + latents_dtype = latents_dtype # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn( + latents_shape, + device=self.device, + dtype=latents_dtype, + generator=self.generator, + ) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def profile_start(self, name, color="blue"): + if self.nvtx_profile: + self.markers[name] = nvtx.start_range(message=name, color=color) + if name in self.events: + cudart.cudaEventRecord(self.events[name][0], 0) + + def profile_stop(self, name): + if name in self.events: + cudart.cudaEventRecord(self.events[name][1], 0) + if self.nvtx_profile: + nvtx.end_range(self.markers[name]) + + def preprocess_images(self, batch_size, images=()): + if not images: + return () + self.profile_start("preprocess", color="pink") + input_images = [] + for image in images: + image = image.to(self.device).float() + if image.shape[0] != batch_size: + image = image.repeat(batch_size, 1, 1, 1) + input_images.append(image) + self.profile_stop("preprocess") + return tuple(input_images) + + def preprocess_controlnet_images(self, batch_size, images=None): + """ + images: List of PIL.Image.Image + """ + if images is None: + return None + self.profile_start("preprocess", color="pink") + images = [ + (np.array(i.convert("RGB")).astype(np.float32) / 255.0)[..., None] + .transpose(3, 2, 0, 1) + .repeat(batch_size, axis=0) + for i in images + ] + # do_classifier_free_guidance + images = [ + torch.cat([torch.from_numpy(i).to(self.device).float()] * 2) for i in images + ] + images = torch.cat([image[None, ...] for image in images], dim=0) + self.profile_stop("preprocess") + return images + + def encode_prompt( + self, + prompt, + negative_prompt, + encoder="clip", + pooled_outputs=False, + output_hidden_states=False, + ): + self.profile_start("clip", color="green") + + tokenizer = self.tokenizer2 if encoder == "clip2" else self.tokenizer + + def tokenize(prompt, output_hidden_states): + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + text_hidden_states = None + if self.torch_inference: + outputs = self.torch_models[encoder]( + text_input_ids, output_hidden_states=output_hidden_states + ) + text_embeddings = outputs[0].clone() + if output_hidden_states: + text_hidden_states = outputs["hidden_states"][-2].clone() + else: + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.runEngine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + text_hidden_states = outputs["hidden_states"].clone() + return text_embeddings, text_hidden_states + + # Tokenize prompt + text_embeddings, text_hidden_states = tokenize(prompt, output_hidden_states) + + if self.do_classifier_free_guidance: + # Tokenize negative prompt + uncond_embeddings, uncond_hidden_states = tokenize( + negative_prompt, output_hidden_states + ) + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to( + dtype=torch.float16 + ) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = ( + torch.cat([uncond_hidden_states, text_hidden_states]).to( + dtype=torch.float16 + ) + if self.do_classifier_free_guidance + else text_hidden_states + ) + + self.profile_stop("clip") + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + # from diffusers (get_timesteps) + def get_timesteps(self, num_inference_steps, strength, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min( + int(num_inference_steps * strength), num_inference_steps + ) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_start * self.scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if self.scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd devirative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + + def denoise_latent( + self, + latents, + text_embeddings, + denoiser="unet", + timesteps=None, + step_offset=0, + mask=None, + masked_image_latents=None, + image_guidance=1.5, + controlnet_imgs=None, + controlnet_scales=None, + text_embeds=None, + time_ids=None, + ): + assert image_guidance > 1.0, "Image guidance has to be > 1.0" + + controlnet_imgs = self.preprocess_controlnet_images( + latents.shape[0], controlnet_imgs + ) + + do_autocast = self.torch_inference != "" and self.models[denoiser].fp16 + with torch.autocast("cuda", enabled=do_autocast): + self.profile_start("denoise", color="blue") + for step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, timestep + ) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) + + # Predict the noise residual + if self.torch_inference: + params = { + "sample": latent_model_input, + "timestep": timestep, + "encoder_hidden_states": text_embeddings, + } + if controlnet_imgs is not None: + params.update( + { + "images": controlnet_imgs, + "controlnet_scales": controlnet_scales, + } + ) + added_cond_kwargs = {} + if text_embeds != None: + added_cond_kwargs.update({"text_embeds": text_embeds}) + if time_ids != None: + added_cond_kwargs.update({"time_ids": time_ids}) + if text_embeds != None or time_ids != None: + params.update({"added_cond_kwargs": added_cond_kwargs}) + noise_pred = self.torch_models[denoiser](**params)["sample"] + else: + timestep_float = ( + timestep.float() + if timestep.dtype != torch.float32 + else timestep + ) + + params = { + "sample": latent_model_input, + "timestep": timestep_float, + "encoder_hidden_states": text_embeddings, + } + if controlnet_imgs is not None: + params.update( + { + "images": controlnet_imgs, + "controlnet_scales": controlnet_scales, + } + ) + if text_embeds != None: + params.update({"text_embeds": text_embeds}) + if time_ids != None: + params.update({"time_ids": time_ids}) + noise_pred = self.runEngine(denoiser, params)["latent"] + + # Perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # from diffusers (prepare_extra_step_kwargs) + extra_step_kwargs = {} + if "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ): + # TODO: configurable eta + eta = 0.0 + extra_step_kwargs["eta"] = eta + if "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ): + extra_step_kwargs["generator"] = self.generator + + latents = self.scheduler.step( + noise_pred, + timestep, + latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + latents = 1.0 / self.vae_scaling_factor * latents + latents = latents.to(dtype=torch.float32) + + self.profile_stop("denoise") + return latents + + def encode_image(self, input_image): + self.profile_start("vae_encoder", color="red") + cast_to = ( + torch.float16 + if self.models["vae_encoder"].fp16 + else torch.bfloat16 + if self.models["vae_encoder"].bf16 + else torch.float32 + ) + input_image = input_image.to(dtype=cast_to) + if self.torch_inference: + image_latents = self.torch_models["vae_encoder"](input_image) + else: + image_latents = self.runEngine("vae_encoder", {"images": input_image})[ + "latent" + ] + image_latents = self.vae_scaling_factor * image_latents + self.profile_stop("vae_encoder") + return image_latents + + def decode_latent(self, latents): + self.profile_start("vae", color="red") + cast_to = ( + torch.float16 + if self.models["vae"].fp16 + else torch.bfloat16 + if self.models["vae"].bf16 + else torch.float32 + ) + latents = latents.to(dtype=cast_to) + + if self.torch_inference: + images = self.torch_models["vae"](latents, return_dict=False)[0] + else: + images = self.runEngine("vae", {"latent": latents})["images"] + self.profile_stop("vae") + return images + + def print_summary(self, denoising_steps, walltime_ms, batch_size): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + if "vae_encoder" in self.stages: + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime( + self.events["vae_encoder"][0], self.events["vae_encoder"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP", + cudart.cudaEventElapsedTime( + self.events["clip"][0], self.events["clip"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "UNet" + + ("+CNet" if self.pipeline_type.is_controlnet() else "") + + " x " + + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["denoise"][0], self.events["denoise"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Dec", + cudart.cudaEventElapsedTime( + self.events["vae"][0], self.events["vae"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print("Throughput: {:.2f} image/s".format(batch_size * 1000.0 / walltime_ms)) + + def save_image(self, images, pipeline, prompt, seed): + # Save image + image_name_prefix = ( + pipeline + + "".join( + set( + ["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))] + ) + ) + + "-" + + str(seed) + + "-" + ) + image_name_suffix = "torch" if self.torch_inference else "trt" + image_module.save_image( + images, self.output_dir, image_name_prefix, image_name_suffix + ) + + def infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + input_image=None, + image_strength=0.75, + mask_image=None, + controlnet_scales=None, + aesthetic_score=6.0, + negative_aesthetic_score=2.5, + warmup=False, + verbose=False, + save_image=True, + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + input_image (image): + Input image used to initialize the latents or to be inpainted. + image_strength (float): + Strength of transformation applied to input_image. Must be between 0 and 1. + mask_image (image): + Mask image containg the region to be inpainted. + controlnet_scales (torch.Tensor) + A tensor which containes ControlNet scales, essential for multi ControlNet. + Must be equal to number of Controlnets. + warmup (bool): + Indicate if this is a warmup run. + verbose (bool): + Verbose in logging + save_image (bool): + Save the generated image (if applicable) + """ + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + # Spatial dimensions of latent tensor + latent_height = image_height // 8 + latent_width = image_width // 8 + + if self.generator and self.seed: + self.generator.manual_seed(self.seed) + + num_inference_steps = self.denoising_steps + + with torch.inference_mode(), trt.Runtime(TRT_LOGGER): + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # TODO: support custom timesteps + timesteps = None + if timesteps is not None: + if "timesteps" not in set( + inspect.signature(self.scheduler.set_timesteps).parameters.keys() + ): + raise ValueError( + f"The current scheduler class {self.scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + self.scheduler.set_timesteps(timesteps=timesteps, device=self.device) + assert self.denoising_steps == len(self.scheduler.timesteps) + else: + self.scheduler.set_timesteps(self.denoising_steps, device=self.device) + timesteps = self.scheduler.timesteps.to(self.device) + + denoise_kwargs = {} + if not ( + self.pipeline_type.is_img2img() or self.pipeline_type.is_sd_xl_refiner() + ): + # Initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=latent_height, + latent_width=latent_width, + ) + if self.pipeline_type.is_controlnet(): + denoise_kwargs.update( + { + "controlnet_imgs": input_image, + "controlnet_scales": controlnet_scales, + } + ) + + # Pre-process and VAE encode input image + if ( + self.pipeline_type.is_img2img() + or self.pipeline_type.is_inpaint() + or self.pipeline_type.is_sd_xl_refiner() + ): + assert input_image != None + # Initialize timesteps and pre-process input image + timesteps, num_inference_steps = self.get_timesteps( + self.denoising_steps, image_strength + ) + denoise_kwargs.update({"timesteps": timesteps}) + if self.pipeline_type.is_img2img() or self.pipeline_type.is_sd_xl_refiner(): + latent_timestep = timesteps[:1].repeat(batch_size) + input_image = self.preprocess_images(batch_size, (input_image,))[0] + # Encode if not a latent + image_latents = ( + input_image + if input_image.shape[1] == 4 + else self.encode_image(input_image) + ) + # Add noise to latents using timesteps + noise = torch.randn( + image_latents.shape, + generator=self.generator, + device=self.device, + dtype=torch.float32, + ) + latents = self.scheduler.add_noise( + image_latents, noise, latent_timestep + ) + elif self.pipeline_type.is_inpaint(): + mask, mask_image = self.preprocess_images( + batch_size, + image_module.prepare_mask_and_masked_image(input_image, mask_image), + ) + mask = torch.nn.functional.interpolate( + mask, size=(latent_height, latent_width) + ) + mask = torch.cat([mask] * 2) + masked_image_latents = self.encode_image(mask_image) + masked_image_latents = torch.cat([masked_image_latents] * 2) + denoise_kwargs.update( + {"mask": mask, "masked_image_latents": masked_image_latents} + ) + + # CLIP text encoder(s) + if self.pipeline_type.is_sd_xl(): + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + pooled_outputs=True, + output_hidden_states=True, + ) + + # Merge text embeddings + if self.pipeline_type.is_sd_xl_base(): + text_embeddings = self.encode_prompt( + prompt, negative_prompt, output_hidden_states=True + ) + text_embeddings = torch.cat( + [text_embeddings, text_embeddings2], dim=-1 + ) + else: + text_embeddings = text_embeddings2 + + # Time embeddings + def _get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype, + aesthetic_score=None, + negative_aesthetic_score=None, + ): + if ( + self.pipeline_type.is_sd_xl_refiner() + ): # self.requires_aesthetics_score: + add_time_ids = list( + original_size + crops_coords_top_left + (aesthetic_score,) + ) + if self.do_classifier_free_guidance: + add_neg_time_ids = list( + original_size + + crops_coords_top_left + + (negative_aesthetic_score,) + ) + else: + add_time_ids = list( + original_size + crops_coords_top_left + target_size + ) + if self.do_classifier_free_guidance: + add_neg_time_ids = list( + original_size + crops_coords_top_left + target_size + ) + add_time_ids = torch.tensor( + [add_time_ids], dtype=dtype, device=self.device + ) + if self.do_classifier_free_guidance: + add_neg_time_ids = torch.tensor( + [add_neg_time_ids], dtype=dtype, device=self.device + ) + add_time_ids = torch.cat( + [add_neg_time_ids, add_time_ids], dim=0 + ) + return add_time_ids + + original_size = (image_height, image_width) + crops_coords_top_left = (0, 0) + target_size = (image_height, image_width) + if self.pipeline_type.is_sd_xl_refiner(): + add_time_ids = _get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=text_embeddings.dtype, + aesthetic_score=aesthetic_score, + negative_aesthetic_score=negative_aesthetic_score, + ) + else: + add_time_ids = _get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + dtype=text_embeddings.dtype, + ) + add_time_ids = add_time_ids.repeat(batch_size, 1) + denoise_kwargs.update( + {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + ) + else: + text_embeddings = self.encode_prompt(prompt, negative_prompt) + + # UNet denoiser + (optional) ControlNet(s) + denoiser = "unetxl" if self.pipeline_type.is_sd_xl() else "unet" + latents = self.denoise_latent( + latents, text_embeddings, denoiser=denoiser, **denoise_kwargs + ) + + # VAE decode latent (if applicable) + if self.return_latents: + latents = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + if not warmup: + self.print_summary(num_inference_steps, walltime_ms, batch_size) + if not self.return_latents and save_image: + # post-process images + images = ( + ((images + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + self.save_image( + images, self.pipeline_type.name.lower(), prompt, self.seed + ) + + return (latents, walltime_ms) if self.return_latents else (images, walltime_ms) + + def run( + self, + prompt, + negative_prompt, + height, + width, + batch_size, + batch_count, + num_warmup_runs, + use_cuda_graph, + **kwargs, + ): + # Process prompt + if not isinstance(prompt, list): + raise ValueError( + f"`prompt` must be of type `str` list, but is {type(prompt)}" + ) + prompt = prompt * batch_size + + if not isinstance(negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` list, but is {type(negative_prompt)}" + ) + if len(negative_prompt) == 1: + negative_prompt = negative_prompt * batch_size + + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer( + prompt, negative_prompt, height, width, warmup=True, **kwargs + ) + + for _ in range(batch_count): + print("[I] Running StableDiffusion pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(prompt, negative_prompt, height, width, warmup=False, **kwargs) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_video_diffusion_pipeline.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_video_diffusion_pipeline.py new file mode 100644 index 000000000..e301dd883 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/stable_video_diffusion_pipeline.py @@ -0,0 +1,1026 @@ +# +# Copyright 2024 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import pathlib +import random +import sys +import time +from typing import Optional + +import modelopt.torch.opt as mto +import modelopt.torch.quantization as mtq +import tensorrt as trt +import torch +from cuda import cudart +from diffusers.image_processor import VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor +from tqdm.auto import tqdm + +import demo_diffusion.engine as engine_module +import demo_diffusion.image as image_module +from demo_diffusion.model import ( + CLIPImageProcessorModel, + CLIPVisionWithProjModel, + UNetTemporalModel, + VAEDecTemporalModel, +) +from demo_diffusion.pipeline.calibrate import load_calibration_images +from demo_diffusion.pipeline.stable_diffusion_pipeline import StableDiffusionPipeline +from demo_diffusion.pipeline.type import PIPELINE_TYPE +from demo_diffusion.utils_modelopt import ( + SD_FP8_FP16_DEFAULT_CONFIG, + check_lora, + filter_func, + generate_fp8_scales, + quantize_lvl, +) + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def _GiB(val): + return val * 1 << 30 + + +def _append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +class StableVideoDiffusionPipeline(StableDiffusionPipeline): + """ + Application showcasing the acceleration of Stable Video Diffusion pipelines using NVidia TensorRT. + """ + + def __init__( + self, + version="svd-xt-1.1", + pipeline_type=PIPELINE_TYPE.IMG2VID, + min_guidance_scale: float = 1.0, + max_guidance_scale: float = 3.0, + decode_chunk_size: Optional[int] = None, + **kwargs, + ): + """ + Initializes the Diffusion pipeline. + + Args: + version (str): + The version of the pipeline. Should be one of [svd-xt-1.1] + pipeline_type (PIPELINE_TYPE): + Type of current pipeline. + min_guidance_scale (`float`, *optional*, defaults to 1.0): + The minimum guidance scale. Used for the classifier free guidance with first frame. + max_guidance_scale (`float`, *optional*, defaults to 3.0): + The maximum guidance scale. Used for the classifier free guidance with last frame. + `max_guidance_scale = 1` corresponds to doing no classifier free guidance. + decode_chunk_size (`int`, *optional*): + The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency + between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once + for maximal quality. Reduce `decode_chunk_size` to reduce memory usage. + """ + super().__init__(version=version, pipeline_type=pipeline_type, **kwargs) + self.min_guidance_scale = min_guidance_scale + self.max_guidance_scale = max_guidance_scale + self.do_classifier_free_guidance = max_guidance_scale > 1 + # FIXME vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = 8 + # FIXME num_frames = self.config.num_frames + select_num_frames = { + "svd-xt-1.1": 25, + } + self.num_frames = select_num_frames.get(version, 14) + # TODO decode_chunk_size from args + self.decode_chunk_size = 8 if not decode_chunk_size else decode_chunk_size + # TODO: scaling_factor = vae.config.scaling_factor + self.scaling_factor = 0.18215 + + # TODO user configurable cuda_device_id + cuda_device_id = 0 + vram_size = cudart.cudaGetDeviceProperties(cuda_device_id)[1].totalGlobalMem + self.low_vram = vram_size < _GiB(40) + if self.low_vram: + print( + f"[W] WARNING low VRAM ({vram_size / _GiB(1):.2f} GB) mode selected. Certain optimizations may be skipped." + ) + if self.use_cuda_graph and self.low_vram: + print("[W] WARNING CUDA graph disabled in low VRAM mode.") + self.use_cuda_graph = False + + self.config = {} + if self.pipeline_type.is_img2vid(): + self.config["clip_vis_torch_fallback"] = True + self.config["clip_imgfe_torch_fallback"] = True + self.config["vae_temp_torch_fallback"] = True + + # initialized in loadEngines() + self.max_shared_device_memory_size = 0 + + def loadResources(self, image_height, image_width, batch_size, seed): + # Initialize noise generator + self.seed = seed + self.generator = ( + torch.Generator(device="cuda").manual_seed(seed) if seed else None + ) + + # Create CUDA events and stream + for stage in ["clip", "denoise", "vae", "vae_encoder"]: + self.events[stage] = [ + cudart.cudaEventCreate()[1], + cudart.cudaEventCreate()[1], + ] + self.stream = cudart.cudaStreamCreate()[1] + + # Allocate shared device memory for TensorRT engines + if not self.low_vram and not self.torch_inference: + for model_name in self.models.keys(): + if not self.torch_fallback[model_name]: + self.max_shared_device_memory_size = max( + self.max_shared_device_memory_size, + self.engine[model_name].engine.device_memory_size, + ) + self.shared_device_memory = cudart.cudaMalloc( + self.max_shared_device_memory_size + )[1] + # Activate TensorRT engines + for model_name in self.models.keys(): + if not self.torch_fallback[model_name]: + self.engine[model_name].activate( + device_memory=self.shared_device_memory + ) + alloc_shape = self.models[model_name].get_shape_dict( + batch_size, image_height, image_width + ) + self.engine[model_name].allocate_buffers( + shape_dict=alloc_shape, device=self.device + ) + + def loadEngines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_all_tactics=False, + timing_cache=None, + fp8=False, + quantization_level=0.0, + calibration_size=32, + calib_batch_size=2, + **_kwargs, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to store the TensorRT engines. + framework_model_dir (str): + Directory to store the framework model ckpt. + onnx_dir (str): + Directory to store the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to speed up TensorRT build. + fp8 (bool): + Whether to quantize to fp8 format or not. + quantization_level (float): + Controls which layers to quantize. + calibration_size (int): + The number of steps to use for calibrating the model for quantization. + calib_batch_size (int): + The batch size to use for calibration. Defaults to 2. + """ + # Create directories if missing + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + # Load pipeline models + models_args = { + "version": self.version, + "pipeline": self.pipeline_type, + "device": self.device, + "hf_token": self.hf_token, + "verbose": self.verbose, + "framework_model_dir": framework_model_dir, + "max_batch_size": self.max_batch_size, + } + if "clip-vis" in self.stages: + self.models["clip-vis"] = CLIPVisionWithProjModel( + **models_args, subfolder="image_encoder" + ) + if "clip-imgfe" in self.stages: + self.models["clip-imgfe"] = CLIPImageProcessorModel( + **models_args, subfolder="feature_extractor" + ) + if "unet-temp" in self.stages: + self.models["unet-temp"] = UNetTemporalModel( + **models_args, + fp16=True, + fp8=fp8, + num_frames=self.num_frames, + do_classifier_free_guidance=self.do_classifier_free_guidance, + ) + if "vae-temp" in self.stages: + self.models["vae-temp"] = VAEDecTemporalModel( + **models_args, decode_chunk_size=self.decode_chunk_size + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + # Configure pipeline models to load + model_names = self.models.keys() + self.torch_fallback = dict( + zip( + model_names, + [ + self.torch_inference + or self.config.get( + model_name.replace("-", "_") + "_torch_fallback", False + ) + for model_name in model_names + ], + ) + ) + onnx_path = dict( + zip( + model_names, + [ + self.getOnnxPath(model_name, onnx_dir, opt=False) + for model_name in model_names + ], + ) + ) + onnx_opt_path = dict( + zip( + model_names, + [self.getOnnxPath(model_name, onnx_dir) for model_name in model_names], + ) + ) + engine_path = dict( + zip( + model_names, + [ + self.getEnginePath(model_name, engine_dir) + for model_name in model_names + ], + ) + ) + do_engine_refit = dict( + zip( + model_names, + [ + enable_refit and model_name.startswith("unet") + for model_name in model_names + ], + ) + ) + + # Quantization. + model_suffix = dict(zip(model_names, ["" for model_name in model_names])) + use_fp8 = dict.fromkeys(model_names, False) + if fp8: + model_name = "unet-temp" + use_fp8[model_name] = True + model_suffix[model_name] += ( + f"-fp8.l{quantization_level}.bs2.s{self.denoising_steps}.c{calibration_size}" + ) + onnx_path = { + model_name: self.getOnnxPath( + model_name, onnx_dir, opt=False, suffix=model_suffix[model_name] + ) + for model_name in model_names + } + onnx_opt_path = { + model_name: self.getOnnxPath( + model_name, onnx_dir, suffix=model_suffix[model_name] + ) + for model_name in model_names + } + engine_path = { + model_name: self.getEnginePath( + model_name, + engine_dir, + do_engine_refit[model_name], + suffix=model_suffix[model_name], + ) + for model_name in model_names + } + weights_map_path = { + model_name: ( + self.getWeightsMapPath(model_name, onnx_dir) + if do_engine_refit[model_name] + else None + ) + for model_name in model_names + } + + # Export models to ONNX + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + do_export_onnx = not os.path.exists( + engine_path[model_name] + ) and not os.path.exists(onnx_opt_path[model_name]) + do_export_weights_map = weights_map_path[model_name] and not os.path.exists( + weights_map_path[model_name] + ) + if do_export_onnx or do_export_weights_map: + if use_fp8[model_name]: + pipeline = obj.get_pipeline() + model = pipeline.unet + + state_dict_path = self.getStateDictPath( + model_name, onnx_dir, suffix=model_suffix[model_name] + ) + if not os.path.exists(state_dict_path): + # Load calibration images + print( + f"[I] Calibrated weights not found, generating {state_dict_path}" + ) + root_dir = os.path.dirname( + os.path.abspath(sys.modules["__main__"].__file__) + ) + calibration_image_folder = os.path.join( + root_dir, "calibration_data", "calibration-images" + ) + calibration_image_list = load_calibration_images( + calibration_image_folder + ) + print("Number of images loaded:", len(calibration_image_list)) + + # TODO check size > calibration_size + def do_calibrate(pipeline, calibration_images, **kwargs): + for i_th, image in enumerate(calibration_images): + if i_th >= kwargs["calib_size"]: + return + pipeline( + image=image, + num_inference_steps=kwargs["n_steps"], + ).frames[0] + + def forward_loop(model): + pipeline.unet = model + do_calibrate( + pipeline=pipeline, + calibration_images=calibration_image_list, + calib_size=calibration_size // calib_batch_size, + n_steps=self.denoising_steps, + ) + + print( + f"[I] Performing calibration for {calibration_size} steps." + ) + if use_fp8[model_name]: + quant_config = SD_FP8_FP16_DEFAULT_CONFIG + check_lora(model) + mtq.quantize(model, quant_config, forward_loop) + mto.save(model, state_dict_path) + else: + mto.restore(model, state_dict_path) + + print( + f"[I] Generating quantized ONNX model: {onnx_opt_path[model_name]}" + ) + if not os.path.exists(onnx_path[model_name]): + """ + Error: Torch bug, ONNX export failed due to unknown kernel shape in QuantConv3d. + TRT_FP8QuantizeLinear and TRT_FP8DequantizeLinear operations in UNetSpatioTemporalConditionModel for svd + cause issues. Inputs on different devices (CUDA vs CPU) may contribute to the problem. + """ + quantize_lvl( + self.version, + model, + quantization_level, + enable_conv_3d=False, + ) + mtq.disable_quantizer(model, filter_func) + if use_fp8[model_name]: + generate_fp8_scales(model) + else: + model = None + + obj.export_onnx( + onnx_path[model_name], + onnx_opt_path[model_name], + onnx_opset, + opt_image_height, + opt_image_width, + custom_model=model, + static_shape=static_shape, + ) + else: + obj.export_onnx( + onnx_path[model_name], + onnx_opt_path[model_name], + onnx_opset, + opt_image_height, + opt_image_width, + ) + + # Clean model cache + torch.cuda.empty_cache() + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if self.torch_fallback[model_name]: + continue + engine = engine_module.Engine(engine_path[model_name]) + if not os.path.exists(engine_path[model_name]): + update_output_names = ( + obj.get_output_names() + obj.extra_output_names + if obj.extra_output_names + else None + ) + engine.build( + onnx_opt_path[model_name], + fp16=True, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_shape=static_shape, + ), + enable_refit=do_engine_refit[model_name], + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=update_output_names, + native_instancenorm=False, + ) + self.engine[model_name] = engine + + # Load TensorRT engines + for model_name in self.models.keys(): + if not self.torch_fallback[model_name]: + self.engine[model_name].load() + + def activateEngines(self, model_name, alloc_shape=None): + if not self.torch_fallback[model_name]: + device_memory_update = self.low_vram and not self.shared_device_memory + if device_memory_update: + assert not self.use_cuda_graph + # Reclaim GPU memory from torch cache + torch.cuda.empty_cache() + self.shared_device_memory = cudart.cudaMalloc( + self.max_shared_device_memory_size + )[1] + # Create TensorRT execution context + if not self.engine[model_name].context: + assert not self.use_cuda_graph + self.engine[model_name].activate( + device_memory=self.shared_device_memory + ) + if device_memory_update: + self.engine[model_name].reactivate( + device_memory=self.shared_device_memory + ) + if alloc_shape and not self.engine[model_name].tensors: + assert not self.use_cuda_graph + self.engine[model_name].allocate_buffers( + shape_dict=alloc_shape, device=self.device + ) + else: + # Load torch model + if model_name not in self.torch_models: + self.torch_models[model_name] = self.models[model_name].get_model( + torch_inference=self.torch_inference + ) + + def deactivateEngines(self, model_name, release_model=True): + if not release_model: + return + if not self.torch_fallback[model_name]: + assert not self.use_cuda_graph + self.engine[model_name].deallocate_buffers() + self.engine[model_name].deactivate() + # Shared device memory deallocated only in low VRAM mode + if self.low_vram and self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + self.shared_device_memory = None + else: + del self.torch_models[model_name] + + def print_summary(self, denoising_steps, walltime_ms, batch_size, num_frames): + print("|-----------------|--------------|") + print("| {:^15} | {:^12} |".format("Module", "Latency")) + print("|-----------------|--------------|") + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime( + self.events["vae_encoder"][0], self.events["vae_encoder"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "CLIP", + cudart.cudaEventElapsedTime( + self.events["clip"][0], self.events["clip"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "UNet" + + ("+CNet" if self.pipeline_type.is_controlnet() else "") + + " x " + + str(denoising_steps), + cudart.cudaEventElapsedTime( + self.events["denoise"][0], self.events["denoise"][1] + )[1], + ) + ) + print( + "| {:^15} | {:>9.2f} ms |".format( + "VAE-Dec", + cudart.cudaEventElapsedTime( + self.events["vae"][0], self.events["vae"][1] + )[1], + ) + ) + print("|-----------------|--------------|") + print("| {:^15} | {:>9.2f} ms |".format("Pipeline", walltime_ms)) + print("|-----------------|--------------|") + print( + "Throughput: {:.2f} videos/min ({} frames)".format( + batch_size * 60000.0 / walltime_ms, num_frames + ) + ) + + def save_video(self, frames, pipeline, seed): + video_name_prefix = "-".join( + [pipeline, "fp16", str(seed), str(random.randint(1000, 9999))] + ) + video_name_suffix = "torch" if self.torch_inference else "trt" + video_path = video_name_prefix + "-" + video_name_suffix + ".gif" + print(f"Saving video to: {video_path}") + frames[0].save( + os.path.join(self.output_dir, video_path), + save_all=True, + optimize=False, + append_images=frames[1:], + loop=0, + ) + + def _encode_image(self, image, num_videos_per_prompt, do_classifier_free_guidance): + dtype = next(self.torch_models["clip-vis"].parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.image_processor.pil_to_numpy(image) + image = self.image_processor.numpy_to_pt(image) + + # We normalize the image before resizing to match with the original implementation. + # Then we unnormalize it after resizing. + image = image * 2.0 - 1.0 + image = image_module.resize_with_antialiasing(image, (224, 224)) + image = (image + 1.0) / 2.0 + + # Normalize the image with for CLIP input + image = self.torch_models["clip-imgfe"]( + images=image, + do_normalize=True, + do_center_crop=False, + do_resize=False, + do_rescale=False, + return_tensors="pt", + ).pixel_values + + image = image.to(device=self.device, dtype=dtype) + image_embeddings = self.torch_models["clip-vis"](image).image_embeds + image_embeddings = image_embeddings.unsqueeze(1) + + # duplicate image embeddings for each generation per prompt, using mps friendly method + bs_embed, seq_len, _ = image_embeddings.shape + image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1) + image_embeddings = image_embeddings.view( + bs_embed * num_videos_per_prompt, seq_len, -1 + ) + + if do_classifier_free_guidance: + negative_image_embeddings = torch.zeros_like(image_embeddings) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_embeddings = torch.cat([negative_image_embeddings, image_embeddings]) + + return image_embeddings + + def _encode_vae_image( + self, + image: torch.Tensor, + device, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + image = image.to(device=device) + image_latents = self.torch_models["vae-temp"].encode(image).latent_dist.mode() + + if do_classifier_free_guidance: + negative_image_latents = torch.zeros_like(image_latents) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + image_latents = torch.cat([negative_image_latents, image_latents]) + + # duplicate image_latents for each generation per prompt, using mps friendly method + image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1) + + return image_latents + + def _get_add_time_ids( + self, + fps, + motion_bucket_id, + noise_aug_strength, + dtype, + batch_size, + num_videos_per_prompt, + do_classifier_free_guidance, + ): + add_time_ids = [fps, motion_bucket_id, noise_aug_strength] + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1) + + if do_classifier_free_guidance: + add_time_ids = torch.cat([add_time_ids, add_time_ids]) + + return add_time_ids + + def prepare_latents( + self, + batch_size, + num_frames, + num_channels_latents, + height, + width, + dtype, + device, + latents=None, + ): + shape = ( + batch_size, + num_frames, + num_channels_latents // 2, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) + + if latents is None: + latents = randn_tensor( + shape, generator=self.generator, device=device, dtype=dtype + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def decode_latents(self, latents, num_frames, decode_chunk_size): + # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width] + latents = latents.flatten(0, 1) + + latents = 1 / self.scaling_factor * latents + + # decode decode_chunk_size frames at a time to avoid OOM + frames = [] + for i in range(0, latents.shape[0], decode_chunk_size): + num_frames_in = latents[i : i + decode_chunk_size].shape[0] + # TODO only pass num_frames_in if it's expected + if self.torch_fallback["vae-temp"]: + frame = ( + self.torch_models["vae-temp"] + .decode( + latents[i : i + decode_chunk_size], num_frames=num_frames_in + ) + .sample + ) + else: + params = { + "latent": latents[i : i + decode_chunk_size], + # FIXME segfault + #'num_frames_in': torch.Tensor([num_frames_in]).to(device=latents.device, dtype=torch.int64), + } + frame = self.runEngine("vae-temp", params)["frames"] + frames.append(frame) + frames = torch.cat(frames, dim=0) + + # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width] + frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute( + 0, 2, 1, 3, 4 + ) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + frames = frames.float() + return frames + + def infer( + self, + input_image, + image_height: int, + image_width: int, + fps: int = 7, + motion_bucket_id: int = 127, + noise_aug_strength: int = 0.02, + num_videos_per_prompt: Optional[int] = 1, + warmup: bool = False, + save_video: bool = True, + ): + """ + Run the video diffusion pipeline. + + Args: + input_image (image): + Input image used to initialize the latents or to be inpainted. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + fps (`int`, *optional*, defaults to 7): + Frames per second. The rate at which the generated images shall be exported to a video after generation. + Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training. + motion_bucket_id (`int`, *optional*, defaults to 127): + The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video. + noise_aug_strength (`int`, *optional*, defaults to 0.02): + The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion. + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + warmup (bool): + Indicate if this is a warmup run. + save_video (bool): + Save the video image. + """ + + if self.generator and self.seed: + self.generator.manual_seed(self.seed) + + # TODO + batch_size = 1 + # Fast warmup + denoising_steps = 1 if warmup else self.denoising_steps + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + class LoadModelContext: + def __init__(ctx, model_name, alloc_shape=None, release_model=False): + ctx.model_name = model_name + ctx.release_model = release_model + ctx.alloc_shape = alloc_shape + + def __enter__(ctx): + self.activateEngines(ctx.model_name, alloc_shape=ctx.alloc_shape) + + def __exit__(ctx, exc_type, exc_val, exc_tb): + self.deactivateEngines(ctx.model_name, release_model=ctx.release_model) + + # Release model opportunistically in TensorRT pipeline only in low VRAM mode + release_model = self.low_vram and not self.torch_inference + with torch.inference_mode(), torch.autocast("cuda"), trt.Runtime(TRT_LOGGER): + with ( + LoadModelContext("clip-imgfe", release_model=release_model), + LoadModelContext("clip-vis", release_model=release_model), + ): + self.profile_start("clip", color="green") + image_embeddings = self._encode_image( + input_image, num_videos_per_prompt, self.do_classifier_free_guidance + ) + self.profile_stop("clip") + # NOTE Stable Diffusion Video was conditioned on fps - 1 + fps = fps - 1 + + self.profile_start("preprocess", color="pink") + input_image = self.image_processor.preprocess( + input_image, height=image_height, width=image_width + ).to(self.device) + noise = randn_tensor( + input_image.shape, + generator=self.generator, + device=input_image.device, + dtype=input_image.dtype, + ) + input_image = input_image + noise_aug_strength * noise + self.profile_stop("preprocess") + + # TODO + # assert self.torch_models['vae-temp'].dtype == torch.float32 + + with LoadModelContext("vae-temp"): + self.profile_start("vae_encoder", color="red") + image_latents = self._encode_vae_image( + input_image, + self.device, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + image_latents = image_latents.to(image_embeddings.dtype) + self.profile_stop("vae_encoder") + + # Repeat the image latents for each frame so we can concatenate them with the noise + # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width] + image_latents = image_latents.unsqueeze(1).repeat( + 1, self.num_frames, 1, 1, 1 + ) + + # Get Added Time IDs + added_time_ids = self._get_add_time_ids( + fps, + motion_bucket_id, + noise_aug_strength, + image_embeddings.dtype, + batch_size, + num_videos_per_prompt, + self.do_classifier_free_guidance, + ) + added_time_ids = added_time_ids.to(self.device) + + # Prepare timesteps + self.scheduler.set_timesteps(denoising_steps, device=self.device) + timesteps = self.scheduler.timesteps + + # Prepare latent variables + latents = self.prepare_latents( + batch_size * num_videos_per_prompt, + self.num_frames, + 8, # TODO: num_channels_latents = unet.config.in_channels + image_height, + image_width, + image_embeddings.dtype, + input_image.device, + None, # pre-generated latents + ) + + # Prepare guidance scale + guidance_scale = torch.linspace( + self.min_guidance_scale, self.max_guidance_scale, self.num_frames + ).unsqueeze(0) + guidance_scale = guidance_scale.to(self.device, latents.dtype) + guidance_scale = guidance_scale.repeat( + batch_size * num_videos_per_prompt, 1 + ) + guidance_scale = _append_dims(guidance_scale, latents.ndim) + + # Denoising loop + num_warmup_steps = len(timesteps) - denoising_steps * self.scheduler.order + unet_shape_dict = self.models["unet-temp"].get_shape_dict( + batch_size, image_height, image_width + ) + with ( + LoadModelContext( + "unet-temp", + alloc_shape=unet_shape_dict, + release_model=release_model, + ), + tqdm(total=denoising_steps) as progress_bar, + ): + self.profile_start("denoise", color="blue") + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) + if self.do_classifier_free_guidance + else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + # Concatenate image_latents over channels dimention + latent_model_input = torch.cat( + [latent_model_input, image_latents], dim=2 + ) + + # predict the noise residual + if self.torch_fallback["unet-temp"]: + noise_pred = self.torch_models["unet-temp"]( + latent_model_input, + t, + encoder_hidden_states=image_embeddings, + added_time_ids=added_time_ids, + return_dict=False, + )[0] + else: + params = { + "sample": latent_model_input, + "timestep": t, + "encoder_hidden_states": image_embeddings, + "added_time_ids": added_time_ids, + } + noise_pred = self.runEngine("unet-temp", params)["latent"] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_cond - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents).prev_sample + + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps + and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + self.profile_stop("denoise") + + with ( + torch.inference_mode(), + trt.Runtime(TRT_LOGGER), + LoadModelContext("vae-temp"), + ): + self.profile_start("vae", color="red") + self.torch_models["vae-temp"].to(dtype=torch.float16) + frames = self.decode_latents( + latents, self.num_frames, self.decode_chunk_size + ) + frames = image_module.tensor2vid( + frames, self.image_processor, output_type="pil" + ) + self.profile_stop("vae") + + torch.cuda.synchronize() + + if warmup: + return + + e2e_toc = time.perf_counter() + walltime_ms = (e2e_toc - e2e_tic) * 1000.0 + self.print_summary(denoising_steps, walltime_ms, batch_size, len(frames[0])) + if save_video: + self.save_video(frames[0], self.pipeline_type.name.lower(), self.seed) + + return frames, walltime_ms + + def run( + self, + input_image, + height, + width, + batch_size, + batch_count, + num_warmup_runs, + use_cuda_graph, + **kwargs, + ): + num_warmup_runs = max(1, num_warmup_runs) if use_cuda_graph else num_warmup_runs + if num_warmup_runs > 0: + print("[I] Warming up ..") + for _ in range(num_warmup_runs): + self.infer(input_image, height, width, warmup=True) + + for _ in range(batch_count): + print("[I] Running StableDiffusion pipeline") + if self.nvtx_profile: + cudart.cudaProfilerStart() + self.infer(input_image, height, width, warmup=False) + if self.nvtx_profile: + cudart.cudaProfilerStop() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/type.py b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/type.py new file mode 100644 index 000000000..aca8df766 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/pipeline/type.py @@ -0,0 +1,64 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import enum + + +class PIPELINE_TYPE(enum.Enum): + TXT2IMG = enum.auto() + IMG2IMG = enum.auto() + IMG2VID = enum.auto() + INPAINT = enum.auto() + CONTROLNET = enum.auto() + XL_CONTROLNET = enum.auto() + XL_BASE = enum.auto() + XL_REFINER = enum.auto() + CASCADE_PRIOR = enum.auto() + CASCADE_DECODER = enum.auto() + + def is_txt2img(self): + return self in (self.TXT2IMG, self.CONTROLNET) + + def is_img2img(self): + return self == self.IMG2IMG + + def is_img2vid(self): + return self == self.IMG2VID + + def is_inpaint(self): + return self == self.INPAINT + + def is_controlnet(self): + return self in (self.CONTROLNET, self.XL_CONTROLNET) + + def is_sd_xl_base(self): + return self in (self.XL_BASE, self.XL_CONTROLNET) + + def is_sd_xl_refiner(self): + return self == self.XL_REFINER + + def is_sd_xl(self): + return self.is_sd_xl_base() or self.is_sd_xl_refiner() + + def is_cascade_prior(self): + return self == self.CASCADE_PRIOR + + def is_cascade_decoder(self): + return self == self.CASCADE_DECODER + + def is_cascade(self): + return self.is_cascade_prior() or self.is_cascade_decoder() diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/utils_modelopt.py b/flux.1-dev-trt-b200/model/demo_diffusion/utils_modelopt.py new file mode 100755 index 000000000..b6dfccadb --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/utils_modelopt.py @@ -0,0 +1,943 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import re +from collections import defaultdict +from random import choice, shuffle +from typing import Set + +import modelopt.torch.quantization as mtq +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import torch +import torch.nn.functional as F +from diffusers.models.attention_processor import ( + Attention, + AttnProcessor, + FluxAttnProcessor2_0, +) +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from modelopt.torch.quantization import utils as quant_utils +from modelopt.torch.quantization.calib.max import MaxCalibrator +from PIL import Image +from torch.utils.data import Dataset, Sampler + +USE_PEFT = True +try: + from peft.tuners.lora.layer import Conv2d as PEFTLoRAConv2d + from peft.tuners.lora.layer import Linear as PEFTLoRALinear +except ModuleNotFoundError: + USE_PEFT = False + + +class PercentileCalibrator(MaxCalibrator): + def __init__( + self, num_bits=8, axis=None, unsigned=False, track_amax=False, **kwargs + ): + super().__init__(num_bits, axis, unsigned, track_amax) + self.percentile = kwargs["percentile"] + self.total_step = kwargs["total_step"] + self.collect_method = kwargs["collect_method"] + self.data = {} + self.i = 0 + + def collect(self, x): + """Tracks the absolute max of all tensors. + + Args: + x: A tensor + + Raises: + RuntimeError: If amax shape changes + """ + # Swap axis to reduce. + axis = self._axis if isinstance(self._axis, (list, tuple)) else [self._axis] + # Handle negative axis. + axis = [x.dim() + i if isinstance(i, int) and i < 0 else i for i in axis] + reduce_axis = [] + for i in range(x.dim()): + if i not in axis: + reduce_axis.append(i) + local_amax = quant_utils.reduce_amax(x, axis=reduce_axis).detach() + _cur_step = self.i % self.total_step + if _cur_step not in self.data.keys(): + self.data[_cur_step] = local_amax + else: + if self.collect_method == "global_min": + self.data[_cur_step] = torch.min(self.data[_cur_step], local_amax) + elif self.collect_method == "min-max" or self.collect_method == "mean-max": + self.data[_cur_step] = torch.max(self.data[_cur_step], local_amax) + else: + self.data[_cur_step] += local_amax + if self._track_amax: + raise NotImplementedError + self.i += 1 + + def compute_amax(self): + """Return the absolute max of all tensors collected.""" + up_lim = int(self.total_step * self.percentile) + if self.collect_method == "min-mean": + amaxs_values = [self.data[i] / self.total_step for i in range(0, up_lim)] + else: + amaxs_values = [self.data[i] for i in range(0, up_lim)] + if self.collect_method == "mean-max": + act_amax = torch.vstack(amaxs_values).mean(axis=0)[0] + else: + act_amax = torch.vstack(amaxs_values).min(axis=0)[0] + self._calib_amax = act_amax + return self._calib_amax + + def __str__(self): + s = "PercentileCalibrator" + return s.format(**self.__dict__) + + def __repr__(self): + s = "PercentileCalibrator(" + s += super(MaxCalibrator, self).__repr__() + s += " calib_amax={_calib_amax}" + if self._track_amax: + s += " amaxs={_amaxs}" + s += ")" + return s.format(**self.__dict__) + + +def filter_func(name): + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|proj_out).*" + ) + return pattern.match(name) is not None + + +def filter_func_no_proj_out(name): # used for Flux + pattern = re.compile( + r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" + ) + return pattern.match(name) is not None + + +def quantize_lvl( + model_id, backbone, quant_level=2.5, linear_only=False, enable_conv_3d=True +): + """ + We should disable the unwanted quantizer when exporting the onnx + Because in the current modelopt setting, it will load the quantizer amax for all the layers even + if we didn't add that unwanted layer into the config during the calibration + """ + for name, module in backbone.named_modules(): + if isinstance(module, torch.nn.Conv2d): + if linear_only: + module.input_quantizer.disable() + module.weight_quantizer.disable() + else: + module.input_quantizer.enable() + module.weight_quantizer.enable() + elif isinstance(module, torch.nn.Linear): + if ( + (quant_level >= 2 and "ff.net" in name) + or ( + quant_level >= 2.5 + and ("to_q" in name or "to_k" in name or "to_v" in name) + ) + or quant_level >= 3 + ) and name != "proj_out": # Disable the final output layer from flux model + module.input_quantizer.enable() + module.weight_quantizer.enable() + else: + module.input_quantizer.disable() + module.weight_quantizer.disable() + elif isinstance(module, torch.nn.Conv3d) and not enable_conv_3d: + """ + Error: Torch bug, ONNX export failed due to unknown kernel shape in QuantConv3d. + TRT_FP8QuantizeLinear and TRT_FP8DequantizeLinear operations in UNetSpatioTemporalConditionModel for svd + cause issues. Inputs on different devices (CUDA vs CPU) may contribute to the problem. + """ + module.input_quantizer.disable() + module.weight_quantizer.disable() + elif isinstance(module, Attention): + # TRT only supports FP8 MHA with head_size % 16 == 0. + head_size = int(module.inner_dim / module.heads) + if quant_level >= 4 and head_size % 16 == 0: + module.q_bmm_quantizer.enable() + module.k_bmm_quantizer.enable() + module.v_bmm_quantizer.enable() + module.softmax_quantizer.enable() + if model_id.startswith("flux.1"): + if name.startswith("transformer_blocks"): + module.bmm2_output_quantizer.enable() + else: + module.bmm2_output_quantizer.disable() + setattr(module, "_disable_fp8_mha", False) + else: + module.q_bmm_quantizer.disable() + module.k_bmm_quantizer.disable() + module.v_bmm_quantizer.disable() + module.softmax_quantizer.disable() + module.bmm2_output_quantizer.disable() + setattr(module, "_disable_fp8_mha", True) + + +def fp8_mha_disable(backbone, quantized_mha_output: bool = True): + def mha_filter_func(name): + pattern = re.compile( + r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer).*" + if quantized_mha_output + else r".*(q_bmm_quantizer|k_bmm_quantizer|v_bmm_quantizer|softmax_quantizer|bmm2_output_quantizer).*" + ) + return pattern.match(name) is not None + + if hasattr(F, "scaled_dot_product_attention"): + mtq.disable_quantizer(backbone, mha_filter_func) + + +def get_int8_config( + model, + quant_level=3, + alpha=0.8, + percentile=1.0, + num_inference_steps=20, + collect_method="min-mean", +): + quant_config = { + "quant_cfg": { + "*lm_head*": {"enable": False}, + "*output_layer*": {"enable": False}, + "*output_quantizer": {"enable": False}, + "default": {"num_bits": 8, "axis": None}, + }, + "algorithm": {"method": "smoothquant", "alpha": alpha}, + } + for name, module in model.named_modules(): + w_name = f"{name}*weight_quantizer" + i_name = f"{name}*input_quantizer" + + if ( + w_name in quant_config["quant_cfg"].keys() + or i_name in quant_config["quant_cfg"].keys() + ): + continue + if filter_func(name): + continue + if isinstance(module, (torch.nn.Linear, LoRACompatibleLinear)): + if ( + (quant_level >= 2 and "ff.net" in name) + or ( + quant_level >= 2.5 + and ("to_q" in name or "to_k" in name or "to_v" in name) + ) + or quant_level == 3 + ): + quant_config["quant_cfg"][w_name] = {"num_bits": 8, "axis": 0} + quant_config["quant_cfg"][i_name] = {"num_bits": 8, "axis": -1} + elif isinstance(module, (torch.nn.Conv2d, LoRACompatibleConv)): + quant_config["quant_cfg"][w_name] = {"num_bits": 8, "axis": 0} + quant_config["quant_cfg"][i_name] = { + "num_bits": 8, + "axis": None, + "calibrator": ( + PercentileCalibrator, + (), + { + "num_bits": 8, + "axis": None, + "percentile": percentile, + "total_step": num_inference_steps, + "collect_method": collect_method, + }, + ), + } + return quant_config + + +SD_FP8_FP16_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "*input_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "*k_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "*v_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Half", + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + +SD_FP8_BF16_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*input_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*k_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*v_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + +SD_FP8_BF16_FLUX_MMDIT_BMM2_FP8_OUTPUT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*input_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*k_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*v_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "transformer_blocks*bmm2_output_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "BFloat16", + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + +SD_FP8_FP32_DEFAULT_CONFIG = { + "quant_cfg": { + "*weight_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "*input_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "*output_quantizer": {"enable": False}, + "*q_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "*k_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "*v_bmm_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "*softmax_quantizer": { + "num_bits": (4, 3), + "axis": None, + "trt_high_precision_dtype": "Float", + }, + "default": {"enable": False}, + }, + "algorithm": "max", +} + + +def set_fmha(denoiser, is_flux=False): + for name, module in denoiser.named_modules(): + if isinstance(module, Attention): + if is_flux: + module.set_processor(FluxAttnProcessor2_0()) + else: + module.set_processor(AttnProcessor()) + + +def check_lora(model): + for name, module in model.named_modules(): + if isinstance(module, (LoRACompatibleConv, LoRACompatibleLinear)): + assert module.lora_layer is None, ( + f"To quantize {name}, LoRA layer should be fused/merged. Please fuse the LoRA layer before quantization." + ) + elif USE_PEFT and isinstance(module, (PEFTLoRAConv2d, PEFTLoRALinear)): + assert module.merged, ( + f"To quantize {name}, LoRA layer should be fused/merged. Please fuse the LoRA layer before quantization." + ) + + +def generate_fp8_scales(unet): + # temporary solution due to a known bug in torch.onnx._dynamo_export + for _, module in unet.named_modules(): + if isinstance(module, (torch.nn.Linear, torch.nn.Conv2d)) and ( + hasattr(module.input_quantizer, "_amax") + and module.input_quantizer is not None + ): + module.input_quantizer._num_bits = 8 + module.weight_quantizer._num_bits = 8 + module.input_quantizer._amax = module.input_quantizer._amax * (127 / 448.0) + module.weight_quantizer._amax = module.weight_quantizer._amax * ( + 127 / 448.0 + ) + elif isinstance(module, Attention) and ( + hasattr(module.q_bmm_quantizer, "_amax") + and module.q_bmm_quantizer is not None + ): + module.q_bmm_quantizer._num_bits = 8 + module.q_bmm_quantizer._amax = module.q_bmm_quantizer._amax * (127 / 448.0) + module.k_bmm_quantizer._num_bits = 8 + module.k_bmm_quantizer._amax = module.k_bmm_quantizer._amax * (127 / 448.0) + module.v_bmm_quantizer._num_bits = 8 + module.v_bmm_quantizer._amax = module.v_bmm_quantizer._amax * (127 / 448.0) + module.softmax_quantizer._num_bits = 8 + module.softmax_quantizer._amax = module.softmax_quantizer._amax * ( + 127 / 448.0 + ) + + +def get_parent_nodes(node): + """ + Returns list of input producer nodes for the given node. + """ + parents = [] + for tensor in node.inputs: + # If the tensor is not a constant or graph input and has a producer, + # the producer is a parent of node `node` + if len(tensor.inputs) == 1: + parents.append(tensor.inputs[0]) + return parents + + +def get_child_nodes(node): + """ + Returns list of output consumer nodes for the given node. + """ + children = [] + for tensor in node.outputs: + for consumer in tensor.outputs: # Traverse all consumer of the tensor + children.append(consumer) + return children + + +def has_path_type(node, graph, path_type, is_forward, wild_card_types, path_nodes): + """ + Return pattern nodes for the given path_type. + """ + if not path_type: + # All types matched + return True + + # Check if current non-wild node type does not match the expected path type + node_type = node.op + is_match = node_type == path_type[0] + is_wild_match = node_type in wild_card_types + if not is_match and not is_wild_match: + return False + + if is_match: + path_nodes.append(node) + next_path_type = path_type[1:] + else: + next_path_type = path_type[:] + + if is_forward: + next_level_nodes = get_child_nodes(node) + else: + next_level_nodes = get_parent_nodes(node) + + # Check if any child (forward path) or parent (backward path) can match the remaining path types + for next_node in next_level_nodes: + sub_path = [] + if has_path_type( + next_node, graph, next_path_type, is_forward, wild_card_types, sub_path + ): + path_nodes.extend(sub_path) + return True + + # Path type matches if there is no remaining types to match + return not next_path_type + + +def insert_cast(graph, input_tensor, attrs): + """ + Create a cast layer using tensor as input. + """ + output_tensor = gs.Variable( + name=f"{input_tensor.name}/Cast_output", dtype=attrs["to"] + ) + next_node_list = input_tensor.outputs.copy() + graph.layer( + op="Cast", + name=f"{input_tensor.name}/Cast", + inputs=[input_tensor], + outputs=[output_tensor], + attrs=attrs, + ) + + # use cast output as input to next node + for next_node in next_node_list: + for idx, next_input in enumerate(next_node.inputs): + if next_input.name == input_tensor.name: + next_node.inputs[idx] = output_tensor + + +def convert_zp_fp8(onnx_graph): + """ + Convert Q/DQ zero datatype from INT8 to FP8. + """ + # Find all zero constant nodes + qdq_zero_nodes = set() + for node in onnx_graph.graph.node: + if node.op_type == "QuantizeLinear": + if len(node.input) > 2: + qdq_zero_nodes.add(node.input[2]) + + print(f"Found {len(qdq_zero_nodes)} QDQ pairs") + + # Convert zero point datatype from INT8 to FP8. + for node in onnx_graph.graph.node: + if node.output[0] in qdq_zero_nodes: + node.attribute[0].t.data_type = onnx.TensorProto.FLOAT8E4M3FN + + return onnx_graph + + +def cast_resize_io(graph): + """ + After all activations and weights are converted to fp16, we will + add cast nodes to Resize nodes I/O because Resize need to be run in fp32. + """ + resize_nodes = [node for node in graph.nodes if node.op == "Resize"] + + print(f"Found {len(resize_nodes)} Resize nodes to fix") + for resize_node in resize_nodes: + for i, input_tensor in enumerate(resize_node.inputs): + SIZES_INPUT_INDEX = 3 # Optional input "sizes" at index 3 must be in INT64. Skip cast for this input. + if i != SIZES_INPUT_INDEX and input_tensor.name: + insert_cast(graph, input_tensor=input_tensor, attrs={"to": np.float32}) + for output_tensor in resize_node.outputs: + if output_tensor.name: + insert_cast(graph, input_tensor=output_tensor, attrs={"to": np.float16}) + + +def cast_fp8_mha_io(graph): + r""" + Insert three cast ops. + The first cast will be added before the input0 of MatMul to cast fp16 to fp32. + The second cast will be added before the input1 of MatMul to cast fp16 to fp32. + The third cast will be added after the output of MatMul to cast fp32 back to fp16. + Q Q + | | + DQ DQ + | | + Cast Cast + (fp16 to fp32) (fp16 to fp32) + \ / + \ / + \ / + MatMul + | + Cast (fp32 to fp16) + | + Q + | + DQ + The insertion of Cast ops in the FP8 MHA part actually forbids the MHAs to run + with FP16 accumulation because TensorRT only has FP32 accumulation kernels for FP8 MHAs. + """ + # Find FP8 MHA pattern. + # Match FP8 MHA: Q -> DQ -> BMM1 -> (Mul/Div) -> (Add) -> Softmax -> (Cast) -> Q -> DQ -> BMM2 -> Q -> DQ + softmax_bmm1_chain_type = [ + "Softmax", + "MatMul", + "DequantizeLinear", + "QuantizeLinear", + ] + softmax_bmm2_chain_type = [ + "Softmax", + "QuantizeLinear", + "DequantizeLinear", + "MatMul", + "QuantizeLinear", + "DequantizeLinear", + ] + wild_card_types = [ + "Div", + "Mul", + "ConstMul", + "Add", + "BiasAdd", + "Reshape", + "Transpose", + "Flatten", + "Cast", + ] + + fp8_mha_partitions = [] + for node in graph.nodes: + if node.op == "Softmax": + fp8_mha_partition = [] + if has_path_type( + node, + graph, + softmax_bmm1_chain_type, + False, + wild_card_types, + fp8_mha_partition, + ) and has_path_type( + node, + graph, + softmax_bmm2_chain_type, + True, + wild_card_types, + fp8_mha_partition, + ): + if ( + len(fp8_mha_partition) == 10 + and fp8_mha_partition[1].op == "MatMul" + and fp8_mha_partition[7].op == "MatMul" + ): + fp8_mha_partitions.append(fp8_mha_partition) + + print(f"Found {len(fp8_mha_partitions)} FP8 attentions") + + # Insert Cast nodes for BMM1 and BMM2. + for fp8_mha_partition in fp8_mha_partitions: + bmm1_node = fp8_mha_partition[1] + insert_cast(graph, input_tensor=bmm1_node.inputs[0], attrs={"to": np.float32}) + insert_cast(graph, input_tensor=bmm1_node.inputs[1], attrs={"to": np.float32}) + insert_cast(graph, input_tensor=bmm1_node.outputs[0], attrs={"to": np.float16}) + + bmm2_node = fp8_mha_partition[7] + insert_cast(graph, input_tensor=bmm2_node.inputs[0], attrs={"to": np.float32}) + insert_cast(graph, input_tensor=bmm2_node.inputs[1], attrs={"to": np.float32}) + insert_cast(graph, input_tensor=bmm2_node.outputs[0], attrs={"to": np.float16}) + + +def set_quant_precision(quant_config, precision: str = "Half"): + for key in quant_config["quant_cfg"]: + if "trt_high_precision_dtype" in quant_config["quant_cfg"][key]: + quant_config["quant_cfg"][key]["trt_high_precision_dtype"] = precision + + +def convert_fp16_io(graph): + """ + Convert graph I/O to FP16. + """ + for input_tensor in graph.inputs: + input_tensor.dtype = onnx.TensorProto.FLOAT16 + for output_tensor in graph.outputs: + output_tensor.dtype = onnx.TensorProto.FLOAT16 + + +def random_resize(cur_size: int): + """ + Randomly selects a new resolution for an image based on its current aspect ratio. + + This function determines the current aspect ratio of an image, selects a new aspect ratio + from predefined choices depending on whether the current aspect ratio is square, + portrait, or landscape, and returns the corresponding resolution from a provided mapping. + + Parameters: + cur_size (int): A tuple (width, height) representing the current resolution of the image. + resolution_to_aspects (dict[float, tuple[int, int]]): A mapping of aspect ratios (floats) + to their corresponding resolutions as tuples of (width, height). + + Returns: + tuple[int, int]: A tuple (new_width, new_height) representing the newly selected resolution. + + Raises: + KeyError: If the chosen aspect ratio is not present in the `resolution_to_aspects` dictionary. + + Notes: + - For square images (aspect ratio = 1), the function selects from aspect ratios 1.25, 0.8, 1.5, and 0.667. + - For landscape images (aspect ratio > 1), the function selects from aspect ratios 1.778, 1.25, and 1.5. + - For portrait images (aspect ratio < 1), the function selects from aspect ratios 0.563, 0.8, and 0.667. + """ + resolution_to_aspects = { + 1.0: (1024, 1024), + 1.778: (768, 1344), + 0.563: (1344, 768), + 1.25: (896, 1152), + 0.8: (1152, 896), + 1.5: (832, 1216), + 0.667: (1216, 832), + } + + cur_aspect_ratio = round(cur_size[1] / cur_size[0], 3) + + if cur_aspect_ratio == 1: + new_aspect_ratio = choice((1.25, 0.8, 1.5, 0.667)) + new_res = resolution_to_aspects[new_aspect_ratio] + elif cur_aspect_ratio > 1: + new_aspect_ratio = choice((1.778, 1.25, 1.5)) + new_res = resolution_to_aspects[new_aspect_ratio] + else: + # cur_aspect_ratio < 1 + new_aspect_ratio = choice((0.563, 0.8, 0.667)) + new_res = resolution_to_aspects[new_aspect_ratio] + + return new_res + + +class PromptImageDataset(Dataset): + def __init__( + self, + root_dir, + ): + """ + Args: + root_dir (str): Directory with all the images and the prompt file. + """ + self.root_dir = root_dir + self.possible_resolutions = {1024, 768, 1344, 896, 832, 1216} + self.global_idx_template = "{} | {} | {}" + + self.prompts_by_size = defaultdict(list) + self.images_by_size = defaultdict(list) + self.images = [] + self.prompts = [] + self.images_size = [] + # self.global_idx_2_group = dict() + # self.global_idx_to_group_idx = dict() + self.group_to_global_idx = {} + + for idx, file in enumerate(os.listdir(os.path.join(self.root_dir, "prompts"))): + if not file.endswith(".txt"): + continue + file_name = os.path.splitext(file)[0] + image_path = os.path.join( + self.root_dir, + "inputs", + f"{file_name}.png", + ) + + with ( + Image.open(image_path) as img, + open(os.path.join(self.root_dir, "prompts", file), "r") as f, + ): + prompt = "\n".join(f.readlines()) + + std_img_size = ( + self.closest_value(img.size[0], self.possible_resolutions), + self.closest_value(img.size[1], self.possible_resolutions), + ) + + self.images_by_size[std_img_size].append(image_path) + self.prompts_by_size[std_img_size].append(prompt) + + self.images.append(image_path) + self.prompts.append(prompt) + self.images_size.append(std_img_size) + + # create a unique key that map group and index inside the group to a global index + in_group_idx = len(self.images_by_size[std_img_size]) - 1 + group_idx_key = self.global_idx_template.format( + std_img_size[0], std_img_size[1], in_group_idx + ) + self.group_to_global_idx[group_idx_key] = len(self.images) - 1 + + assert len(self.images) == len(self.prompts) + assert len(self.images) == len(self.group_to_global_idx) + + @staticmethod + def closest_value(target: int, candidates: Set[int]): + """ + Find the closest value to the target from a set of candidate values. + + Args: + target (int): The integer to compare against. + candidates (set): A set of integers as candidates. + + Returns: + int: The closest value from the candidates. + """ + if not candidates: + raise ValueError("The candidates set cannot be empty.") + + # Use the min function with a key that computes the absolute difference + return min(candidates, key=lambda x: abs(x - target)) + + def __len__(self): + return len(self.images) + + def __getitem__(self, idx): + """ + Returns: + image (Tensor): Transformed image. + prompt (str): Corresponding text prompt. + """ + if torch.is_tensor(idx): + idx = idx.tolist() + + prompt = self.prompts[idx] + image = self.images[idx] + image_size = self.images_size[idx] + return image, prompt, image_size + + +class SameSizeSampler(Sampler): + def __init__(self, dataset: PromptImageDataset, batch_size: int): + """ + Custom sampler that creates batches of images with the same size + + Args: + dataset (SameSizeImageDataset): Dataset to sample from + batch_size (int): Number of images per batch + """ + super().__init__(dataset) + self.dataset = dataset + self.batch_size = batch_size + + # Prepare size groups with indices + self.size_groups = {} + for size, image_paths in self.dataset.images_by_size.items(): + # Create a list of indices for this size group + self.size_groups[size] = list(range(len(image_paths))) + + def __iter__(self): + """ + Iteration method that yields indices for batches of same-size images + """ + # Create a copy of size groups to shuffle + size_groups_copy = { + std_img_size: indices.copy() + for std_img_size, indices in self.size_groups.items() + } + + # Shuffle each size group + for std_img_size, indices in size_groups_copy.items(): + shuffle(indices) + + # Iterate through size groups + for std_img_size, indices in size_groups_copy.items(): + # Batch indices of the same size + for i in range(0, len(indices), self.batch_size): + # Yield batch indices for this size + batch_group_idxs = indices[i : min(i + self.batch_size, len(indices))] + for in_group_idx in batch_group_idxs: + group_idx_key = self.dataset.global_idx_template.format( + std_img_size[0], std_img_size[1], in_group_idx + ) + batch_global_idx = self.dataset.group_to_global_idx[group_idx_key] + # batch_global_idxs.append(batch_global_idx) + yield batch_global_idx + + def __len__(self): + """ + Total number of batches + """ + return len(self.dataset.images) // self.batch_size + + +def custom_collate(data): + """ + Custom collate function to handle batches of same-size images + + Args: + dataset (SameSizeImageDataset): Dataset instance + batch (list): List of global indices + + Returns: + tuple: Batched images and their size + """ + # Group images by their size + images, prompts, image_sizes = tuple(map(list, zip(*data))) + assert len(images) > 0 + new_img_size = random_resize(image_sizes[0]) + batch_images = [] + for image in images: + with Image.open(image) as image: + image = image.convert("RGB").resize( + size=new_img_size, resample=Image.LANCZOS + ) + image = np.array(image) + image = np.transpose(image, axes=(-1, 0, 1)) + image = torch.from_numpy(image).float() / 127.5 - 1.0 + batch_images.append(image) + + batch_images = torch.stack(batch_images, dim=0) + return batch_images, prompts + + +def infinite_dataloader(dataloader): + while True: + for batch in dataloader: + yield batch diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/__init__.py b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/mmdit.py b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/mmdit.py new file mode 100644 index 000000000..bd5264705 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/mmdit.py @@ -0,0 +1,822 @@ +# MIT License + +# Copyright (c) 2024 Stability AI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math +from typing import Dict, Optional + +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from demo_diffusion.utils_sd3.other_impls import Mlp, attention + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size: Optional[int] = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + flatten: bool = True, + bias: bool = True, + strict_img_size: bool = True, + dynamic_img_pad: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.patch_size = (patch_size, patch_size) + if img_size is not None: + self.img_size = (img_size, img_size) + self.grid_size = tuple( + [s // p for s, p in zip(self.img_size, self.patch_size)] + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + else: + self.img_size = None + self.grid_size = None + self.num_patches = None + + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.strict_img_size = strict_img_size + self.dynamic_img_pad = dynamic_img_pad + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_size, + bias=bias, + dtype=dtype, + device=device, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + return x + + +def modulate(x, shift, scale): + if shift is None: + shift = torch.zeros_like(scale) + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# + + +def get_2d_sincos_pos_embed( + embed_dim, + grid_size, + cls_token=False, + extra_tokens=0, + scaling_factor=None, + offset=None, +): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid_h = np.arange(grid_size, dtype=np.float32) + grid_w = np.arange(grid_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + if scaling_factor is not None: + grid = grid / scaling_factor + if offset is not None: + grid = grid - offset + grid = grid.reshape([2, 1, grid_size, grid_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate( + [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0 + ) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + return np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """Embeds scalar timesteps into vector representations.""" + + def __init__( + self, hidden_size, frequency_embedding_size=256, dtype=None, device=None + ): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear( + frequency_embedding_size, + hidden_size, + bias=True, + dtype=dtype, + device=device, + ), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + if torch.is_floating_point(t): + embedding = embedding.to(dtype=t.dtype) + return embedding + + def forward(self, t, dtype, **kwargs): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +class VectorEmbedder(nn.Module): + """Embeds a flat vector of dimension input_dim""" + + def __init__(self, input_dim: int, hidden_size: int, dtype=None, device=None): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(input_dim, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + + +################################################################################# +# Core DiT Model # +################################################################################# + + +def split_qkv(qkv, head_dim): + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, -1, head_dim).movedim(2, 0) + return qkv[0], qkv[1], qkv[2] + + +def optimized_attention(qkv, num_heads): + return attention(qkv[0], qkv[1], qkv[2], num_heads) + + +class SelfAttention(nn.Module): + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[float] = None, + attn_mode: str = "xformers", + pre_only: bool = False, + qk_norm: Optional[str] = None, + rmsnorm: bool = False, + dtype=None, + device=None, + ): + super().__init__() + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + if not pre_only: + self.proj = nn.Linear(dim, dim, dtype=dtype, device=device) + assert attn_mode in self.ATTENTION_MODES + self.attn_mode = attn_mode + self.pre_only = pre_only + + if qk_norm == "rms": + self.ln_q = RMSNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + self.ln_k = RMSNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + elif qk_norm == "ln": + self.ln_q = nn.LayerNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + self.ln_k = nn.LayerNorm( + self.head_dim, + elementwise_affine=True, + eps=1.0e-6, + dtype=dtype, + device=device, + ) + elif qk_norm is None: + self.ln_q = nn.Identity() + self.ln_k = nn.Identity() + else: + raise ValueError(qk_norm) + + def pre_attention(self, x: torch.Tensor): + B, L, C = x.shape + qkv = self.qkv(x) + q, k, v = split_qkv(qkv, self.head_dim) + q = self.ln_q(q).reshape(q.shape[0], q.shape[1], -1) + k = self.ln_k(k).reshape(q.shape[0], q.shape[1], -1) + return (q, k, v) + + def post_attention(self, x: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + x = self.proj(x) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + (q, k, v) = self.pre_attention(x) + x = attention(q, k, v, self.num_heads) + x = self.post_attention(x) + return x + + +class RMSNorm(torch.nn.Module): + def __init__( + self, + dim: int, + elementwise_affine: bool = False, + eps: float = 1e-6, + device=None, + dtype=None, + ): + """ + Initialize the RMSNorm normalization layer. + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + """ + super().__init__() + self.eps = eps + self.learnable_scale = elementwise_affine + if self.learnable_scale: + self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype)) + else: + self.register_parameter("weight", None) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The normalized tensor. + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + Args: + x (torch.Tensor): The input tensor. + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + """ + x = self._norm(x) + if self.learnable_scale: + return x * self.weight.to(device=x.device, dtype=x.dtype) + else: + return x + + +class SwiGLUFeedForward(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float] = None, + ): + """ + Initialize the FeedForward module. + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (ColumnParallelLinear): Linear transformation for the first layer. + w2 (RowParallelLinear): Linear transformation for the second layer. + w3 (ColumnParallelLinear): Linear transformation for the third layer. + + """ + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) + + +class DismantledBlock(nn.Module): + """A DiT block with gated adaptive layer norm (adaLN) conditioning.""" + + ATTENTION_MODES = ("xformers", "torch", "torch-hb", "math", "debug") + + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: str = "xformers", + qkv_bias: bool = False, + pre_only: bool = False, + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + qk_norm: Optional[str] = None, + dtype=None, + device=None, + **block_kwargs, + ): + super().__init__() + assert attn_mode in self.ATTENTION_MODES + if not rmsnorm: + self.norm1 = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + dtype=dtype, + device=device, + ) + else: + self.norm1 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = SelfAttention( + dim=hidden_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + pre_only=pre_only, + qk_norm=qk_norm, + rmsnorm=rmsnorm, + dtype=dtype, + device=device, + ) + if not pre_only: + if not rmsnorm: + self.norm2 = nn.LayerNorm( + hidden_size, + elementwise_affine=False, + eps=1e-6, + dtype=dtype, + device=device, + ) + else: + self.norm2 = RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + if not pre_only: + if not swiglu: + self.mlp = Mlp( + in_features=hidden_size, + hidden_features=mlp_hidden_dim, + act_layer=nn.GELU(approximate="tanh"), + dtype=dtype, + device=device, + ) + else: + self.mlp = SwiGLUFeedForward( + dim=hidden_size, hidden_dim=mlp_hidden_dim, multiple_of=256 + ) + self.scale_mod_only = scale_mod_only + if not scale_mod_only: + n_mods = 6 if not pre_only else 2 + else: + n_mods = 4 if not pre_only else 1 + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + hidden_size, n_mods * hidden_size, bias=True, dtype=dtype, device=device + ), + ) + self.pre_only = pre_only + + def pre_attention(self, x: torch.Tensor, c: torch.Tensor): + assert x is not None, "pre_attention called with None input" + if not self.pre_only: + if not self.scale_mod_only: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + self.adaLN_modulation(c).chunk(6, dim=1) + ) + else: + shift_msa = None + shift_mlp = None + scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation( + c + ).chunk(4, dim=1) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, (x, gate_msa, shift_mlp, scale_mlp, gate_mlp) + else: + if not self.scale_mod_only: + shift_msa, scale_msa = self.adaLN_modulation(c).chunk(2, dim=1) + else: + shift_msa = None + scale_msa = self.adaLN_modulation(c) + qkv = self.attn.pre_attention(modulate(self.norm1(x), shift_msa, scale_msa)) + return qkv, None + + def post_attention(self, attn, x, gate_msa, shift_mlp, scale_mlp, gate_mlp): + assert not self.pre_only + x = x + gate_msa.unsqueeze(1) * self.attn.post_attention(attn) + x = x + gate_mlp.unsqueeze(1) * self.mlp( + modulate(self.norm2(x), shift_mlp, scale_mlp) + ) + return x + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + assert not self.pre_only + (q, k, v), intermediates = self.pre_attention(x, c) + attn = attention(q, k, v, self.attn.num_heads) + return self.post_attention(attn, *intermediates) + + +def block_mixing(context, x, context_block, x_block, c): + assert context is not None, "block_mixing called with None context" + context_qkv, context_intermediates = context_block.pre_attention(context, c) + + x_qkv, x_intermediates = x_block.pre_attention(x, c) + + o = [] + for t in range(3): + o.append(torch.cat((context_qkv[t], x_qkv[t]), dim=1)) + q, k, v = tuple(o) + + attn = attention(q, k, v, x_block.attn.num_heads) + context_attn, x_attn = ( + attn[:, : context_qkv[0].shape[1]], + attn[:, context_qkv[0].shape[1] :], + ) + + if not context_block.pre_only: + context = context_block.post_attention(context_attn, *context_intermediates) + else: + context = None + x = x_block.post_attention(x_attn, *x_intermediates) + return context, x + + +class JointBlock(nn.Module): + """just a small wrapper to serve as a fsdp unit""" + + def __init__(self, *args, **kwargs): + super().__init__() + pre_only = kwargs.pop("pre_only") + qk_norm = kwargs.pop("qk_norm", None) + self.context_block = DismantledBlock( + *args, pre_only=pre_only, qk_norm=qk_norm, **kwargs + ) + self.x_block = DismantledBlock(*args, pre_only=False, qk_norm=qk_norm, **kwargs) + + def forward(self, *args, **kwargs): + return block_mixing( + *args, context_block=self.context_block, x_block=self.x_block, **kwargs + ) + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__( + self, + hidden_size: int, + patch_size: int, + out_channels: int, + total_out_channels: Optional[int] = None, + dtype=None, + device=None, + ): + super().__init__() + self.norm_final = nn.LayerNorm( + hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device + ) + self.linear = ( + nn.Linear( + hidden_size, + patch_size * patch_size * out_channels, + bias=True, + dtype=dtype, + device=device, + ) + if (total_out_channels is None) + else nn.Linear( + hidden_size, total_out_channels, bias=True, dtype=dtype, device=device + ) + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear( + hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device + ), + ) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class MMDiT(nn.Module): + """Diffusion model with a Transformer backbone.""" + + def __init__( + self, + input_size: int = 32, + patch_size: int = 2, + in_channels: int = 4, + depth: int = 28, + mlp_ratio: float = 4.0, + learn_sigma: bool = False, + adm_in_channels: Optional[int] = None, + context_embedder_config: Optional[Dict] = None, + register_length: int = 0, + attn_mode: str = "torch", + rmsnorm: bool = False, + scale_mod_only: bool = False, + swiglu: bool = False, + out_channels: Optional[int] = None, + pos_embed_scaling_factor: Optional[float] = None, + pos_embed_offset: Optional[float] = None, + pos_embed_max_size: Optional[int] = None, + num_patches=None, + qk_norm: Optional[str] = None, + qkv_bias: bool = True, + dtype=None, + device=None, + ): + super().__init__() + self.dtype = dtype + self.learn_sigma = learn_sigma + self.in_channels = in_channels + default_out_channels = in_channels * 2 if learn_sigma else in_channels + self.out_channels = ( + out_channels if out_channels is not None else default_out_channels + ) + self.patch_size = patch_size + self.pos_embed_scaling_factor = pos_embed_scaling_factor + self.pos_embed_offset = pos_embed_offset + self.pos_embed_max_size = pos_embed_max_size + + # apply magic --> this defines a head_size of 64 + hidden_size = 64 * depth + num_heads = depth + + self.num_heads = num_heads + + self.x_embedder = PatchEmbed( + input_size, + patch_size, + in_channels, + hidden_size, + bias=True, + strict_img_size=self.pos_embed_max_size is None, + dtype=dtype, + device=device, + ) + self.t_embedder = TimestepEmbedder(hidden_size, dtype=dtype, device=device) + + if adm_in_channels is not None: + assert isinstance(adm_in_channels, int) + self.y_embedder = VectorEmbedder( + adm_in_channels, hidden_size, dtype=dtype, device=device + ) + + self.context_embedder = nn.Identity() + if context_embedder_config is not None: + if context_embedder_config["target"] == "torch.nn.Linear": + self.context_embedder = nn.Linear( + **context_embedder_config["params"], dtype=dtype, device=device + ) + + self.register_length = register_length + if self.register_length > 0: + self.register = nn.Parameter( + torch.randn(1, register_length, hidden_size, dtype=dtype, device=device) + ) + + # num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + # just use a buffer already + if num_patches is not None: + self.register_buffer( + "pos_embed", + torch.zeros(1, num_patches, hidden_size, dtype=dtype, device=device), + ) + else: + self.pos_embed = None + + self.joint_blocks = nn.ModuleList( + [ + JointBlock( + hidden_size, + num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + attn_mode=attn_mode, + pre_only=i == depth - 1, + rmsnorm=rmsnorm, + scale_mod_only=scale_mod_only, + swiglu=swiglu, + qk_norm=qk_norm, + dtype=dtype, + device=device, + ) + for i in range(depth) + ] + ) + + self.final_layer = FinalLayer( + hidden_size, patch_size, self.out_channels, dtype=dtype, device=device + ) + + def cropped_pos_embed(self, hw): + assert self.pos_embed_max_size is not None + p = self.x_embedder.patch_size[0] + h, w = hw + # patched size + h = h // p + w = w // p + assert h <= self.pos_embed_max_size, (h, self.pos_embed_max_size) + assert w <= self.pos_embed_max_size, (w, self.pos_embed_max_size) + top = (self.pos_embed_max_size - h) // 2 + left = (self.pos_embed_max_size - w) // 2 + spatial_pos_embed = rearrange( + self.pos_embed, + "1 (h w) c -> 1 h w c", + h=self.pos_embed_max_size, + w=self.pos_embed_max_size, + ) + spatial_pos_embed = spatial_pos_embed[:, top : top + h, left : left + w, :] + spatial_pos_embed = rearrange(spatial_pos_embed, "1 h w c -> 1 (h w) c") + return spatial_pos_embed + + def unpatchify(self, x, hw=None): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + if hw is None: + h = w = int(x.shape[1] ** 0.5) + else: + h, w = hw + h = h // p + w = w // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward_core_with_concat( + self, + x: torch.Tensor, + c_mod: torch.Tensor, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if self.register_length > 0: + context = torch.cat( + ( + repeat(self.register, "1 ... -> b ...", b=x.shape[0]), + context if context is not None else torch.Tensor([]).type_as(x), + ), + 1, + ) + + # context is B, L', D + # x is B, L, D + for block in self.joint_blocks: + context, x = block(context, x, c=c_mod) + + x = self.final_layer(x, c_mod) # (N, T, patch_size ** 2 * out_channels) + return x + + def forward( + self, + x: torch.Tensor, + t: torch.Tensor, + y: Optional[torch.Tensor] = None, + context: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + hw = x.shape[-2:] + x = self.x_embedder(x) + self.cropped_pos_embed(hw) + c = self.t_embedder(t, dtype=x.dtype) # (N, D) + if y is not None: + y = self.y_embedder(y) # (N, D) + c = c + y # (N, D) + + context = self.context_embedder(context) + + x = self.forward_core_with_concat(x, c, context) + + x = self.unpatchify(x, hw=hw) # (N, out_channels, H, W) + return x diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/other_impls.py b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/other_impls.py new file mode 100644 index 000000000..0a17d8c86 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/other_impls.py @@ -0,0 +1,813 @@ +# MIT License + +# Copyright (c) 2024 Stability AI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import torch +import math +import numpy as np +from torch import nn +from transformers import CLIPTokenizer, T5TokenizerFast + + +def load_into(f, model, prefix, device, dtype=None): + """Just a debugging-friendly hack to apply the weights in a safetensors file to the pytorch module.""" + for key in f.keys(): + if key.startswith(prefix) and not key.startswith("loss."): + path = key[len(prefix) :].split(".") + obj = model + for p in path: + if obj is list: + obj = obj[int(p)] + else: + obj = getattr(obj, p, None) + if obj is None: + print( + f"Skipping key '{key}' in safetensors file as '{p}' does not exist in python model" + ) + break + if obj is None: + continue + try: + tensor = f.get_tensor(key).to(device=device) + if dtype is not None: + tensor = tensor.to(dtype=dtype) + obj.requires_grad_(False) + obj.set_(tensor) + except Exception as e: + print(f"Failed to load key '{key}' in safetensors file: {e}") + raise e + + +def preprocess_image_sd3(image): + image.convert("RGB") + image_np = np.array(image).astype(np.float32) / 255.0 + image_np = np.moveaxis(image_np, 2, 0) + batch_images = np.expand_dims(image_np, axis=0).repeat(1, axis=0) + image_torch = torch.from_numpy(batch_images) + image_torch = 2.0 * image_torch - 1.0 + + return image_torch + + +################################################################################################# +### Core/Utility +################################################################################################# + + +def attention(q, k, v, heads, mask=None): + """Convenience wrapper around a basic attention operation""" + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map(lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), (q, k, v)) + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + return out.transpose(1, 2).reshape(b, -1, heads * dim_head) + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + dtype=None, + device=None, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + self.fc1 = nn.Linear( + in_features, hidden_features, bias=bias, dtype=dtype, device=device + ) + self.act = act_layer + self.fc2 = nn.Linear( + hidden_features, out_features, bias=bias, dtype=dtype, device=device + ) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +################################################################################################# +### CLIP +################################################################################################# + + +class CLIPAttention(torch.nn.Module): + def __init__(self, embed_dim, heads, dtype, device): + super().__init__() + self.heads = heads + self.q_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.k_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.v_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + self.out_proj = nn.Linear( + embed_dim, embed_dim, bias=True, dtype=dtype, device=device + ) + + def forward(self, x, mask=None): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + out = attention(q, k, v, self.heads, mask) + return self.out_proj(out) + + +ACTIVATIONS = { + "quick_gelu": lambda a: a * torch.sigmoid(1.702 * a), + "gelu": torch.nn.functional.gelu, +} + + +class CLIPLayer(torch.nn.Module): + def __init__( + self, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layer_norm1 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + self.self_attn = CLIPAttention(embed_dim, heads, dtype, device) + self.layer_norm2 = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + # self.mlp = CLIPMLP(embed_dim, intermediate_size, intermediate_activation, dtype, device) + self.mlp = Mlp( + embed_dim, + intermediate_size, + embed_dim, + act_layer=ACTIVATIONS[intermediate_activation], + dtype=dtype, + device=device, + ) + + def forward(self, x, mask=None): + x += self.self_attn(self.layer_norm1(x), mask) + x += self.mlp(self.layer_norm2(x)) + return x + + +class CLIPEncoder(torch.nn.Module): + def __init__( + self, + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ): + super().__init__() + self.layers = torch.nn.ModuleList( + [ + CLIPLayer( + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + for i in range(num_layers) + ] + ) + + def forward(self, x, mask=None, intermediate_output=None): + if intermediate_output is not None: + if intermediate_output < 0: + intermediate_output = len(self.layers) + intermediate_output + intermediate = None + for i, l in enumerate(self.layers): + x = l(x, mask) + if i == intermediate_output: + intermediate = x.clone() + return x, intermediate + + +class CLIPEmbeddings(torch.nn.Module): + def __init__( + self, embed_dim, vocab_size=49408, num_positions=77, dtype=None, device=None + ): + super().__init__() + self.token_embedding = torch.nn.Embedding( + vocab_size, embed_dim, dtype=dtype, device=device + ) + self.position_embedding = torch.nn.Embedding( + num_positions, embed_dim, dtype=dtype, device=device + ) + + def forward(self, input_tokens): + return self.token_embedding(input_tokens) + self.position_embedding.weight + + +class CLIPTextModel_(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + num_layers = config_dict["num_hidden_layers"] + embed_dim = config_dict["hidden_size"] + heads = config_dict["num_attention_heads"] + intermediate_size = config_dict["intermediate_size"] + intermediate_activation = config_dict["hidden_act"] + super().__init__() + self.embeddings = CLIPEmbeddings(embed_dim, dtype=dtype, device=device) + self.encoder = CLIPEncoder( + num_layers, + embed_dim, + heads, + intermediate_size, + intermediate_activation, + dtype, + device, + ) + self.final_layer_norm = nn.LayerNorm(embed_dim, dtype=dtype, device=device) + + def forward( + self, input_tokens, intermediate_output=None, final_layer_norm_intermediate=True + ): + x = self.embeddings(input_tokens) + causal_mask = ( + torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) + .fill_(float("-inf")) + .triu_(1) + ) + x, i = self.encoder( + x, mask=causal_mask, intermediate_output=intermediate_output + ) + x = self.final_layer_norm(x) + if i is not None and final_layer_norm_intermediate: + i = self.final_layer_norm(i) + pooled_output = x[ + torch.arange(x.shape[0], device=x.device), + input_tokens.to(dtype=torch.int, device=x.device).argmax(dim=-1), + ] + return x, i, pooled_output + + +class CLIPTextModel(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_hidden_layers"] + self.text_model = CLIPTextModel_(config_dict, dtype, device) + embed_dim = config_dict["hidden_size"] + self.text_projection = nn.Linear( + embed_dim, embed_dim, bias=False, dtype=dtype, device=device + ) + + # WAR for RuntimeError: a leaf Variable that requires grad is being used in an in-place operation. + with torch.no_grad(): + self.text_projection.weight.copy_(torch.eye(embed_dim)) + self.dtype = dtype + + def get_input_embeddings(self): + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, embeddings): + self.text_model.embeddings.token_embedding = embeddings + + def forward(self, *args, **kwargs): + x = self.text_model(*args, **kwargs) + out = self.text_projection(x[2]) + return (x[0], x[1], out, x[2]) + + +class SDTokenizer: + def __init__( + self, + max_length=77, + pad_with_end=True, + tokenizer=None, + has_start_token=True, + pad_to_max_length=True, + min_length=None, + ): + self.tokenizer = tokenizer + self.max_length = max_length + self.min_length = min_length + empty = self.tokenizer("")["input_ids"] + if has_start_token: + self.tokens_start = 1 + self.start_token = empty[0] + self.end_token = empty[1] + else: + self.tokens_start = 0 + self.start_token = None + self.end_token = empty[0] + self.pad_with_end = pad_with_end + self.pad_to_max_length = pad_to_max_length + vocab = self.tokenizer.get_vocab() + self.inv_vocab = {v: k for k, v in vocab.items()} + self.max_word_length = 8 + + def tokenize_with_weights(self, text: str): + """Tokenize the text, with weight values - presume 1.0 for all and ignore other features here. The details aren't relevant for a reference impl, and weights themselves has weak effect on SD3.""" + if self.pad_with_end: + pad_token = self.end_token + else: + pad_token = 0 + batch = [] + if self.start_token is not None: + batch.append((self.start_token, 1.0)) + to_tokenize = text.replace("\n", " ").split(" ") + to_tokenize = [x for x in to_tokenize if x != ""] + for word in to_tokenize: + batch.extend( + [ + (t, 1) + for t in self.tokenizer(word)["input_ids"][self.tokens_start : -1] + ] + ) + batch.append((self.end_token, 1.0)) + if self.pad_to_max_length: + batch.extend([(pad_token, 1.0)] * (self.max_length - len(batch))) + if self.min_length is not None and len(batch) < self.min_length: + batch.extend([(pad_token, 1.0)] * (self.min_length - len(batch))) + if len(batch) > self.max_length: + batch = batch[: self.max_length] + return [batch] + + +class SDXLClipGTokenizer(SDTokenizer): + def __init__(self, tokenizer): + super().__init__(pad_with_end=False, tokenizer=tokenizer) + + +class SD3Tokenizer: + def __init__(self): + clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") + self.clip_l = SDTokenizer(tokenizer=clip_tokenizer) + self.clip_g = SDXLClipGTokenizer(clip_tokenizer) + self.t5xxl = T5XXLTokenizer() + + def tokenize_with_weights(self, text: str): + out = {} + out["g"] = self.clip_g.tokenize_with_weights(text) + out["l"] = self.clip_l.tokenize_with_weights(text) + out["t5xxl"] = self.t5xxl.tokenize_with_weights(text) + return out + + +class ClipTokenWeightEncoder: + def encode_token_weights(self, token_weight_pairs): + tokens = list(map(lambda a: a[0], token_weight_pairs[0])) + + # model inference + tokens = torch.tensor([tokens], dtype=torch.int64, device="cuda") + out, pooled = self(tokens) + + if pooled is not None: + first_pooled = pooled[0:1].cuda() + else: + first_pooled = pooled + output = [out[0:1]] + + return torch.cat(output, dim=-2).cuda(), first_pooled + + +class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + + LAYERS = ["last", "pooled", "hidden"] + + def __init__( + self, + device="cuda", + max_length=77, + layer="last", + layer_idx=None, + textmodel_json_config=None, + dtype=None, + model_class=CLIPTextModel, + special_tokens={"start": 49406, "end": 49407, "pad": 49407}, + layer_norm_hidden_state=True, + return_projected_pooled=True, + ): + super().__init__() + assert layer in self.LAYERS + self.transformer = model_class(textmodel_json_config, dtype, device) + self.num_layers = self.transformer.num_layers + self.max_length = max_length + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + self.layer = layer + self.layer_idx = None + self.special_tokens = special_tokens + self.logit_scale = torch.nn.Parameter(torch.tensor(4.6055)) + self.layer_norm_hidden_state = layer_norm_hidden_state + self.return_projected_pooled = return_projected_pooled + if layer == "hidden": + assert layer_idx is not None + assert abs(layer_idx) < self.num_layers + self.set_clip_options({"layer": layer_idx}) + self.options_default = ( + self.layer, + self.layer_idx, + self.return_projected_pooled, + ) + + def set_clip_options(self, options): + layer_idx = options.get("layer", self.layer_idx) + self.return_projected_pooled = options.get( + "projected_pooled", self.return_projected_pooled + ) + if layer_idx is None or abs(layer_idx) > self.num_layers: + self.layer = "last" + else: + self.layer = "hidden" + self.layer_idx = layer_idx + + def forward(self, tokens): + backup_embeds = self.transformer.get_input_embeddings() + outputs = self.transformer( + tokens, + intermediate_output=self.layer_idx, + final_layer_norm_intermediate=self.layer_norm_hidden_state, + ) + self.transformer.set_input_embeddings(backup_embeds) + if self.layer == "last": + z = outputs[0] + else: + z = outputs[1] + pooled_output = None + if len(outputs) >= 3: + if ( + not self.return_projected_pooled + and len(outputs) >= 4 + and outputs[3] is not None + ): + pooled_output = outputs[3].float() + elif outputs[2] is not None: + pooled_output = outputs[2].float() + return z.float(), pooled_output + + +class SDXLClipG(SDClipModel): + """Wraps the CLIP-G model into the SD-CLIP-Model interface""" + + def __init__( + self, config, device="cuda", layer="penultimate", layer_idx=None, dtype=None + ): + if layer == "penultimate": + layer = "hidden" + layer_idx = -2 + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"start": 49406, "end": 49407, "pad": 0}, + layer_norm_hidden_state=False, + ) + + +class T5XXLModel(SDClipModel): + """Wraps the T5-XXL model into the SD-CLIP-Model interface for convenience""" + + def __init__(self, config, device="cuda", layer="last", layer_idx=None, dtype=None): + super().__init__( + device=device, + layer=layer, + layer_idx=layer_idx, + textmodel_json_config=config, + dtype=dtype, + special_tokens={"end": 1, "pad": 0}, + model_class=T5, + ) + + +################################################################################################# +### T5 implementation, for the T5-XXL text encoder portion, largely pulled from upstream impl +################################################################################################# + + +class T5XXLTokenizer(SDTokenizer): + """Wraps the T5 Tokenizer from HF into the SDTokenizer interface""" + + def __init__(self): + super().__init__( + pad_with_end=False, + tokenizer=T5TokenizerFast.from_pretrained("google/t5-v1_1-xxl"), + has_start_token=False, + pad_to_max_length=False, + max_length=99999999, + min_length=77, + ) + + +class T5LayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6, dtype=None, device=None): + super().__init__() + self.weight = torch.nn.Parameter( + torch.ones(hidden_size, dtype=dtype, device=device) + ) + self.variance_epsilon = eps + + def forward(self, x): + variance = x.pow(2).mean(-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + return self.weight.to(device=x.device, dtype=x.dtype) * x + + +class T5DenseGatedActDense(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.wi_0 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wi_1 = nn.Linear(model_dim, ff_dim, bias=False, dtype=dtype, device=device) + self.wo = nn.Linear(ff_dim, model_dim, bias=False, dtype=dtype, device=device) + + def forward(self, x): + hidden_gelu = torch.nn.functional.gelu(self.wi_0(x), approximate="tanh") + hidden_linear = self.wi_1(x) + x = hidden_gelu * hidden_linear + x = self.wo(x) + return x + + +class T5LayerFF(torch.nn.Module): + def __init__(self, model_dim, ff_dim, dtype, device): + super().__init__() + self.DenseReluDense = T5DenseGatedActDense(model_dim, ff_dim, dtype, device) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x): + forwarded_states = self.layer_norm(x) + forwarded_states = self.DenseReluDense(forwarded_states) + x += forwarded_states + return x + + +class T5Attention(torch.nn.Module): + def __init__( + self, model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device + ): + super().__init__() + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.k = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.v = nn.Linear(model_dim, inner_dim, bias=False, dtype=dtype, device=device) + self.o = nn.Linear(inner_dim, model_dim, bias=False, dtype=dtype, device=device) + self.num_heads = num_heads + self.relative_attention_bias = None + if relative_attention_bias: + self.relative_attention_num_buckets = 32 + self.relative_attention_max_distance = 128 + self.relative_attention_bias = torch.nn.Embedding( + self.relative_attention_num_buckets, + self.num_heads, + device=device, + dtype=dtype, + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) + # now relative_position is in the range [0, inf) + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), + ) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) + return relative_buckets + + def compute_bias(self, query_length, key_length, device): + """Compute binned relative position bias""" + context_position = torch.arange(query_length, dtype=torch.long, device=device)[ + :, None + ] + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[ + None, : + ] + relative_position = ( + memory_position - context_position + ) # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=True, + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias( + relative_position_bucket + ) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze( + 0 + ) # shape (1, num_heads, query_length, key_length) + return values + + def forward(self, x, past_bias=None): + q = self.q(x) + k = self.k(x) + v = self.v(x) + if self.relative_attention_bias is not None: + past_bias = self.compute_bias(x.shape[1], x.shape[1], x.device) + if past_bias is not None: + mask = past_bias + out = attention( + q, k * ((k.shape[-1] / self.num_heads) ** 0.5), v, self.num_heads, mask + ) + return self.o(out), past_bias + + +class T5LayerSelfAttention(torch.nn.Module): + def __init__( + self, + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ): + super().__init__() + self.SelfAttention = T5Attention( + model_dim, inner_dim, num_heads, relative_attention_bias, dtype, device + ) + self.layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward(self, x, past_bias=None): + output, past_bias = self.SelfAttention(self.layer_norm(x), past_bias=past_bias) + x += output + return x, past_bias + + +class T5Block(torch.nn.Module): + def __init__( + self, + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ): + super().__init__() + self.layer = torch.nn.ModuleList() + self.layer.append( + T5LayerSelfAttention( + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias, + dtype, + device, + ) + ) + self.layer.append(T5LayerFF(model_dim, ff_dim, dtype, device)) + + def forward(self, x, past_bias=None): + x, past_bias = self.layer[0](x, past_bias) + x = self.layer[-1](x) + return x, past_bias + + +class T5Stack(torch.nn.Module): + def __init__( + self, + num_layers, + model_dim, + inner_dim, + ff_dim, + num_heads, + vocab_size, + dtype, + device, + ): + super().__init__() + self.embed_tokens = torch.nn.Embedding( + vocab_size, model_dim, device=device, dtype=dtype + ) + self.block = torch.nn.ModuleList( + [ + T5Block( + model_dim, + inner_dim, + ff_dim, + num_heads, + relative_attention_bias=(i == 0), + dtype=dtype, + device=device, + ) + for i in range(num_layers) + ] + ) + self.final_layer_norm = T5LayerNorm(model_dim, dtype=dtype, device=device) + + def forward( + self, input_ids, intermediate_output=None, final_layer_norm_intermediate=True + ): + intermediate = None + x = self.embed_tokens(input_ids) + past_bias = None + for i, l in enumerate(self.block): + x, past_bias = l(x, past_bias) + if i == intermediate_output: + intermediate = x.clone() + x = self.final_layer_norm(x) + if intermediate is not None and final_layer_norm_intermediate: + intermediate = self.final_layer_norm(intermediate) + return x, intermediate + + +class T5(torch.nn.Module): + def __init__(self, config_dict, dtype, device): + super().__init__() + self.num_layers = config_dict["num_layers"] + self.encoder = T5Stack( + self.num_layers, + config_dict["d_model"], + config_dict["d_model"], + config_dict["d_ff"], + config_dict["num_heads"], + config_dict["vocab_size"], + dtype, + device, + ) + self.dtype = dtype + + def get_input_embeddings(self): + return self.encoder.embed_tokens + + def set_input_embeddings(self, embeddings): + self.encoder.embed_tokens = embeddings + + def forward(self, *args, **kwargs): + return self.encoder(*args, **kwargs) diff --git a/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/sd3_impls.py b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/sd3_impls.py new file mode 100644 index 000000000..45ecba842 --- /dev/null +++ b/flux.1-dev-trt-b200/model/demo_diffusion/utils_sd3/sd3_impls.py @@ -0,0 +1,589 @@ +# MIT License + +# Copyright (c) 2024 Stability AI + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import math + +import einops +import torch +from PIL import Image + +from demo_diffusion.utils_sd3.mmdit import MMDiT + +################################################################################################# +### MMDiT Model Wrapping +################################################################################################# + + +class ModelSamplingDiscreteFlow(torch.nn.Module): + """Helper for sampler scheduling (ie timestep/sigma calculations) for Discrete Flow models""" + + def __init__(self, shift=1.0): + super().__init__() + self.shift = shift + timesteps = 1000 + ts = self.sigma(torch.arange(1, timesteps + 1, 1)) + self.register_buffer("sigmas", ts) + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + return sigma * 1000 + + def sigma(self, timestep: torch.Tensor): + timestep = timestep / 1000.0 + if self.shift == 1.0: + return timestep + return self.shift * timestep / (1 + (self.shift - 1) * timestep) + + def calculate_denoised(self, sigma, model_output, model_input): + sigma = sigma.view(sigma.shape[:1] + (1,) * (model_output.ndim - 1)) + return model_input - model_output * sigma + + def noise_scaling(self, sigma, noise, latent_image, max_denoise=False): + return sigma * noise + (1.0 - sigma) * latent_image + + +class BaseModel(torch.nn.Module): + """Wrapper around the core MM-DiT model""" + + def __init__( + self, shift=1.0, device=None, dtype=torch.float32, file=None, prefix="" + ): + super().__init__() + # Important configuration values can be quickly determined by checking shapes in the source file + # Some of these will vary between models (eg 2B vs 8B primarily differ in their depth, but also other details change) + patch_size = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[2] + depth = file.get_tensor(f"{prefix}x_embedder.proj.weight").shape[0] // 64 + num_patches = file.get_tensor(f"{prefix}pos_embed").shape[1] + pos_embed_max_size = round(math.sqrt(num_patches)) + adm_in_channels = file.get_tensor(f"{prefix}y_embedder.mlp.0.weight").shape[1] + context_shape = file.get_tensor(f"{prefix}context_embedder.weight").shape + context_embedder_config = { + "target": "torch.nn.Linear", + "params": { + "in_features": context_shape[1], + "out_features": context_shape[0], + }, + } + self.diffusion_model = MMDiT( + input_size=None, + pos_embed_scaling_factor=None, + pos_embed_offset=None, + pos_embed_max_size=pos_embed_max_size, + patch_size=patch_size, + in_channels=16, + depth=depth, + num_patches=num_patches, + adm_in_channels=adm_in_channels, + context_embedder_config=context_embedder_config, + device=device, + dtype=dtype, + ) + self.model_sampling = ModelSamplingDiscreteFlow(shift=shift) + + def forward(self, x, sigma, c_crossattn=None, y=None): + dtype = self.get_dtype() + timestep = self.model_sampling.timestep(sigma).float() + model_output = self.diffusion_model( + x.to(dtype), timestep, context=c_crossattn.to(dtype), y=y.to(dtype) + ).float() + return self.model_sampling.calculate_denoised(sigma, model_output, x) + + def get_dtype(self): + return self.diffusion_model.dtype + + +class CFGDenoiser(torch.nn.Module): + """Helper for applying CFG Scaling to diffusion outputs""" + + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, x, timestep, cond, uncond, cond_scale): + # Run cond and uncond in a batch together + batched = self.model( + torch.cat([x, x]), + torch.cat([timestep, timestep]), + c_crossattn=torch.cat([cond["c_crossattn"], uncond["c_crossattn"]]), + y=torch.cat([cond["y"], uncond["y"]]), + ) + # Then split and apply CFG Scaling + pos_out, neg_out = batched.chunk(2) + scaled = neg_out + (pos_out - neg_out) * cond_scale + return scaled + + +class SD3LatentFormat: + """Latents are slightly shifted from center - this class must be called after VAE Decode to correct for the shift""" + + def __init__(self): + self.scale_factor = 1.5305 + self.shift_factor = 0.0609 + + def process_in(self, latent): + return (latent - self.shift_factor) * self.scale_factor + + def process_out(self, latent): + return (latent / self.scale_factor) + self.shift_factor + + def decode_latent_to_preview(self, x0): + """Quick RGB approximate preview of sd3 latents""" + factors = torch.tensor( + [ + [-0.0645, 0.0177, 0.1052], + [0.0028, 0.0312, 0.0650], + [0.1848, 0.0762, 0.0360], + [0.0944, 0.0360, 0.0889], + [0.0897, 0.0506, -0.0364], + [-0.0020, 0.1203, 0.0284], + [0.0855, 0.0118, 0.0283], + [-0.0539, 0.0658, 0.1047], + [-0.0057, 0.0116, 0.0700], + [-0.0412, 0.0281, -0.0039], + [0.1106, 0.1171, 0.1220], + [-0.0248, 0.0682, -0.0481], + [0.0815, 0.0846, 0.1207], + [-0.0120, -0.0055, -0.0867], + [-0.0749, -0.0634, -0.0456], + [-0.1418, -0.1457, -0.1259], + ], + device="cuda", + ) + latent_image = x0[0].permute(1, 2, 0).cuda() @ factors + + latents_ubyte = ( + ((latent_image + 1) / 2) + .clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + .byte() + ).cuda() + + return Image.fromarray(latents_ubyte.numpy()) + + +################################################################################################# +### K-Diffusion Sampling +################################################################################################# + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + return x[(...,) + (None,) * dims_to_append] + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +@torch.no_grad() +@torch.autocast("cuda", dtype=torch.float16) +def sample_euler(func, x, sigmas, extra_args=None): + """Implements Algorithm 2 (Euler steps) from Karras et al. (2022).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in range(len(sigmas) - 1): + sigma_hat = sigmas[i] + denoised = func(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + dt = sigmas[i + 1] - sigma_hat + # Euler method + x = x + d * dt + return x + + +################################################################################################# +### VAE +################################################################################################# + + +def Normalize(in_channels, num_groups=32, dtype=torch.float32, device=None): + return torch.nn.GroupNorm( + num_groups=num_groups, + num_channels=in_channels, + eps=1e-6, + affine=True, + dtype=dtype, + device=device, + ) + + +class ResnetBlock(torch.nn.Module): + def __init__( + self, *, in_channels, out_channels=None, dtype=torch.float32, device=None + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = Normalize(in_channels, dtype=dtype, device=device) + self.conv1 = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + self.norm2 = Normalize(out_channels, dtype=dtype, device=device) + self.conv2 = torch.nn.Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + ) + else: + self.nin_shortcut = None + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + hidden = x + hidden = self.norm1(hidden) + hidden = self.swish(hidden) + hidden = self.conv1(hidden) + hidden = self.norm2(hidden) + hidden = self.swish(hidden) + hidden = self.conv2(hidden) + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + return x + hidden + + +class AttnBlock(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.norm = Normalize(in_channels, dtype=dtype, device=device) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0, + dtype=dtype, + device=device, + ) + + def forward(self, x): + hidden = self.norm(x) + q = self.q(hidden) + k = self.k(hidden) + v = self.v(hidden) + b, c, h, w = q.shape + q, k, v = map( + lambda x: einops.rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), + (q, k, v), + ) + hidden = torch.nn.functional.scaled_dot_product_attention( + q, k, v + ) # scale is dim ** -0.5 per default + hidden = einops.rearrange(hidden, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) + hidden = self.proj_out(hidden) + return x + hidden + + +class Downsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0, + dtype=dtype, + device=device, + ) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(torch.nn.Module): + def __init__(self, in_channels, dtype=torch.float32, device=None): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + return x + + +class VAEEncoder(torch.nn.Module): + def __init__( + self, + ch=128, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + in_channels=3, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, + ch, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = torch.nn.ModuleList() + for i_level in range(self.num_resolutions): + block = torch.nn.ModuleList() + attn = torch.nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dtype=dtype, + device=device, + ) + ) + block_in = block_out + down = torch.nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, dtype=dtype, device=device) + self.down.append(down) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dtype=dtype, device=device + ) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dtype=dtype, device=device + ) + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, x): + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1]) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + # middle + h = hs[-1] + h = self.mid.block_1(h) + h = self.mid.attn_1(h) + h = self.mid.block_2(h) + # end + h = self.norm_out(h) + h = self.swish(h) + h = self.conv_out(h) + return h + + +class VAEDecoder(torch.nn.Module): + def __init__( + self, + ch=128, + out_ch=3, + ch_mult=(1, 2, 4, 4), + num_res_blocks=2, + resolution=256, + z_channels=16, + dtype=torch.float32, + device=None, + ): + super().__init__() + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + # middle + self.mid = torch.nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dtype=dtype, device=device + ) + self.mid.attn_1 = AttnBlock(block_in, dtype=dtype, device=device) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, out_channels=block_in, dtype=dtype, device=device + ) + # upsampling + self.up = torch.nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = torch.nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + dtype=dtype, + device=device, + ) + ) + block_in = block_out + up = torch.nn.Module() + up.block = block + if i_level != 0: + up.upsample = Upsample(block_in, dtype=dtype, device=device) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + # end + self.norm_out = Normalize(block_in, dtype=dtype, device=device) + self.conv_out = torch.nn.Conv2d( + block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1, + dtype=dtype, + device=device, + ) + self.swish = torch.nn.SiLU(inplace=True) + + def forward(self, z): + # z to block_in + hidden = self.conv_in(z) + # middle + hidden = self.mid.block_1(hidden) + hidden = self.mid.attn_1(hidden) + hidden = self.mid.block_2(hidden) + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + hidden = self.up[i_level].block[i_block](hidden) + if i_level != 0: + hidden = self.up[i_level].upsample(hidden) + # end + hidden = self.norm_out(hidden) + hidden = self.swish(hidden) + hidden = self.conv_out(hidden) + return hidden + + +class SDVAE(torch.nn.Module): + def __init__(self, dtype=torch.float32, device=None): + super().__init__() + self.encoder = VAEEncoder(dtype=dtype, device=device) + self.decoder = VAEDecoder(dtype=dtype, device=device) + + @torch.autocast("cuda", dtype=torch.float16) + def decode(self, latent): + return self.decoder(latent) + + @torch.autocast("cuda", dtype=torch.float16) + def encode(self, image): + hidden = self.encoder(image) + mean, logvar = torch.chunk(hidden, 2, dim=1) + logvar = torch.clamp(logvar, -30.0, 20.0) + std = torch.exp(0.5 * logvar) + return mean + std * torch.randn_like(mean) diff --git a/flux.1-dev-trt-b200/model/model.py b/flux.1-dev-trt-b200/model/model.py new file mode 100644 index 000000000..c99bda24c --- /dev/null +++ b/flux.1-dev-trt-b200/model/model.py @@ -0,0 +1,328 @@ +import argparse +import base64 +import os +import sys +import time +from io import BytesIO +from typing import Any + +import torch +from cuda import cudart +from huggingface_hub import snapshot_download +from PIL import Image + +# Add the current directory to Python path so demo_diffusion can be imported +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.insert(0, current_dir) + +from demo_diffusion import dd_argparse +from demo_diffusion import pipeline as pipeline_module + + +class Model: + def __init__(self, **kwargs): + self._model = None + self.pipe = None + self.args = None + self._secrets = kwargs.get("secrets") + + def load(self): + """Load the Flux model with TensorRT engines and perform warmup runs.""" + print("[I] Initializing Flux txt2img model using TensorRT") + + if self._secrets and "hf_access_token" in self._secrets: + os.environ["HF_TOKEN"] = self._secrets["hf_access_token"] + + # Download TensorRT engine files (with error handling for private repos) + try: + snapshot_download( + repo_id="baseten-admin/flux.1-dev-trt-10.13.0.35.engine-B200", + local_dir="/app/data", + ) + print("[I] TensorRT engine files downloaded successfully") + except Exception as e: + print(f"[W] Could not download TensorRT engines from Hugging Face: {e}") + print("[I] Assuming TensorRT engine files are already present in /app/data") + + # Create arguments for the pipeline + self.args = self._create_args() + + # Initialize the pipeline + self.pipe = pipeline_module.FluxPipeline.FromArgs( + self.args, pipeline_type=pipeline_module.PIPELINE_TYPE.TXT2IMG + ) + + # Load TensorRT engines and pytorch modules + print("[I] Loading TensorRT engines and ONNX models...") + _, kwargs_load_engine, _ = dd_argparse.process_pipeline_args(self.args) + self.pipe.load_engines( + framework_model_dir="/app/data", + **kwargs_load_engine, + ) + + # Allocate device memory + if self.pipe.low_vram: + self.pipe.device_memory_sizes = self.pipe.get_device_memory_sizes() + else: + _, shared_device_memory = cudart.cudaMalloc( + self.pipe.calculate_max_device_memory() + ) + self.pipe.activate_engines(shared_device_memory) + + # Load resources for default dimensions + self.pipe.load_resources( + self.args.height, self.args.width, self.args.batch_size, self.args.seed + ) + + # Perform warmup runs + print("[I] Performing warmup runs...") + self._perform_warmup_runs() + print("[I] Model loaded successfully") + + def _create_args(self) -> argparse.Namespace: + """Create argument namespace for the Flux pipeline.""" + parser = argparse.ArgumentParser(description="Flux Txt2Img Model") + + # Add all necessary arguments + parser = dd_argparse.add_arguments(parser) + + # Add only Flux-specific arguments that aren't already in dd_argparse + parser.add_argument( + "--prompt2", + default=None, + nargs="*", + help="Text prompt(s) to be sent to the T5 tokenizer and text encoder. If not defined, prompt will be used instead", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=512, + help="Maximum sequence length to use with the prompt. Can be up to 512 for the dev and 256 for the schnell variant.", + ) + parser.add_argument( + "--t5-ws-percentage", + type=int, + default=None, + help="Set runtime weight streaming budget as the percentage of the size of streamable weights for the T5 model.", + ) + parser.add_argument( + "--transformer-ws-percentage", + type=int, + default=None, + help="Set runtime weight streaming budget as the percentage of the size of streamable weights for the transformer model.", + ) + + # Set default values for required arguments + default_args = [ + "a beautiful landscape", # prompt + "--version", + "flux.1-dev", + "--height", + "1024", + "--width", + "1024", + "--batch-size", + "1", + "--batch-count", + "1", + "--denoising-steps", + "50", + "--guidance-scale", + "3.5", + "--max_sequence_length", + "512", + "--fp4", + "--download-onnx-models", # Required for FP4 models since native export is not supported + "--onnx-dir", + "/app/data/onnx", # Directory for ONNX files + "--engine-dir", + "/app/data/engine", # Directory for TensorRT engines + "--custom-engine-paths", + "clip:/app/data/clip/engine_trt10.13.0.35.plan,t5:/app/data/t5/engine_trt10.13.0.35.plan,transformer:/app/data/transformer_fp4/engine_trt10.13.0.35.plan,vae:/app/data/vae/engine_trt10.13.0.35.plan", + "--num-warmup-runs", + "2", + "--seed", + "42", + "--framework-model-dir", + "/app/data", # Point to the directory containing the model files + ] + + return parser.parse_args(default_args) + + def _perform_warmup_runs(self): + """Perform warmup runs to optimize CUDA kernels.""" + warmup_prompt = "a beautiful landscape with mountains and a lake, photorealistic, high quality" + + for i in range(self.args.num_warmup_runs): + print(f"[I] Warmup run {i + 1}/{self.args.num_warmup_runs}") + + # Create warmup arguments + kwargs_run_demo = { + "prompt": [warmup_prompt], + "prompt2": [warmup_prompt], + "height": self.args.height, + "width": self.args.width, + "batch_count": 1, + "num_warmup_runs": 0, # Don't do nested warmups + "use_cuda_graph": self.args.use_cuda_graph, + } + + # Run warmup inference + self.pipe.run(**kwargs_run_demo) + + def predict(self, model_input: Any) -> Any: + """Generate image from text prompt.""" + # Extract parameters from model input + prompt = model_input.pop("prompt", "a beautiful landscape") + negative_prompt = model_input.pop("negative_prompt", "") + height = model_input.pop("height", 1024) + width = model_input.pop("width", 1024) + num_inference_steps = model_input.pop("num_inference_steps", 50) + guidance_scale = model_input.pop("guidance_scale", 3.5) + seed = model_input.pop("seed", None) + batch_size = model_input.pop("batch_size", 1) + batch_count = model_input.pop("batch_count", 1) + + # If prompt is a list, use its length as the batch size + if isinstance(prompt, list) and len(prompt) > 1: + batch_size = len(prompt) + + # Validate dimensions + if height % 8 != 0 or width % 8 != 0: + raise ValueError("Height and width must be multiples of 8") + + # Set seed if provided + if seed is not None: + torch.manual_seed(seed) + self.args.seed = seed + + # Update pipeline resources if dimensions changed + if ( + height != self.args.height + or width != self.args.width + or batch_size != self.args.batch_size + ): + self.pipe.load_resources(height, width, batch_size, self.args.seed) + self.args.height = height + self.args.width = width + self.args.batch_size = batch_size + + # Prepare prompts + if not isinstance(prompt, list): + prompt = [prompt] + # Only duplicate if we have 1 prompt but need multiple + if len(prompt) == 1 and batch_size > 1: + prompt = prompt * batch_size + + # Use prompt2 if provided, otherwise use prompt + prompt2 = model_input.pop("prompt2", None) + if prompt2 is None: + prompt2 = prompt + elif not isinstance(prompt2, list): + prompt2 = [prompt2] + # Only duplicate if we have 1 prompt2 but need multiple + if len(prompt2) == 1 and batch_size > 1: + prompt2 = prompt2 * batch_size + + # Prepare negative prompt + if not isinstance(negative_prompt, list): + negative_prompt = [negative_prompt] + # Only duplicate if we have 1 negative_prompt but need multiple + if len(negative_prompt) == 1 and batch_size > 1: + negative_prompt = negative_prompt * batch_size + + # Update guidance scale + self.args.guidance_scale = guidance_scale + + # Update denoising steps + self.args.denoising_steps = num_inference_steps + + start_time = time.time() + + try: + # Run inference directly using the infer method + latents, walltime_ms = self.pipe.infer( + prompt=prompt, + prompt2=prompt2, + image_height=height, + image_width=width, + warmup=False, + save_image=False, # Don't save to file, we'll handle it ourselves + ) + + # Process the returned latents the same way the pipeline does when save_image=True + # The latents returned are raw tensor data that need to be processed into images + processed_images = ( + ((latents + 1) * 255 / 2) + .clamp(0, 255) + .detach() + .permute(0, 2, 3, 1) + .round() + .type(torch.uint8) + .cpu() + .numpy() + ) + + # Convert numpy arrays to PIL Images + pil_images = [] + for image in processed_images: + pil_image = Image.fromarray(image) + pil_images.append(pil_image) + + # Convert images to base64 + b64_images = [] + for image in pil_images: + b64_images.append(self.convert_to_b64(image)) + + end_time = time.time() - start_time + + print( + f"[I] Generated {len(processed_images)} images in {end_time:.2f} seconds" + ) + + # Return results + if len(b64_images) == 1: + return { + "status": "success", + "data": b64_images[0], + "time": end_time, + "prompt": prompt[0] if len(prompt) == 1 else prompt, + "negative_prompt": negative_prompt[0] + if len(negative_prompt) == 1 + else negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "seed": seed, + } + else: + return { + "status": "success", + "data": b64_images, + "time": end_time, + "prompt": prompt, + "negative_prompt": negative_prompt, + "height": height, + "width": width, + "num_inference_steps": num_inference_steps, + "guidance_scale": guidance_scale, + "seed": seed, + } + + except Exception as e: + end_time = time.time() - start_time + print(f"[E] Error during inference: {str(e)}") + return { + "status": "error", + "error": str(e), + "time": end_time, + } + + def convert_to_b64(self, image: Image.Image) -> str: + """Convert PIL image to base64 string.""" + buffered = BytesIO() + image.save(buffered, format="JPEG", quality=95) + img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8") + return img_b64 diff --git a/flux.1-dev-trt-b200/requirements.txt b/flux.1-dev-trt-b200/requirements.txt new file mode 100644 index 000000000..6e6230b0c --- /dev/null +++ b/flux.1-dev-trt-b200/requirements.txt @@ -0,0 +1,276 @@ +absl-py==2.3.0 +accelerate==1.9.0 +annotated-types==0.7.0 +anyio==4.9.0 +anykeystore==0.2 +apex==0.9.10.dev0 +argon2-cffi==25.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==3.0.0 +astunparse==1.6.3 +async-lru==2.0.5 +attrs==25.3.0 +audioread==3.0.1 +babel==2.17.0 +beautifulsoup4==4.13.4 +black==25.1.0 +bleach==6.2.0 +blis==0.7.11 +catalogue==2.0.10 +certifi==2025.4.26 +cffi==1.17.1 +charset-normalizer==3.4.2 +click==8.2.1 +cloudpathlib==0.21.1 +cmake==3.31.6 +colored==2.3.0 +coloredlogs==15.0.1 +comm==0.2.2 +confection==0.1.5 +contourpy==1.3.2 +controlnet-aux==0.0.6 +cppimport==22.8.2 +cryptacular==1.6.2 +cycler==0.12.1 +cymem==2.0.11 +Cython==3.1.1 +debugpy==1.8.14 +decorator==5.2.1 +defusedxml==0.7.1 +diffusers @ git+https://github.com/huggingface/diffusers.git@3335e2262d47e7d7e311a44dea7f454b5f01b643 +dill==0.4.0 +dm-tree==0.1.9 +einops==0.8.1 +execnet==2.1.1 +executing==2.2.0 +expecttest==0.3.0 +fastjsonschema==2.21.1 +filelock==3.18.0 +flatbuffers==25.2.10 +fonttools==4.58.1 +fqdn==1.5.1 +fsspec==2025.5.1 +ftfy==6.3.1 +gast==0.6.0 +greenlet==3.0.3 +grpcio==1.62.1 +h11==0.16.0 +hf-xet==1.1.5 +hf_transfer==0.1.9 +httpcore==1.0.9 +httpx==0.28.1 +huggingface-hub==0.33.4 +humanfriendly==10.0 +hupper==1.12.1 +hypothesis==6.130.8 +idna==3.10 +imageio==2.37.0 +importlib_metadata==8.7.0 +iniconfig==2.1.0 +intel-openmp==2021.4.0 +ipykernel==6.29.5 +ipython==9.3.0 +ipython_pygments_lexers==1.1.1 +isoduration==20.11.0 +isort==6.0.1 +jedi==0.19.2 +Jinja2==3.1.6 +joblib==1.5.1 +json5==0.12.0 +jsonpointer==3.0.0 +jsonschema==4.24.0 +jsonschema-specifications==2025.4.1 +jupyter-events==0.12.0 +jupyter-lsp==2.2.5 +jupyter_client==8.6.3 +jupyter_core==5.8.1 +jupyter_server==2.16.0 +jupyter_server_terminals==0.5.3 +jupyterlab==4.4.3 +jupyterlab_code_formatter==3.0.2 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.27.3 +jupytext==1.17.2 +kiwisolver==1.4.8 +langcodes==3.5.0 +language_data==1.3.0 +lazy_loader==0.4 +librosa==0.11.0 +lightning-utilities==0.14.3 +lintrunner==0.12.7 +llvmlite==0.42.0 +looseversion==1.3.0 +Mako==1.3.10 +marisa-trie==1.2.1 +Markdown==3.8 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib==3.10.3 +matplotlib-inline==0.1.7 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +mistune==3.1.3 +mkl==2021.1.1 +mkl-devel==2021.1.1 +mkl-include==2021.1.1 +mock==5.2.0 +mpmath==1.3.0 +msgpack==1.1.0 +murmurhash==1.0.13 +mypy_extensions==1.1.0 +nbclient==0.10.2 +nbconvert==7.16.6 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.5 +ninja==1.11.1.4 +notebook==7.4.3 +notebook_shim==0.2.4 +numba==0.59.1 +numpy==1.26.4 +nvfuser==0.2.27a0+9bf5aca +nvidia-cuda-runtime-cu12==12.9.79 +nvidia-dali-cuda120==1.50.0 +nvidia-modelopt==0.29.0 +nvidia-modelopt-core==0.29.0 +nvidia-nvcomp-cu12==4.2.0.14 +nvidia-nvimgcodec-cu12==0.5.0.13 +nvidia-nvjpeg-cu12==12.4.0.16 +nvidia-nvjpeg2k-cu12==0.8.1.40 +nvidia-nvtiff-cu12==0.5.0.67 +nvidia-resiliency-ext==0.4.0 +oauthlib==3.3.1 +onnx-graphsurgeon==0.5.2 +onnxconverter-common==1.15.0 +onnxmltools==1.13.0 +onnxruntime==1.19.2 +onnxruntime-gpu==1.20.2 +opencv-python==4.11.0.86 +opencv-python-headless==4.8.0.74 +opt_einsum==3.4.0 +optree==0.16.0 +overrides==7.7.0 +packaging==23.2 +pandocfilters==1.5.1 +parso==0.8.4 +PasteDeploy==3.1.0 +pathspec==0.12.1 +pbkdf2==1.3 +peft==0.13.0 +pexpect==4.9.0 +pillow==11.2.1 +plaster==1.1.2 +plaster-pastedeploy==1.0.1 +platformdirs==4.3.8 +pluggy==1.6.0 +polygraphy==0.49.9 +pooch==1.8.2 +preshed==3.0.10 +prometheus_client==0.22.1 +prompt_toolkit==3.0.51 +protobuf==4.24.4 +psutil==7.0.0 +ptyprocess==0.7.0 +PuLP==3.2.1 +pure_eval==0.2.3 +pybind11==2.13.6 +pybind11_global==2.13.6 +pycparser==2.22 +pydantic==2.11.5 +pydantic_core==2.33.2 +Pygments==2.19.1 +pynvim==0.5.0 +pyparsing==3.2.3 +pyramid==2.0.2 +pyramid-mailer==0.15.1 +pytest==8.1.1 +pytest-flakefinder==1.1.0 +pytest-rerunfailures==15.1 +pytest-shard==0.1.2 +pytest-xdist==3.7.0 +python-dateutil==2.9.0.post0 +python-hostlist==2.2.1 +python-json-logger==3.3.0 +python3-openid==3.2.0 +PyYAML==6.0.2 +pyzmq==26.4.0 +referencing==0.36.2 +regex==2024.11.6 +repoze.sendmail==4.4.1 +requests==2.32.3 +requests-oauthlib==2.0.0 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==14.0.0 +rpds-py==0.25.1 +safetensors==0.5.3 +scikit-image==0.25.2 +scikit-learn==1.6.1 +scipy==1.15.3 +Send2Trash==1.8.3 +sentencepiece==0.2.0 +setuptools==78.1.1 +shellingham==1.5.4 +six==1.16.0 +smart-open==7.1.0 +sniffio==1.3.1 +sortedcontainers==2.4.0 +soundfile==0.13.1 +soupsieve==2.7 +soxr==0.5.0.post1 +spacy==3.7.5 +spacy-legacy==3.0.12 +spacy-loggers==1.0.5 +SQLAlchemy==2.0.41 +srsly==2.5.1 +stack-data==0.6.3 +supervisor==4.2.5 +sympy==1.14.0 +tabulate==0.9.0 +tbb==2021.13.1 +tensorboard==2.16.2 +tensorboard-data-server==0.7.2 +tensorrt_cu12==10.13.0.35 +tensorrt_cu12_bindings==10.13.0.35 +tensorrt_cu12_libs==10.13.0.35 +terminado==0.18.1 +thinc==8.2.5 +threadpoolctl==3.6.0 +tifffile==2025.6.11 +timm==1.0.18 +tinycss2==1.4.0 +tokenizers==0.19.1 +torchprofile==0.0.4 +tornado==6.5.1 +tqdm==4.67.1 +traitlets==5.14.3 +transaction==5.0 +transformers==4.42.2 +translationstring==1.4 +typer==0.16.0 +types-dataclasses==0.6.6 +types-python-dateutil==2.9.0.20250516 +typing-inspection==0.4.1 +typing_extensions==4.14.0 +uri-template==1.3.0 +urllib3==2.0.7 +velruse==1.1.1 +venusian==3.1.1 +wasabi==1.1.3 +wcwidth==0.2.13 +weasel==0.4.1 +webcolors==24.11.1 +webencodings==0.5.1 +WebOb==1.8.9 +websocket-client==1.8.0 +Werkzeug==3.1.3 +wheel==0.45.1 +wrapt==1.17.2 +WTForms==3.2.1 +wtforms-recaptcha==0.3.2 +xdoctest==1.0.2 +zipp==3.22.0 +zope.deprecation==5.1 +zope.interface==7.2 +zope.sqlalchemy==3.1 diff --git a/flux.1-dev-trt-b200/show.py b/flux.1-dev-trt-b200/show.py new file mode 100644 index 000000000..24b78dab0 --- /dev/null +++ b/flux.1-dev-trt-b200/show.py @@ -0,0 +1,18 @@ +""" +truss predict -d '{"prompt": "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"}' | python show.py +""" + +import base64 +import json +import os +import sys + +resp = sys.stdin.read() +image = json.loads(resp)["data"] +img = base64.b64decode(image) + +file_name = f"{image[-10:].replace('/', '')}.jpeg" +img_file = open(file_name, "wb") +img_file.write(img) +img_file.close() +os.system(f"open {file_name}") diff --git a/flux.1-dev-trt-b200/show_batch.py b/flux.1-dev-trt-b200/show_batch.py new file mode 100644 index 000000000..e28493ec6 --- /dev/null +++ b/flux.1-dev-trt-b200/show_batch.py @@ -0,0 +1,40 @@ +""" +Modified show.py script to handle batch responses with multiple images +Usage: curl ... | python show_batch.py +""" + +import base64 +import json +import os +import sys + +resp = sys.stdin.read() +response_data = json.loads(resp) + +# Check if the response contains multiple images or a single image +if "data" in response_data: + data = response_data["data"] + + # If data is a list, we have multiple images + if isinstance(data, list): + print(f"Received {len(data)} images from batch request") + for i, image_b64 in enumerate(data): + img = base64.b64decode(image_b64) + file_name = f"batch_image_{i + 1}_{image_b64[-10:].replace('/', '')}.jpeg" + img_file = open(file_name, "wb") + img_file.write(img) + img_file.close() + print(f"Saved image {i + 1} as {file_name}") + os.system(f"open {file_name}") + else: + # Single image case + img = base64.b64decode(data) + file_name = f"single_image_{data[-10:].replace('/', '')}.jpeg" + img_file = open(file_name, "wb") + img_file.write(img) + img_file.close() + print(f"Saved single image as {file_name}") + os.system(f"open {file_name}") +else: + print("Error: No 'data' field found in response") + print("Response:", response_data)