Skip to content

Commit e465237

Browse files
authored
[NPU] Add softmax implementation (#1087)
## Summary This PR adds a Softmax implementation for NPU. It includes a single-block forward kernel for smaller column sizes, as well as a multi-block kernel for large column sizes to avoid NPU UB overflow. ## Testing Done Test done with `python -m pytest test/transformers/test_softmax.py` Hardware Type: Atlas 800I A2(32G) - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 5c9cecc commit e465237

File tree

2 files changed

+192
-0
lines changed

2 files changed

+192
-0
lines changed

src/liger_kernel/ops/backends/_ascend/ops/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@
3838
from liger_kernel.ops.backends._ascend.ops.rope import LigerRopeFunction
3939
from liger_kernel.ops.backends._ascend.ops.rope import rope_backward
4040
from liger_kernel.ops.backends._ascend.ops.rope import rope_forward
41+
from liger_kernel.ops.backends._ascend.ops.softmax import LigerSoftmaxFunction
42+
from liger_kernel.ops.backends._ascend.ops.softmax import softmax_backward
43+
from liger_kernel.ops.backends._ascend.ops.softmax import softmax_forward
4144
from liger_kernel.ops.backends._ascend.ops.swiglu import LigerSiLUMulFunction
4245
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_backward
4346
from liger_kernel.ops.backends._ascend.ops.swiglu import swiglu_forward
@@ -76,4 +79,7 @@
7679
"LigerKLDivLossFunction",
7780
"kldiv_forward_triton",
7881
"kldiv_backward_triton",
82+
"LigerSoftmaxFunction",
83+
"softmax_forward",
84+
"softmax_backward",
7985
]
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
from liger_kernel.ops.utils import ensure_contiguous
6+
from liger_kernel.ops.utils import get_npu_core_count
7+
8+
9+
@triton.jit
10+
def _softmax_multi_block_forward_kernel(
11+
Y_ptr,
12+
Y_row_stride,
13+
X_ptr,
14+
X_row_stride,
15+
n_rows,
16+
n_cols,
17+
BLOCK_SIZE: tl.constexpr,
18+
):
19+
"""
20+
Multi-block softmax forward kernel using two-pass algorithm.
21+
22+
First pass computes max and sum for numerical stability.
23+
Second pass normalizes and writes output.
24+
25+
Args:
26+
Y_ptr: Output tensor pointer
27+
Y_row_stride: Stride for output rows
28+
X_ptr: Input tensor pointer
29+
X_row_stride: Stride for input rows
30+
n_rows: Number of rows to process
31+
n_cols: Number of columns per row
32+
BLOCK_SIZE: Block size for column processing
33+
"""
34+
row_start = tl.program_id(0)
35+
row_step = tl.num_programs(0)
36+
37+
for row_idx in tl.range(row_start, n_rows, row_step):
38+
row_start_ptr = X_ptr + row_idx * X_row_stride
39+
col_offsets = tl.arange(0, BLOCK_SIZE)
40+
m = float("-inf")
41+
d = 0.0
42+
43+
for start in tl.range(0, n_cols, BLOCK_SIZE):
44+
idx = start + col_offsets
45+
mask = idx < n_cols
46+
xblk = tl.load(
47+
row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca"
48+
)
49+
blk_max = tl.max(xblk, axis=0)
50+
new_m = tl.maximum(m, blk_max)
51+
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(xblk - new_m), axis=0)
52+
m = new_m
53+
54+
for start in tl.range(0, n_cols, BLOCK_SIZE):
55+
idx = start + col_offsets
56+
mask = idx < n_cols
57+
xblk = tl.load(
58+
row_start_ptr + idx, mask=mask, other=float("-inf"), eviction_policy="evict_first", cache_modifier=".ca"
59+
)
60+
yblk = tl.exp(xblk - m) / d
61+
tl.store(Y_ptr + row_idx * Y_row_stride + idx, yblk, mask=mask, cache_modifier=".cs")
62+
63+
64+
@triton.jit
65+
def _softmax_multi_block_backward_kernel(
66+
dy_ptr,
67+
dy_stride,
68+
y_ptr,
69+
y_stride,
70+
dx_ptr,
71+
dx_stride,
72+
n_rows,
73+
n_cols,
74+
BLOCK_SIZE: tl.constexpr,
75+
):
76+
"""
77+
Multi-block softmax backward kernel using two-pass algorithm.
78+
79+
Computes gradient: dx = y * (dy - sum(dy * y))
80+
81+
Args:
82+
dy_ptr: Gradient output pointer
83+
dy_stride: Stride for gradient output rows
84+
y_ptr: Forward output pointer
85+
y_stride: Stride for forward output rows
86+
dx_ptr: Gradient input pointer
87+
dx_stride: Stride for gradient input rows
88+
n_rows: Number of rows to process
89+
n_cols: Number of columns per row
90+
BLOCK_SIZE: Block size for column processing
91+
"""
92+
row_start = tl.program_id(0)
93+
col_offsets = tl.arange(0, BLOCK_SIZE)
94+
acc = 0.0
95+
row_step = tl.num_programs(0)
96+
97+
for row_idx in tl.range(row_start, n_rows, row_step):
98+
dy_start_ptr = dy_ptr + row_idx * dy_stride
99+
y_start_ptr = y_ptr + row_idx * y_stride
100+
101+
for start in tl.range(0, n_cols, BLOCK_SIZE):
102+
idx = start + col_offsets
103+
mask = idx < n_cols
104+
dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first")
105+
y_blk = tl.load(
106+
y_start_ptr + idx, mask=mask, other=0.0, eviction_policy="evict_first", cache_modifier=".ca"
107+
)
108+
acc += tl.sum(dy_blk * y_blk, axis=0)
109+
110+
for start in tl.range(0, n_cols, BLOCK_SIZE):
111+
idx = start + col_offsets
112+
mask = idx < n_cols
113+
dy_blk = tl.load(dy_start_ptr + idx, mask=mask, other=0.0)
114+
y_blk = tl.load(y_start_ptr + idx, mask=mask, other=0.0, cache_modifier=".ca")
115+
dx_blk = y_blk * (dy_blk - acc)
116+
tl.store(dx_ptr + row_idx * dx_stride + idx, dx_blk, mask=mask, cache_modifier=".wb")
117+
118+
119+
def softmax_forward(x):
120+
*batch, n_cols = x.shape
121+
x2d = x.contiguous().view(-1, n_cols)
122+
n_rows = x2d.shape[0]
123+
MAX_FUSED_BLOCK_SIZE = 8192
124+
125+
BLOCK_SIZE = triton.next_power_of_2(n_cols)
126+
BLOCK_SIZE = min(BLOCK_SIZE, MAX_FUSED_BLOCK_SIZE)
127+
128+
y2d = torch.empty_like(x2d)
129+
num_cores = get_npu_core_count()
130+
num_programs = min(num_cores, n_rows)
131+
132+
_softmax_multi_block_forward_kernel[(num_programs,)](
133+
y2d, y2d.stride(0), x2d, x2d.stride(0), n_rows, n_cols, BLOCK_SIZE=BLOCK_SIZE
134+
)
135+
136+
return y2d.view(*batch, n_cols), BLOCK_SIZE
137+
138+
139+
def softmax_backward(
140+
dy: torch.Tensor,
141+
y: torch.Tensor,
142+
BLOCK_SIZE: int,
143+
) -> torch.Tensor:
144+
*batch, n_cols = dy.shape
145+
dy2d = dy.contiguous().view(-1, n_cols)
146+
y2d = y.contiguous().view(-1, n_cols)
147+
n_rows = dy2d.shape[0]
148+
dx2d = torch.empty_like(dy2d)
149+
150+
num_cores = get_npu_core_count()
151+
num_programs = min(num_cores, n_rows)
152+
153+
_softmax_multi_block_backward_kernel[(num_programs,)](
154+
dy2d,
155+
dy2d.stride(0),
156+
y2d,
157+
y2d.stride(0),
158+
dx2d,
159+
dx2d.stride(0),
160+
n_rows,
161+
n_cols,
162+
BLOCK_SIZE=BLOCK_SIZE,
163+
)
164+
165+
return dx2d.view(*batch, n_cols)
166+
167+
168+
class LigerSoftmaxFunction(torch.autograd.Function):
169+
@staticmethod
170+
@ensure_contiguous
171+
def forward(ctx, input_: torch.Tensor):
172+
y, BLOCK_SIZE = softmax_forward(input_)
173+
ctx.save_for_backward(y)
174+
ctx.BLOCK_SIZE = BLOCK_SIZE
175+
return y
176+
177+
@staticmethod
178+
@ensure_contiguous
179+
def backward(ctx, grad_output):
180+
(y,) = ctx.saved_tensors
181+
dx = softmax_backward(
182+
grad_output,
183+
y,
184+
ctx.BLOCK_SIZE,
185+
)
186+
return dx

0 commit comments

Comments
 (0)