Skip to content

Commit 969d4ab

Browse files
authored
[NPU] Add fused_linear_cross_entropy operator (#1164)
## Summary To address the UB overflow issue observed in the benchmark, we introduced an operator with an NPU-friendly implementation of fused linear cross entropy. This fused operator relies on several underlying operations (e.g., large matrix multiplication, softmax, and cross entropy), so its current benchmark performance is not yet optimal. Further optimization may be needed. ## Testing Done Device: Atlas A3 `python -m pytest ./test/transformers/test_fused_linear_cross_entropy.py` <img width="3270" height="499" alt="image" src="https://github.com/user-attachments/assets/7f8a63df-f325-43fe-80b9-6268c5f10e29" /> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 1c013e2 commit 969d4ab

File tree

2 files changed

+414
-0
lines changed

2 files changed

+414
-0
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import LigerFusedAddRMSNormFunction
2727
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_backward
2828
from liger_kernel.ops.backends._ascend.ops.fused_add_rms_norm import fused_add_rms_norm_forward
29+
from liger_kernel.ops.backends._ascend.ops.fused_linear_cross_entropy import LigerFusedLinearCrossEntropyFunction
30+
from liger_kernel.ops.backends._ascend.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_backward
31+
from liger_kernel.ops.backends._ascend.ops.fused_linear_cross_entropy import fused_linear_cross_entropy_forward
2932
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import LigerFusedLinearJSDFunction
3033
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_backward
3134
from liger_kernel.ops.backends._ascend.ops.fused_linear_jsd import fused_linear_jsd_forward
@@ -140,4 +143,7 @@
140143
"sparsemax_backward",
141144
"LigerFusedNeighborhoodAttentionFunction",
142145
"fused_neighborhood_attention_forward",
146+
"LigerFusedLinearCrossEntropyFunction",
147+
"fused_linear_cross_entropy_forward",
148+
"fused_linear_cross_entropy_backward",
143149
]

0 commit comments

Comments
 (0)