Skip to content

Commit 1a343fc

Browse files
fix amd ci (#667)
1 parent 4ad5a50 commit 1a343fc

File tree

5 files changed

+23
-40
lines changed

5 files changed

+23
-40
lines changed

dockerfiles/Dockerfile.mi350

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,3 @@ ENV HSA_FORCE_FINE_GRAIN_PCIE=1
4545

4646
# Default command
4747
CMD ["python", "-c", "from lightx2v import LightX2VPipeline; print('LightX2V AMD ROCm ready!')"]
48-

lightx2v_platform/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from lightx2v_platform.base.base import check_ai_device, init_ai_device
1+
from lightx2v_platform.base.base import check_ai_device, init_ai_device # noqa
22
from lightx2v_platform.base.amd_rocm import AmdRocmDevice
33
from lightx2v_platform.base.cambricon_mlu import MluDevice
44
from lightx2v_platform.base.hygon_dcu import HygonDcuDevice

lightx2v_platform/base/amd_rocm.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
"""
99

1010
import sys
11+
1112
import torch
1213
import torch.distributed as dist
13-
1414
from loguru import logger
1515

1616
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
@@ -33,40 +33,40 @@
3333
class AiterSglKernelCompat:
3434
"""
3535
Compatibility layer to use aiter with sgl_kernel interface.
36-
36+
3737
This class wraps aiter functions to match sgl_kernel's API,
3838
allowing existing code to work seamlessly on AMD GPUs.
39-
39+
4040
Note: This is REQUIRED on AMD ROCm as the original sgl_kernel
4141
does not support AMD GPUs.
4242
"""
43-
43+
4444
def __init__(self, aiter_module):
4545
self._aiter = aiter_module
4646
self._gemm_a8w8 = aiter_module.gemm_a8w8_CK
4747
self._pertoken_quant = aiter_module.pertoken_quant
4848
self._dtypes = aiter_module.dtypes
4949
self._rms_norm = aiter_module.rms_norm
5050
logger.info("Using aiter as sgl_kernel backend (AMD ROCm optimized)")
51-
51+
5252
def rmsnorm(self, input, weight, eps):
5353
"""RMSNorm compatible with sgl_kernel.rmsnorm(input, weight, eps)"""
5454
return self._rms_norm(input, weight, eps)
55-
55+
5656
def fp8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None):
5757
"""FP8 GEMM compatible with sgl_kernel.fp8_scaled_mm"""
5858
return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype)
59-
59+
6060
def int8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None):
6161
"""INT8 GEMM compatible with sgl_kernel.int8_scaled_mm"""
6262
return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype)
63-
63+
6464
def sgl_per_token_quant_fp8(self, x, out, scale):
6565
"""Per-token FP8 quantization compatible with sgl_kernel.sgl_per_token_quant_fp8"""
6666
q, s = self._pertoken_quant(x, quant_dtype=self._dtypes.fp8)
6767
out.copy_(q)
6868
scale.copy_(s)
69-
69+
7070
def sgl_per_token_group_quant_fp8(self, x, out, scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0):
7171
"""Per-token per-group FP8 quantization compatible with sgl_kernel.sgl_per_token_group_quant_fp8"""
7272
m, k = x.shape
@@ -82,20 +82,13 @@ def _get_aiter_sgl_kernel():
8282
"""Get aiter-based sgl_kernel compatibility layer."""
8383
try:
8484
import aiter
85+
8586
return AiterSglKernelCompat(aiter)
8687
except ImportError:
8788
logger.error(
88-
f"\n{'='*60}\n"
89-
f"ERROR: AMD ROCm detected but aiter is not installed.\n"
90-
f"aiter is REQUIRED for LightX2V to work on AMD GPUs.\n"
91-
f"\nPlease install aiter:\n"
92-
f"{AITER_INSTALL_CMD}\n"
93-
f"{'='*60}\n"
94-
)
95-
raise ImportError(
96-
"aiter is required for AMD ROCm support. "
97-
f"Please install: pip install git+{AITER_REPO}@{AITER_COMMIT}"
89+
f"\n{'=' * 60}\nERROR: AMD ROCm detected but aiter is not installed.\naiter is REQUIRED for LightX2V to work on AMD GPUs.\n\nPlease install aiter:\n{AITER_INSTALL_CMD}\n{'=' * 60}\n"
9890
)
91+
raise ImportError(f"aiter is required for AMD ROCm support. Please install: pip install git+{AITER_REPO}@{AITER_COMMIT}")
9992

10093

10194
@PLATFORM_DEVICE_REGISTER("amd_rocm")
@@ -113,24 +106,24 @@ class AmdRocmDevice:
113106
def init_device_env():
114107
"""
115108
Initialize AMD ROCm optimizations.
116-
109+
117110
This is called from lightx2v_platform.set_ai_device when platform is amd_rocm.
118111
1. Disable cudnn for faster VAE convolution
119112
2. Inject aiter as sgl_kernel compatibility layer (REQUIRED on AMD)
120113
"""
121114
logger.info("AMD ROCm platform detected, initializing optimizations...")
122-
115+
123116
# Disable cudnn for faster VAE conv computation
124117
torch.backends.cudnn.enabled = False
125118
logger.info(" - cudnn disabled for faster VAE convolution")
126-
119+
127120
# Inject aiter as sgl_kernel compatibility layer (REQUIRED)
128121
sgl_kernel = _get_aiter_sgl_kernel()
129122
sys.modules["sgl_kernel"] = sgl_kernel
130123
# Update any module that already imported sgl_kernel
131124
for mod_name, mod in list(sys.modules.items()):
132-
if mod is not None and hasattr(mod, 'sgl_kernel'):
133-
setattr(mod, 'sgl_kernel', sgl_kernel)
125+
if mod is not None and hasattr(mod, "sgl_kernel"):
126+
setattr(mod, "sgl_kernel", sgl_kernel)
134127
logger.info(" - aiter sgl_kernel compatibility layer enabled (RMSNorm, GEMM)")
135128

136129
@staticmethod
@@ -153,7 +146,7 @@ def init_parallel_env():
153146
# Export constants
154147
__all__ = [
155148
"IS_AMD_ROCM",
156-
"AITER_REPO",
149+
"AITER_REPO",
157150
"AITER_COMMIT",
158151
"AITER_INSTALL_CMD",
159152
"AiterSglKernelCompat",
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .flash_attn import *
2-

lightx2v_platform/ops/attn/amd_rocm/flash_attn.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@
3737
except ImportError as e:
3838
AITER_IMPORT_ERROR = str(e)
3939
if IS_AMD_ROCM:
40-
logger.warning(
41-
f"aiter not found on AMD ROCm platform. "
42-
f"For optimal performance, please install aiter:\n{AITER_INSTALL_CMD}"
43-
)
40+
logger.warning(f"aiter not found on AMD ROCm platform. For optimal performance, please install aiter:\n{AITER_INSTALL_CMD}")
4441
else:
4542
logger.debug("aiter not found (only available on AMD ROCm platform)")
4643

@@ -61,22 +58,18 @@ class AiterAttnWeight(AttnWeightTemplate):
6158

6259
def __init__(self):
6360
self.config = {}
64-
61+
6562
# Check platform first
6663
if not IS_AMD_ROCM:
6764
raise RuntimeError(
6865
"aiter_attn is only available on AMD ROCm platform.\n"
6966
"Current platform is not AMD ROCm (torch.version.hip is not set).\n"
7067
"For NVIDIA GPUs, please use 'flash_attn2' or 'flash_attn3' instead."
7168
)
72-
69+
7370
# Check aiter availability
7471
if not AITER_AVAILABLE:
75-
raise ImportError(
76-
f"aiter is not installed on AMD ROCm platform.\n"
77-
f"Import error: {AITER_IMPORT_ERROR}\n"
78-
f"Please install aiter for optimal performance:\n{AITER_INSTALL_CMD}"
79-
)
72+
raise ImportError(f"aiter is not installed on AMD ROCm platform.\nImport error: {AITER_IMPORT_ERROR}\nPlease install aiter for optimal performance:\n{AITER_INSTALL_CMD}")
8073

8174
def apply(
8275
self,
@@ -107,4 +100,3 @@ def apply(
107100
max_seqlen_kv,
108101
).reshape(bs * max_seqlen_q, -1)
109102
return x
110-

0 commit comments

Comments
 (0)