Skip to content

Commit 44c2c31

Browse files
zheliuyulancerts
andauthored
[NPU]: Adjust MAX_FUSED_SIZE when using fused_linear_cross_entropy (#985)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> Adjust MAX_FUSED_SIZE to avoid ub overflow when using fused_linear_cross_entropy on npu. ## Testing Done - Hardware Type: Ascend NPU A2 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ### Compare `pytest test/transformers/test_fused_linear_cross_entropy.py` ``` Original code: 105 passed, 16 failed. All failed due to ub overflow. Adjusted: 121 passed ``` Co-authored-by: Shao Tang <[email protected]>
1 parent 6c2565b commit 44c2c31

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/liger_kernel/ops/cross_entropy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,13 @@ def liger_cross_entropy_kernel(
289289
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
290290
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
291291
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
292-
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
292+
# the best size we found by manually tuning on xpu and npu.
293+
if infer_device() == "xpu":
294+
MAX_FUSED_SIZE = 4096
295+
elif infer_device() == "npu":
296+
MAX_FUSED_SIZE = 2048
297+
else:
298+
MAX_FUSED_SIZE = 65536 // 2
293299

294300

295301
def cross_entropy_forward(

src/liger_kernel/ops/fused_linear_cross_entropy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from liger_kernel.ops.utils import amp_custom_fwd
77
from liger_kernel.ops.utils import element_mul_kernel
88
from liger_kernel.ops.utils import is_hip
9+
from liger_kernel.utils import infer_device
910

1011
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
1112
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
1213
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
13-
MAX_FUSED_SIZE = 65536 // 2
14+
MAX_FUSED_SIZE = 2048 if infer_device() == "npu" else 65536 // 2
1415

1516

1617
def fused_linear_cross_entropy_forward(

0 commit comments

Comments
 (0)