Skip to content

Commit b5752db

Browse files
authored
update metax platform deploy (#690)
1 parent bdde5ef commit b5752db

File tree

8 files changed

+87
-15
lines changed

8 files changed

+87
-15
lines changed

lightx2v_platform/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
from lightx2v_platform.base.ascend_npu import NpuDevice
44
from lightx2v_platform.base.cambricon_mlu import MluDevice
55
from lightx2v_platform.base.hygon_dcu import HygonDcuDevice
6-
from lightx2v_platform.base.metax import MetaxDevice
76
from lightx2v_platform.base.mthreads_musa import MusaDevice
7+
from lightx2v_platform.base.metax_cuda import MetaxDevice
88
from lightx2v_platform.base.nvidia import CudaDevice
99

1010
__all__ = [

lightx2v_platform/base/metax.py

Lines changed: 0 additions & 7 deletions
This file was deleted.
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch.distributed as dist
3+
4+
from lightx2v_platform.base.nvidia import CudaDevice
5+
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER
6+
7+
8+
@PLATFORM_DEVICE_REGISTER("metax_cuda")
9+
class MetaxDevice(CudaDevice):
10+
name = "cuda"
11+
12+
@staticmethod
13+
def init_device_env():
14+
pass
15+
16+
@staticmethod
17+
def is_available() -> bool:
18+
try:
19+
import torch
20+
21+
return torch.cuda.is_available()
22+
except ImportError:
23+
return False
24+
25+
@staticmethod
26+
def get_device() -> str:
27+
return "cuda"
28+
29+
@staticmethod
30+
def init_parallel_env():
31+
dist.init_process_group(backend="nccl")
32+
torch.npu.set_device(dist.get_rank())

lightx2v_platform/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,5 @@
1313
elif PLATFORM == "ascend_npu":
1414
from .attn.ascend_npu import *
1515
from .mm.ascend_npu import *
16+
elif PLATFORM == "metax_cuda":
17+
from .attn.metax_cuda import *
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .metax_sage_attn2 import *
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from loguru import logger
2+
3+
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER
4+
from lightx2v_platform.ops.attn.template import AttnWeightTemplate
5+
6+
try:
7+
from sageattention import sageattn
8+
except ImportError:
9+
logger.info("sageattn not found, please install sageattention first")
10+
sageattn = None
11+
12+
13+
@ATTN_WEIGHT_REGISTER("metax_sage_attn2")
14+
class MetaxSageAttn2Weight(AttnWeightTemplate):
15+
def __init__(self):
16+
self.config = {}
17+
18+
def apply(
19+
self,
20+
q,
21+
k,
22+
v,
23+
cu_seqlens_q=None,
24+
cu_seqlens_kv=None,
25+
max_seqlen_q=None,
26+
max_seqlen_kv=None,
27+
model_cls=None,
28+
):
29+
q, k, v = q.contiguous(), k.contiguous(), v.contiguous()
30+
if len(q.shape) == 3:
31+
bs = 1
32+
q, k, v = q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0)
33+
elif len(q.shape) == 4:
34+
bs = q.shape[0]
35+
x = (
36+
sageattn(
37+
q,
38+
k,
39+
v,
40+
tensor_layout="NHD",
41+
)[0]
42+
.view(bs * max_seqlen_q, -1)
43+
.type(q.dtype)
44+
)
45+
return x

scripts/platforms/metax/qwen_image_i2i_2511.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
lightx2v_path=
77
model_path=
88

9-
export PLATFORM=metax
9+
export PLATFORM="metax_cuda"
1010
export CUDA_VISIBLE_DEVICES=0
1111

1212
# set environment variables

scripts/platforms/metax/run_wan21_t2v.sh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
#!/bin/bash
22

3-
# System management interface: mx-smi
4-
53
# set path and first
6-
lightx2v_path=
7-
model_path=
4+
lightx2v_path=/path/to/LightX2v
5+
model_path=/path/to/model
86

9-
export PLATFORM=metax
10-
export CUDA_VISIBLE_DEVICES=0
7+
# export CUDA_VISIBLE_DEVICES=5
8+
export PLATFORM="metax_cuda"
9+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
1110

1211
# set environment variables
1312
source ${lightx2v_path}/scripts/base/base.sh

0 commit comments

Comments
 (0)