Skip to content

Commit bb88671

Browse files
authored
Unify NPU vector core count helpers (#1052)
## Summary Both `get_npu_core_count` and `get_npu_multi_processor_count` currently serve the same purpose: retrieving the number of NPU vector cores. And `get_npu_multi_processor_count` additionally requires torch_npu >= v7.2.0. This PR leaves a single implementation(`get_npu_core_count`) for consistency. In the future, we may further simplify this logic by replacing the helper with `torch.npu.get_device_properties()`, aligning the NPU with other backends. ## Testing Done All affected operators were tested with `pytest`. Hardware Type: Ascend 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent a5be02e commit bb88671

File tree

6 files changed

+10
-21
lines changed

6 files changed

+10
-21
lines changed

src/liger_kernel/ops/dyt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from liger_kernel.ops.utils import compare_version
88
from liger_kernel.ops.utils import ensure_contiguous
9+
from liger_kernel.ops.utils import get_npu_core_count
910
from liger_kernel.ops.utils import infer_device
10-
from liger_kernel.utils import get_npu_multi_processor_count
1111
from liger_kernel.utils import is_npu_available
1212

1313
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -128,7 +128,7 @@ def liger_dyt_bwd(dy, x, alpha, gamma, beta):
128128
elif device == "xpu":
129129
NUM_SMS = torch.xpu.get_device_properties(x.device).gpu_subslice_count
130130
elif device == "npu":
131-
NUM_SMS = get_npu_multi_processor_count()
131+
NUM_SMS = get_npu_core_count()
132132
da = torch.zeros(NUM_SMS, triton.cdiv(N, 512), dtype=torch.float32, device=x.device)
133133
dg = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device)
134134
db = torch.empty(NUM_SMS, N, dtype=torch.float32, device=x.device) if HAVE_BETA else None

src/liger_kernel/ops/fused_add_rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.ops.utils import get_npu_core_count
1112
from liger_kernel.ops.utils import set_large_grf_mode
1213
from liger_kernel.ops.utils import torch_to_triton_dtype
13-
from liger_kernel.utils import get_npu_multi_processor_count
1414
from liger_kernel.utils import is_npu_available
1515

1616
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -290,7 +290,7 @@ def fused_add_rms_norm_backward(dY, dS_out, S, W, RSTD, offset, casting_mode, BL
290290
elif S.device.type == "xpu":
291291
sm_count = torch.xpu.get_device_properties(S.device).gpu_eu_count
292292
elif S.device.type == "npu":
293-
sm_count = get_npu_multi_processor_count()
293+
sm_count = get_npu_core_count()
294294

295295
# fp32 for numerical stability especially.
296296
_dW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

src/liger_kernel/ops/layer_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from liger_kernel.ops.utils import calculate_settings
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import ensure_contiguous
11+
from liger_kernel.ops.utils import get_npu_core_count
1112
from liger_kernel.ops.utils import set_large_grf_mode
12-
from liger_kernel.utils import get_npu_multi_processor_count
1313
from liger_kernel.utils import is_npu_available
1414

1515
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -251,7 +251,7 @@ def layer_norm_backward(dY, X, W, B, Mean, RSTD):
251251
elif X.device.type == "xpu":
252252
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
253253
elif X.device.type == "npu":
254-
sm_count = get_npu_multi_processor_count()
254+
sm_count = get_npu_core_count()
255255

256256
# fp32 for numerical stability especially.
257257
_DW = torch.empty((sm_count, n_cols), dtype=torch.float32, device=W.device)

src/liger_kernel/ops/poly_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from liger_kernel.ops.utils import calculate_settings
88
from liger_kernel.ops.utils import compare_version
99
from liger_kernel.ops.utils import ensure_contiguous
10+
from liger_kernel.ops.utils import get_npu_core_count
1011
from liger_kernel.ops.utils import set_large_grf_mode
11-
from liger_kernel.utils import get_npu_multi_processor_count
1212
from liger_kernel.utils import is_npu_available
1313

1414
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -287,7 +287,7 @@ def poly_norm_backward(dY, X, W, RSTD, BLOCK_SIZE, num_warps, in_place):
287287
elif X.device.type == "xpu":
288288
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
289289
elif X.device.type == "npu":
290-
sm_count = get_npu_multi_processor_count()
290+
sm_count = get_npu_core_count()
291291

292292
# Allocate or reuse gradients
293293
if in_place is True:

src/liger_kernel/ops/rms_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from liger_kernel.ops.utils import calculate_settings
2121
from liger_kernel.ops.utils import compare_version
2222
from liger_kernel.ops.utils import ensure_contiguous
23+
from liger_kernel.ops.utils import get_npu_core_count
2324
from liger_kernel.ops.utils import set_large_grf_mode
2425
from liger_kernel.ops.utils import torch_to_triton_dtype
25-
from liger_kernel.utils import get_npu_multi_processor_count
2626
from liger_kernel.utils import is_npu_available
2727

2828
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -494,7 +494,7 @@ def rms_norm_backward(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warp
494494
elif X.device.type == "xpu":
495495
sm_count = torch.xpu.get_device_properties(X.device).gpu_eu_count
496496
elif X.device.type == "npu":
497-
sm_count = get_npu_multi_processor_count()
497+
sm_count = get_npu_core_count()
498498

499499
if W is not None:
500500
# fp32 for numerical stability especially.

src/liger_kernel/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,6 @@ def is_npu_available() -> bool:
6565
return False
6666

6767

68-
def get_npu_multi_processor_count() -> int:
69-
"""Return a heuristic multi-processor count for NPU."""
70-
if is_npu_available():
71-
NPU_MULTI_PROCESSOR_COUNT = 48
72-
dev_props = torch.npu.get_device_properties()
73-
# The vector_core_num attribute is supported in the torch.npu v7.2.0 release version.
74-
return dev_props.vector_core_num if hasattr(dev_props, "vector_core_num") else NPU_MULTI_PROCESSOR_COUNT
75-
# Reasonable default to avoid division by zero
76-
return 1
77-
78-
7968
def transformers_version_dispatch(
8069
required_version: str,
8170
before_fn,

0 commit comments

Comments
 (0)