Skip to content

Commit 4ad5a50

Browse files
Add support ROCM AMD GPU (#661)
# AMD ROCm Support for LightX2V ## Summary This PR adds AMD ROCm support for LightX2V, enabling high-performance video generation on AMD GPUs. The implementation leverages the `aiter` library for optimized kernels. ## End-to-End Performance ### Wan2.1-T2V-1.3B (33 frames, 480×848, 20 steps) | Configuration | Time | Speedup | |---------------|------|---------| | After PR (`flash_attn2`) | 13.96s | 1.00x | | **After PR (`aiter_attn`)** | **6.31s** | **2.21x** | ### Qwen-Image-Edit-2511 (1024×1024, 8 steps) | Configuration | Time | Speedup | |---------------|------|---------| | After PR (`flash_attn2`) | 4.83s | 1.00x | | **After PR (`aiter_attn`)** | **2.86s** | **1.69x** | > **Note**: Before this PR, LightX2V does not work on AMD ROCm due to missing `sgl_kernel` support. This PR enables AMD support and provides significant performance optimizations. ## Optimizations Added | Optimization | Description | How to Enable | |--------------|-------------|---------------| | **sgl_kernel replacement** | Uses `aiter` for RMSNorm, FP8/INT8 GEMM | Automatic (required on AMD) | | **VAE cudnn disabled** | Disables cudnn for faster convolution | Automatic | | **aiter_attn** | Optimized Flash Attention using aiter FA3 | `attn_mode="aiter_attn"` | ### Kernel-Level Performance | Kernel | Speedup (vs baseline) | |--------|----------------------| | Flash Attention (aiter vs flash_attn) | 5.5x - 5.7x | | RMSNorm (aiter vs torch) | 3x - 4.2x | | VAE Conv (cudnn disabled) | ~1.2x | ## Usage ### Environment Setup Set the platform environment variable for AMD ROCm: ```bash export PLATFORM=amd_rocm ``` ### Wan Models (Text-to-Video) ```python from lightx2v import LightX2VPipeline # Initialize pipeline pipe = LightX2VPipeline( model_path="/path/to/Wan2.1-T2V-14B", model_cls="wan2.1", task="t2v", ) # Create generator with aiter_attn for AMD ROCm optimization pipe.create_generator( attn_mode="aiter_attn", # Use aiter_attn for 2.21x speedup on AMD infer_steps=50, height=480, width=832, num_frames=81, guidance_scale=5.0, sample_shift=5.0, ) # Generate video pipe.generate( seed=42, prompt="Two anthropomorphic cats in comfy boxing gear fight on a spotlighted stage.", negative_prompt="", save_result_path="output.mp4", ) ``` ### Qwen-Image-Edit Models (Image-to-Image) ```python from lightx2v import LightX2VPipeline # Initialize pipeline pipe = LightX2VPipeline( model_path="/path/to/Qwen-Image-Edit-2511", model_cls="qwen-image-edit-2511", task="i2i", ) # Create generator with aiter_attn for AMD ROCm optimization pipe.create_generator( attn_mode="aiter_attn", # Use aiter_attn for 1.69x speedup on AMD auto_resize=True, infer_steps=8, guidance_scale=1, ) # Generate image pipe.generate( seed=42, image_path="input.png", prompt="Replace the shirt with a light blue shirt.", negative_prompt="", save_result_path="output.png", ) ``` ## Installation ### Prerequisites For AMD ROCm platform, `aiter` is **required**: ```bash git clone https://github.com/ROCm/aiter.git /tmp/aiter && \ cd /tmp/aiter && \ git checkout a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e && \ pip install -e . ``` ### Docker Support Build AMD ROCm image: ```bash docker build -f dockerfiles/Dockerfile.mi350 -t lightx2v:rocm . ``` Run container: ```bash docker run --device=/dev/kfd --device=/dev/dri \ --group-add video --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ -v /path/to/models:/models \ lightx2v:rocm python your_script.py ``` ## Platform Detection AMD ROCm is automatically detected: ```python IS_AMD_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None ``` When AMD ROCm is detected: 1. `aiter` is injected as `sgl_kernel` replacement (required for AMD support) 2. `cudnn` is automatically disabled (VAE optimization) 3. User can use `attn_mode="aiter_attn"` for additional Flash Attention optimization
1 parent e831d84 commit 4ad5a50

File tree

7 files changed

+328
-4
lines changed

7 files changed

+328
-4
lines changed

dockerfiles/Dockerfile.mi350

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Dockerfile for LightX2V on AMD ROCm platform
2+
# Base image: SGLang with ROCm 7.0.0 for MI300X
3+
FROM lmsysorg/sglang:v0.5.6.post2-rocm700-mi35x
4+
5+
LABEL maintainer="LightX2V Contributors"
6+
LABEL description="LightX2V video generation framework with AMD ROCm support"
7+
8+
# Set working directory
9+
WORKDIR /workspace
10+
11+
# Install system dependencies
12+
RUN apt-get update && apt-get install -y --no-install-recommends \
13+
git \
14+
ffmpeg \
15+
libsm6 \
16+
libxext6 \
17+
&& rm -rf /var/lib/apt/lists/*
18+
19+
# Install aiter (AMD ROCm optimized kernels)
20+
# Commit: a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e
21+
ARG AITER_COMMIT=a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e
22+
RUN git clone https://github.com/ROCm/aiter.git /tmp/aiter && \
23+
cd /tmp/aiter && \
24+
git checkout ${AITER_COMMIT} && \
25+
pip install --no-cache-dir -e . && \
26+
rm -rf /tmp/aiter/.git
27+
28+
# Install flash-attn for ROCm
29+
RUN pip install --no-cache-dir flash-attn --no-build-isolation
30+
31+
# Copy LightX2V source
32+
COPY . /workspace/LightX2V
33+
34+
# Install LightX2V dependencies
35+
WORKDIR /workspace/LightX2V
36+
RUN pip install --no-cache-dir -r requirements.txt
37+
38+
# Install LightX2V
39+
RUN pip install --no-cache-dir -e .
40+
41+
# Set environment variables for AMD ROCm
42+
ENV HIP_VISIBLE_DEVICES=0
43+
ENV ROCM_PATH=/opt/rocm
44+
ENV HSA_FORCE_FINE_GRAIN_PCIE=1
45+
46+
# Default command
47+
CMD ["python", "-c", "from lightx2v import LightX2VPipeline; print('LightX2V AMD ROCm ready!')"]
48+

lightx2v/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def set_infer_config(
203203
self.self_attn_1_type = attn_mode
204204
self.cross_attn_1_type = attn_mode
205205
self.cross_attn_2_type = attn_mode
206-
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]:
206+
elif self.model_cls in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill", "qwen_image"]:
207207
self.attn_type = attn_mode
208208

209209
def set_infer_config_json(self, config_json):

lightx2v_platform/base/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from lightx2v_platform.base.base import check_ai_device, init_ai_device
2+
from lightx2v_platform.base.amd_rocm import AmdRocmDevice
23
from lightx2v_platform.base.cambricon_mlu import MluDevice
34
from lightx2v_platform.base.hygon_dcu import HygonDcuDevice
45
from lightx2v_platform.base.metax import MetaxDevice
56
from lightx2v_platform.base.nvidia import CudaDevice
67

7-
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "HygonDcuDevice"]
8+
__all__ = ["init_ai_device", "check_ai_device", "CudaDevice", "MluDevice", "MetaxDevice", "HygonDcuDevice", "AmdRocmDevice"]

lightx2v_platform/base/amd_rocm.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
"""
2+
AMD ROCm Device implementation for LightX2V.
3+
4+
AMD ROCm provides CUDA-compatible APIs through HIP (Heterogeneous-computing Interface for Portability).
5+
This module handles AMD-specific optimizations including:
6+
- Disabling cudnn for faster VAE convolution
7+
- sgl_kernel compatibility layer using aiter library (required on AMD)
8+
"""
9+
10+
import sys
11+
import torch
12+
import torch.distributed as dist
13+
14+
from loguru import logger
15+
16+
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
17+
18+
# Detect AMD ROCm platform
19+
IS_AMD_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None
20+
21+
# aiter installation info
22+
AITER_REPO = "https://github.com/ROCm/aiter.git"
23+
AITER_COMMIT = "a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e"
24+
AITER_INSTALL_CMD = f"""
25+
# One-line install command for aiter (AMD ROCm optimized kernels):
26+
git clone {AITER_REPO} /tmp/aiter && \\
27+
cd /tmp/aiter && \\
28+
git checkout {AITER_COMMIT} && \\
29+
pip install -e .
30+
"""
31+
32+
33+
class AiterSglKernelCompat:
34+
"""
35+
Compatibility layer to use aiter with sgl_kernel interface.
36+
37+
This class wraps aiter functions to match sgl_kernel's API,
38+
allowing existing code to work seamlessly on AMD GPUs.
39+
40+
Note: This is REQUIRED on AMD ROCm as the original sgl_kernel
41+
does not support AMD GPUs.
42+
"""
43+
44+
def __init__(self, aiter_module):
45+
self._aiter = aiter_module
46+
self._gemm_a8w8 = aiter_module.gemm_a8w8_CK
47+
self._pertoken_quant = aiter_module.pertoken_quant
48+
self._dtypes = aiter_module.dtypes
49+
self._rms_norm = aiter_module.rms_norm
50+
logger.info("Using aiter as sgl_kernel backend (AMD ROCm optimized)")
51+
52+
def rmsnorm(self, input, weight, eps):
53+
"""RMSNorm compatible with sgl_kernel.rmsnorm(input, weight, eps)"""
54+
return self._rms_norm(input, weight, eps)
55+
56+
def fp8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None):
57+
"""FP8 GEMM compatible with sgl_kernel.fp8_scaled_mm"""
58+
return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype)
59+
60+
def int8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None):
61+
"""INT8 GEMM compatible with sgl_kernel.int8_scaled_mm"""
62+
return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype)
63+
64+
def sgl_per_token_quant_fp8(self, x, out, scale):
65+
"""Per-token FP8 quantization compatible with sgl_kernel.sgl_per_token_quant_fp8"""
66+
q, s = self._pertoken_quant(x, quant_dtype=self._dtypes.fp8)
67+
out.copy_(q)
68+
scale.copy_(s)
69+
70+
def sgl_per_token_group_quant_fp8(self, x, out, scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0):
71+
"""Per-token per-group FP8 quantization compatible with sgl_kernel.sgl_per_token_group_quant_fp8"""
72+
m, k = x.shape
73+
x_view = x.view(m, -1, group_size)
74+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(eps)
75+
q = (x_view * (fp8_max / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, k)
76+
s = (x_amax / fp8_max).view(m, -1)
77+
out.copy_(q)
78+
scale.copy_(s)
79+
80+
81+
def _get_aiter_sgl_kernel():
82+
"""Get aiter-based sgl_kernel compatibility layer."""
83+
try:
84+
import aiter
85+
return AiterSglKernelCompat(aiter)
86+
except ImportError:
87+
logger.error(
88+
f"\n{'='*60}\n"
89+
f"ERROR: AMD ROCm detected but aiter is not installed.\n"
90+
f"aiter is REQUIRED for LightX2V to work on AMD GPUs.\n"
91+
f"\nPlease install aiter:\n"
92+
f"{AITER_INSTALL_CMD}\n"
93+
f"{'='*60}\n"
94+
)
95+
raise ImportError(
96+
"aiter is required for AMD ROCm support. "
97+
f"Please install: pip install git+{AITER_REPO}@{AITER_COMMIT}"
98+
)
99+
100+
101+
@PLATFORM_DEVICE_REGISTER("amd_rocm")
102+
class AmdRocmDevice:
103+
"""
104+
AMD ROCm Device implementation for LightX2V.
105+
106+
AMD ROCm uses CUDA-compatible APIs through HIP.
107+
This class provides AMD-specific optimizations.
108+
"""
109+
110+
name = "amd_rocm"
111+
112+
@staticmethod
113+
def init_device_env():
114+
"""
115+
Initialize AMD ROCm optimizations.
116+
117+
This is called from lightx2v_platform.set_ai_device when platform is amd_rocm.
118+
1. Disable cudnn for faster VAE convolution
119+
2. Inject aiter as sgl_kernel compatibility layer (REQUIRED on AMD)
120+
"""
121+
logger.info("AMD ROCm platform detected, initializing optimizations...")
122+
123+
# Disable cudnn for faster VAE conv computation
124+
torch.backends.cudnn.enabled = False
125+
logger.info(" - cudnn disabled for faster VAE convolution")
126+
127+
# Inject aiter as sgl_kernel compatibility layer (REQUIRED)
128+
sgl_kernel = _get_aiter_sgl_kernel()
129+
sys.modules["sgl_kernel"] = sgl_kernel
130+
# Update any module that already imported sgl_kernel
131+
for mod_name, mod in list(sys.modules.items()):
132+
if mod is not None and hasattr(mod, 'sgl_kernel'):
133+
setattr(mod, 'sgl_kernel', sgl_kernel)
134+
logger.info(" - aiter sgl_kernel compatibility layer enabled (RMSNorm, GEMM)")
135+
136+
@staticmethod
137+
def is_available() -> bool:
138+
"""Check if AMD ROCm is available."""
139+
return IS_AMD_ROCM and torch.cuda.is_available()
140+
141+
@staticmethod
142+
def get_device() -> str:
143+
"""Get the device type string. Returns 'cuda' for ROCm compatibility."""
144+
return "cuda"
145+
146+
@staticmethod
147+
def init_parallel_env():
148+
"""Initialize distributed parallel environment for AMD ROCm."""
149+
dist.init_process_group(backend="nccl")
150+
torch.cuda.set_device(dist.get_rank())
151+
152+
153+
# Export constants
154+
__all__ = [
155+
"IS_AMD_ROCM",
156+
"AITER_REPO",
157+
"AITER_COMMIT",
158+
"AITER_INSTALL_CMD",
159+
"AiterSglKernelCompat",
160+
"AmdRocmDevice",
161+
]

lightx2v_platform/ops/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from .attn.cambricon_mlu import *
77
from .mm.cambricon_mlu import *
88
elif AI_DEVICE == "cuda":
9-
# Check if running on Hygon DCU platform
10-
if os.getenv("PLATFORM") == "hygon_dcu":
9+
platform = os.getenv("PLATFORM")
10+
if platform == "hygon_dcu":
1111
from .attn.hygon_dcu import *
12+
elif platform == "amd_rocm":
13+
from .attn.amd_rocm import *
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .flash_attn import *
2+
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
AMD ROCm optimized attention using aiter library.
3+
Provides significantly faster attention computation on AMD GPUs (2.5x-6x speedup).
4+
Internally uses FA3 (fmha_v3) when conditions are met.
5+
"""
6+
7+
import torch
8+
from loguru import logger
9+
10+
from lightx2v_platform.ops.attn.template import AttnWeightTemplate
11+
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER
12+
13+
# Detect AMD ROCm platform
14+
IS_AMD_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None
15+
16+
# aiter installation info
17+
AITER_REPO = "https://github.com/ROCm/aiter.git"
18+
AITER_COMMIT = "a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e"
19+
AITER_INSTALL_CMD = f"""
20+
# One-line install command for aiter (AMD ROCm optimized kernels):
21+
git clone {AITER_REPO} /tmp/aiter && \\
22+
cd /tmp/aiter && \\
23+
git checkout {AITER_COMMIT} && \\
24+
pip install -e .
25+
"""
26+
27+
# Try to import aiter (AMD ROCm optimized)
28+
aiter_flash_attn_varlen_func = None
29+
AITER_AVAILABLE = False
30+
AITER_IMPORT_ERROR = None
31+
32+
try:
33+
from aiter import flash_attn_varlen_func as aiter_flash_attn_varlen_func
34+
35+
AITER_AVAILABLE = True
36+
logger.info("aiter flash_attn_varlen_func found (AMD ROCm optimized)")
37+
except ImportError as e:
38+
AITER_IMPORT_ERROR = str(e)
39+
if IS_AMD_ROCM:
40+
logger.warning(
41+
f"aiter not found on AMD ROCm platform. "
42+
f"For optimal performance, please install aiter:\n{AITER_INSTALL_CMD}"
43+
)
44+
else:
45+
logger.debug("aiter not found (only available on AMD ROCm platform)")
46+
47+
48+
@PLATFORM_ATTN_WEIGHT_REGISTER("aiter_attn")
49+
class AiterAttnWeight(AttnWeightTemplate):
50+
"""
51+
AMD ROCm optimized attention using aiter library.
52+
53+
Performance:
54+
- 2.5x-6x faster than flash_attn package on AMD GPUs
55+
- Automatically uses FA3 (fmha_v3) when conditions are met
56+
57+
Requirements:
58+
- aiter library (AMD ROCm)
59+
- AMD GPU with ROCm support
60+
"""
61+
62+
def __init__(self):
63+
self.config = {}
64+
65+
# Check platform first
66+
if not IS_AMD_ROCM:
67+
raise RuntimeError(
68+
"aiter_attn is only available on AMD ROCm platform.\n"
69+
"Current platform is not AMD ROCm (torch.version.hip is not set).\n"
70+
"For NVIDIA GPUs, please use 'flash_attn2' or 'flash_attn3' instead."
71+
)
72+
73+
# Check aiter availability
74+
if not AITER_AVAILABLE:
75+
raise ImportError(
76+
f"aiter is not installed on AMD ROCm platform.\n"
77+
f"Import error: {AITER_IMPORT_ERROR}\n"
78+
f"Please install aiter for optimal performance:\n{AITER_INSTALL_CMD}"
79+
)
80+
81+
def apply(
82+
self,
83+
q,
84+
k,
85+
v,
86+
cu_seqlens_q=None,
87+
cu_seqlens_kv=None,
88+
max_seqlen_q=None,
89+
max_seqlen_kv=None,
90+
model_cls=None,
91+
):
92+
if len(q.shape) == 3:
93+
bs = 1
94+
elif len(q.shape) == 4:
95+
bs = q.shape[0]
96+
q = q.reshape(-1, q.shape[-2], q.shape[-1])
97+
k = k.reshape(-1, k.shape[-2], k.shape[-1])
98+
v = v.reshape(-1, v.shape[-2], v.shape[-1])
99+
100+
x = aiter_flash_attn_varlen_func(
101+
q,
102+
k,
103+
v,
104+
cu_seqlens_q,
105+
cu_seqlens_kv,
106+
max_seqlen_q,
107+
max_seqlen_kv,
108+
).reshape(bs * max_seqlen_q, -1)
109+
return x
110+

0 commit comments

Comments
 (0)