Skip to content

Commit 62c0544

Browse files
authored
Add Ascend NPU device support. (#955)
## Summary This PR is the first step in the adaptation of Ascend NPU to Liger Kernel: adding NPU device support. For details, refer to [[RFC] Native Ascend NPU Support for Liger Kernel](#954), Section 2.1: **Device Support Integration**. ## Details Key Modifications: 1. Add the installation method and basic function adaptation for NPU. 2. Directly import via `triton.language.math` on NPU to avoid errors caused by non-existent interfaces. ## Testing Done Verification Status: We have conducted verification on **Atlas 800T A3**, and basic test cases such as `test_softmax` and `test_swiglu` have passed. We will continue to improve it in the future. <img width="2870" height="924" alt="image" src="https://github.com/user-attachments/assets/e1bb8195-e140-4531-9cc6-d590ce07e7c9" /> - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent 7d54ff8 commit 62c0544

File tree

13 files changed

+84
-10
lines changed

13 files changed

+84
-10
lines changed

setup.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def get_default_dependencies():
2424
return [
2525
"torch>=2.6.0",
2626
]
27+
elif platform == "npu":
28+
return ["torch_npu==2.6.0", "triton-ascend"]
2729

2830

2931
def get_optional_dependencies():
@@ -67,7 +69,21 @@ def is_xpu_available():
6769
return False
6870

6971

70-
def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
72+
def is_ascend_available() -> bool:
73+
"""Best-effort Ascend detection.
74+
75+
Checks for common Ascend environment variables and a possible `npu-smi`
76+
utility if present.
77+
"""
78+
try:
79+
subprocess.run(["npu-smi", "info"], check=True)
80+
return True
81+
except (subprocess.SubprocessError, FileNotFoundError):
82+
pass
83+
return False
84+
85+
86+
def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu", "npu"]:
7187
"""
7288
Detect whether the system has NVIDIA or AMD GPU without torch dependency.
7389
"""
@@ -86,6 +102,9 @@ def get_platform() -> Literal["cuda", "rocm", "cpu", "xpu"]:
86102
if is_xpu_available():
87103
print("Intel GPU detected")
88104
return "xpu"
105+
elif is_ascend_available():
106+
print("Ascend NPU detected")
107+
return "npu"
89108
else:
90109
print("No GPU detected")
91110
return "cpu"

src/liger_kernel/ops/cross_entropy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
from liger_kernel.ops.utils import element_mul_kernel
1111
from liger_kernel.ops.utils import is_hip
1212
from liger_kernel.utils import infer_device
13+
from liger_kernel.utils import is_npu_available
1314

14-
if compare_version("triton", operator.ge, "3.0.0"):
15+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1516
try:
1617
# typical import path with dispatch available
1718
from triton.language.extra.libdevice import tanh

src/liger_kernel/ops/dyt.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from liger_kernel.ops.utils import compare_version
88
from liger_kernel.ops.utils import ensure_contiguous
99
from liger_kernel.ops.utils import infer_device
10+
from liger_kernel.utils import get_npu_multi_processor_count
11+
from liger_kernel.utils import is_npu_available
1012

11-
if compare_version("triton", operator.ge, "3.0.0"):
13+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1214
try:
1315
# typical import path with dispatch available
1416
from triton.language.extra.libdevice import tanh
@@ -125,7 +127,8 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
125127
NUM_SMS = torch.cuda.get_device_properties(x.device).multi_processor_count
126128
elif device == "xpu":
127129
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
128-
130+
elif device == "npu":
131+
NUM_SMS = get_npu_multi_processor_count()
129132
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
130133
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
131134
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None

src/liger_kernel/ops/fused_add_rms_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
1111
from liger_kernel.ops.utils import torch_to_triton_dtype
12+
from liger_kernel.utils import get_npu_multi_processor_count
13+
from liger_kernel.utils import is_npu_available
1214

13-
if compare_version("triton", operator.ge, "3.0.0"):
15+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1416
try:
1517
# typical import path with dispatch available
1618
from triton.language.extra.libdevice import rsqrt
@@ -293,6 +295,8 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
293295
sm_count = torch.cuda.get_device_properties(S.device).multi_processor_count
294296
elif S.device.type == "xpu":
295297
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
298+
elif S.device.type == "npu":
299+
sm_count = get_npu_multi_processor_count()
296300

297301
# fp32 for numerical stability especially.
298302
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

src/liger_kernel/ops/geglu.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from liger_kernel.ops.utils import calculate_settings
88
from liger_kernel.ops.utils import compare_version
99
from liger_kernel.ops.utils import ensure_contiguous
10+
from liger_kernel.utils import is_npu_available
1011

11-
if compare_version("triton", operator.ge, "3.0.0"):
12+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1213
try:
1314
# typical import path with dispatch available
1415
from triton.language.extra.libdevice import tanh

src/liger_kernel/ops/group_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66

77
from liger_kernel.ops.utils import compare_version
88
from liger_kernel.ops.utils import ensure_contiguous
9+
from liger_kernel.utils import is_npu_available
910

10-
if compare_version("triton", operator.ge, "3.0.0"):
11+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1112
try:
1213
# typical import path with dispatch available
1314
from triton.language.extra.libdevice import rsqrt

src/liger_kernel/ops/layer_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,9 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.utils import is_npu_available
1112

12-
if compare_version("triton", operator.ge, "3.0.0"):
13+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1314
try:
1415
# typical import path with dispatch available
1516
from triton.language.extra.libdevice import rsqrt

src/liger_kernel/ops/poly_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from liger_kernel.ops.utils import calculate_settings
88
from liger_kernel.ops.utils import compare_version
99
from liger_kernel.ops.utils import ensure_contiguous
10+
from liger_kernel.utils import get_npu_multi_processor_count
11+
from liger_kernel.utils import is_npu_available
1012

11-
if compare_version("triton", operator.ge, "3.0.0"):
13+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
1214
try:
1315
from triton.language.extra.libdevice import rsqrt
1416
except ModuleNotFoundError:
@@ -290,6 +292,8 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
290292
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
291293
elif X.device.type == "xpu":
292294
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
295+
elif X.device.type == "npu":
296+
sm_count = get_npu_multi_processor_count()
293297

294298
# Allocate or reuse gradients
295299
if in_place is True:

src/liger_kernel/ops/rms_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
from liger_kernel.ops.utils import compare_version
2222
from liger_kernel.ops.utils import ensure_contiguous
2323
from liger_kernel.ops.utils import torch_to_triton_dtype
24+
from liger_kernel.utils import get_npu_multi_processor_count
25+
from liger_kernel.utils import is_npu_available
2426

25-
if compare_version("triton", operator.ge, "3.0.0"):
27+
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
2628
try:
2729
# typical import path with dispatch available
2830
from triton.language.extra.libdevice import rsqrt
@@ -450,6 +452,8 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
450452
sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count
451453
elif X.device.type == "xpu":
452454
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
455+
elif X.device.type == "npu":
456+
sm_count = get_npu_multi_processor_count()
453457

454458
# fp32 for numerical stability especially.
455459
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

src/liger_kernel/ops/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def get_amp_custom_fwd_bwd() -> Callable:
7878
functools.partial(torch.amp.custom_fwd, device_type=device),
7979
functools.partial(torch.amp.custom_bwd, device_type=device),
8080
)
81+
if hasattr(torch, "npu") and getattr(torch.npu, "amp", None) is not None:
82+
return torch.npu.amp.custom_fwd, torch.npu.amp.custom_bwd
8183
return torch.cuda.amp.custom_fwd, torch.cuda.amp.custom_bwd
8284

8385

0 commit comments

Comments
 (0)