Skip to content

Commit 4dd540e

Browse files
authored
[NPU]: adjust MAX_FUSED_SIZE for NPU devices in group_norm (#1003)
- Set MAX_FUSED_SIZE to 16384 for NPU devices, keep 65536 for others - Improve group_norm test cases with fixed, representative parameters - Remove random test parameters in favor of deterministic test cases Hardware Type: Ascend910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent a45cfbf commit 4dd540e

File tree

2 files changed

+10
-12
lines changed

2 files changed

+10
-12
lines changed

src/liger_kernel/ops/group_norm.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from liger_kernel.ops.utils import compare_version
88
from liger_kernel.ops.utils import ensure_contiguous
9+
from liger_kernel.utils import infer_device
910
from liger_kernel.utils import is_npu_available
1011

1112
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
@@ -18,7 +19,10 @@
1819
else:
1920
from triton.language.math import rsqrt
2021

21-
MAX_FUSED_SIZE = 65536
22+
if infer_device() == "npu":
23+
MAX_FUSED_SIZE = 16384 # 8192
24+
else:
25+
MAX_FUSED_SIZE = 65536
2226

2327

2428
@triton.jit

test/transformers/test_group_norm.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import random
2-
31
import pytest
42
import torch
53

@@ -8,19 +6,15 @@
86

97
device = infer_device()
108

11-
random_batch_size = random.randint(1, 16)
12-
random_num_groups = random.randint(1, 32)
13-
random_num_channels = random_num_groups * random.randint(1, 16)
14-
random_hidden_size = random.randint(1, 8192)
15-
169

1710
@pytest.mark.parametrize(
1811
"batch_size, num_channels, num_groups, hidden_size",
1912
[
20-
(1, 1, 1, 3),
21-
(1, 4, 2, 4),
22-
(16, 12, 3, 4096),
23-
(random_batch_size, random_num_channels, random_num_groups, random_hidden_size),
13+
(1, 1, 1, 3), # minimal
14+
(1, 32, 32, 4), # group == channel
15+
(16, 32, 1, 4096), # single group
16+
(2, 63, 21, 2163), # non-aligned hidden
17+
(16, 48, 12, 8192), # large hidden
2418
],
2519
)
2620
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)