|
| 1 | +""" |
| 2 | +AMD ROCm Device implementation for LightX2V. |
| 3 | +
|
| 4 | +AMD ROCm provides CUDA-compatible APIs through HIP (Heterogeneous-computing Interface for Portability). |
| 5 | +This module handles AMD-specific optimizations including: |
| 6 | +- Disabling cudnn for faster VAE convolution |
| 7 | +- sgl_kernel compatibility layer using aiter library (required on AMD) |
| 8 | +""" |
| 9 | + |
| 10 | +import sys |
| 11 | +import torch |
| 12 | +import torch.distributed as dist |
| 13 | + |
| 14 | +from loguru import logger |
| 15 | + |
| 16 | +from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER |
| 17 | + |
| 18 | +# Detect AMD ROCm platform |
| 19 | +IS_AMD_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None |
| 20 | + |
| 21 | +# aiter installation info |
| 22 | +AITER_REPO = "https://github.com/ROCm/aiter.git" |
| 23 | +AITER_COMMIT = "a7d3bf8cd47afbaf6a6133c1f12e3b01d2c27b0e" |
| 24 | +AITER_INSTALL_CMD = f""" |
| 25 | +# One-line install command for aiter (AMD ROCm optimized kernels): |
| 26 | +git clone {AITER_REPO} /tmp/aiter && \\ |
| 27 | +cd /tmp/aiter && \\ |
| 28 | +git checkout {AITER_COMMIT} && \\ |
| 29 | +pip install -e . |
| 30 | +""" |
| 31 | + |
| 32 | + |
| 33 | +class AiterSglKernelCompat: |
| 34 | + """ |
| 35 | + Compatibility layer to use aiter with sgl_kernel interface. |
| 36 | + |
| 37 | + This class wraps aiter functions to match sgl_kernel's API, |
| 38 | + allowing existing code to work seamlessly on AMD GPUs. |
| 39 | + |
| 40 | + Note: This is REQUIRED on AMD ROCm as the original sgl_kernel |
| 41 | + does not support AMD GPUs. |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, aiter_module): |
| 45 | + self._aiter = aiter_module |
| 46 | + self._gemm_a8w8 = aiter_module.gemm_a8w8_CK |
| 47 | + self._pertoken_quant = aiter_module.pertoken_quant |
| 48 | + self._dtypes = aiter_module.dtypes |
| 49 | + self._rms_norm = aiter_module.rms_norm |
| 50 | + logger.info("Using aiter as sgl_kernel backend (AMD ROCm optimized)") |
| 51 | + |
| 52 | + def rmsnorm(self, input, weight, eps): |
| 53 | + """RMSNorm compatible with sgl_kernel.rmsnorm(input, weight, eps)""" |
| 54 | + return self._rms_norm(input, weight, eps) |
| 55 | + |
| 56 | + def fp8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None): |
| 57 | + """FP8 GEMM compatible with sgl_kernel.fp8_scaled_mm""" |
| 58 | + return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype) |
| 59 | + |
| 60 | + def int8_scaled_mm(self, input_quant, weight, input_scale, weight_scale, dtype, bias=None): |
| 61 | + """INT8 GEMM compatible with sgl_kernel.int8_scaled_mm""" |
| 62 | + return self._gemm_a8w8(input_quant, weight, input_scale, weight_scale, bias, dtype) |
| 63 | + |
| 64 | + def sgl_per_token_quant_fp8(self, x, out, scale): |
| 65 | + """Per-token FP8 quantization compatible with sgl_kernel.sgl_per_token_quant_fp8""" |
| 66 | + q, s = self._pertoken_quant(x, quant_dtype=self._dtypes.fp8) |
| 67 | + out.copy_(q) |
| 68 | + scale.copy_(s) |
| 69 | + |
| 70 | + def sgl_per_token_group_quant_fp8(self, x, out, scale, group_size=128, eps=1e-10, fp8_min=-448.0, fp8_max=448.0): |
| 71 | + """Per-token per-group FP8 quantization compatible with sgl_kernel.sgl_per_token_group_quant_fp8""" |
| 72 | + m, k = x.shape |
| 73 | + x_view = x.view(m, -1, group_size) |
| 74 | + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(eps) |
| 75 | + q = (x_view * (fp8_max / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, k) |
| 76 | + s = (x_amax / fp8_max).view(m, -1) |
| 77 | + out.copy_(q) |
| 78 | + scale.copy_(s) |
| 79 | + |
| 80 | + |
| 81 | +def _get_aiter_sgl_kernel(): |
| 82 | + """Get aiter-based sgl_kernel compatibility layer.""" |
| 83 | + try: |
| 84 | + import aiter |
| 85 | + return AiterSglKernelCompat(aiter) |
| 86 | + except ImportError: |
| 87 | + logger.error( |
| 88 | + f"\n{'='*60}\n" |
| 89 | + f"ERROR: AMD ROCm detected but aiter is not installed.\n" |
| 90 | + f"aiter is REQUIRED for LightX2V to work on AMD GPUs.\n" |
| 91 | + f"\nPlease install aiter:\n" |
| 92 | + f"{AITER_INSTALL_CMD}\n" |
| 93 | + f"{'='*60}\n" |
| 94 | + ) |
| 95 | + raise ImportError( |
| 96 | + "aiter is required for AMD ROCm support. " |
| 97 | + f"Please install: pip install git+{AITER_REPO}@{AITER_COMMIT}" |
| 98 | + ) |
| 99 | + |
| 100 | + |
| 101 | +@PLATFORM_DEVICE_REGISTER("amd_rocm") |
| 102 | +class AmdRocmDevice: |
| 103 | + """ |
| 104 | + AMD ROCm Device implementation for LightX2V. |
| 105 | +
|
| 106 | + AMD ROCm uses CUDA-compatible APIs through HIP. |
| 107 | + This class provides AMD-specific optimizations. |
| 108 | + """ |
| 109 | + |
| 110 | + name = "amd_rocm" |
| 111 | + |
| 112 | + @staticmethod |
| 113 | + def init_device_env(): |
| 114 | + """ |
| 115 | + Initialize AMD ROCm optimizations. |
| 116 | + |
| 117 | + This is called from lightx2v_platform.set_ai_device when platform is amd_rocm. |
| 118 | + 1. Disable cudnn for faster VAE convolution |
| 119 | + 2. Inject aiter as sgl_kernel compatibility layer (REQUIRED on AMD) |
| 120 | + """ |
| 121 | + logger.info("AMD ROCm platform detected, initializing optimizations...") |
| 122 | + |
| 123 | + # Disable cudnn for faster VAE conv computation |
| 124 | + torch.backends.cudnn.enabled = False |
| 125 | + logger.info(" - cudnn disabled for faster VAE convolution") |
| 126 | + |
| 127 | + # Inject aiter as sgl_kernel compatibility layer (REQUIRED) |
| 128 | + sgl_kernel = _get_aiter_sgl_kernel() |
| 129 | + sys.modules["sgl_kernel"] = sgl_kernel |
| 130 | + # Update any module that already imported sgl_kernel |
| 131 | + for mod_name, mod in list(sys.modules.items()): |
| 132 | + if mod is not None and hasattr(mod, 'sgl_kernel'): |
| 133 | + setattr(mod, 'sgl_kernel', sgl_kernel) |
| 134 | + logger.info(" - aiter sgl_kernel compatibility layer enabled (RMSNorm, GEMM)") |
| 135 | + |
| 136 | + @staticmethod |
| 137 | + def is_available() -> bool: |
| 138 | + """Check if AMD ROCm is available.""" |
| 139 | + return IS_AMD_ROCM and torch.cuda.is_available() |
| 140 | + |
| 141 | + @staticmethod |
| 142 | + def get_device() -> str: |
| 143 | + """Get the device type string. Returns 'cuda' for ROCm compatibility.""" |
| 144 | + return "cuda" |
| 145 | + |
| 146 | + @staticmethod |
| 147 | + def init_parallel_env(): |
| 148 | + """Initialize distributed parallel environment for AMD ROCm.""" |
| 149 | + dist.init_process_group(backend="nccl") |
| 150 | + torch.cuda.set_device(dist.get_rank()) |
| 151 | + |
| 152 | + |
| 153 | +# Export constants |
| 154 | +__all__ = [ |
| 155 | + "IS_AMD_ROCM", |
| 156 | + "AITER_REPO", |
| 157 | + "AITER_COMMIT", |
| 158 | + "AITER_INSTALL_CMD", |
| 159 | + "AiterSglKernelCompat", |
| 160 | + "AmdRocmDevice", |
| 161 | +] |
0 commit comments