Skip to content

Commit 224a660

Browse files
authored
[NPU]: optimize GEGLU implementation with flatten 1D approach (#1031)
- Refactor Ascend GEGLU kernels to use flatten 1D grid-stride loop pattern instead of row-based tiling approach for better performance - Simplify block size calculation using compute_default_tiling_strategy - Align type conversion logic with GPU version for consistency - Add quack's distance-based comparison method for NPU + bfloat16 tests - Apply scaled weight initialization (1/sqrt(in_features)) following quack's recommendation for better numerical stability - Improve test robustness by comparing custom_bf16 vs fp32 distance against ref_bf16 vs fp32 distance (threshold: 2x + 1e-6) Hardware Type: Ascend 910B4 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent 15f9338 commit 224a660

File tree

2 files changed

+239
-195
lines changed

2 files changed

+239
-195
lines changed
Lines changed: 119 additions & 194 deletions
Original file line numberDiff line numberDiff line change
@@ -1,266 +1,191 @@
1-
"""
2-
UB-aware GEGLU implementation for Ascend NPU.
3-
4-
This implementation automatically adjusts block sizes to fit within UB constraints,
5-
preventing UB overflow errors when running on Ascend NPU.
6-
7-
It reuses the original kernels when possible, and only uses tiling when necessary.
8-
"""
9-
10-
import operator
11-
121
import torch
132
import triton
143
import triton.language as tl
154

5+
from triton.language.math import tanh
6+
167
from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
17-
from liger_kernel.ops.utils import calculate_settings
18-
from liger_kernel.ops.utils import compare_version
198
from liger_kernel.ops.utils import ensure_contiguous
20-
from liger_kernel.utils import is_npu_available
21-
22-
if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available():
23-
try:
24-
from triton.language.extra.libdevice import tanh
25-
except ModuleNotFoundError:
26-
from triton.language.extra.cuda.libdevice import tanh
27-
else:
28-
from triton.language.math import tanh
9+
from liger_kernel.ops.utils import get_npu_core_count
2910

3011

3112
@triton.jit
32-
def _geglu_tanh_forward_kernel_npu(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
13+
def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr):
3314
"""
34-
UB-aware GEGLU forward kernel for NPU.
15+
High-performance GEGLU forward kernel using flatten 1D approach.
3516
36-
Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
37-
When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
17+
Uses grid-stride loop pattern for optimal performance on NPU.
3818
"""
39-
program_id = tl.program_id(0).to(tl.int64)
19+
pid = tl.program_id(0)
20+
num_progs = tl.num_programs(0)
21+
22+
# Grid-Stride Loop
23+
start_idx = pid * BLOCK_SIZE
24+
stride = num_progs * BLOCK_SIZE
4025

41-
# locate start index
42-
a += program_id * stride
43-
b += program_id * stride
44-
c += program_id * stride
26+
# Constants for GELU tanh approximation
27+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
28+
gelu_coeff = 0.044715
4529

46-
# Process in tiles when BLOCK_SIZE < n_cols
47-
for i in range(0, n_cols, BLOCK_SIZE):
48-
col_offsets = i + tl.arange(0, BLOCK_SIZE)
49-
mask = col_offsets < n_cols
30+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
31+
offsets = idx + tl.arange(0, BLOCK_SIZE)
32+
mask = offsets < total_elements
5033

51-
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
52-
b_row = tl.load(b + col_offsets, mask=mask, other=0)
34+
a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
35+
b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0)
5336

5437
# tanh approximation form of GELU is computed with:
5538
# 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))
56-
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
57-
a_cubed = a_row * a_row * a_row
58-
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
39+
a_cubed = a_val * a_val * a_val
40+
tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed)
5941
tanh_result = tanh(tanh_arg)
60-
geglu_a = 0.5 * a_row * (1 + tanh_result)
61-
c_row = geglu_a.cast(b_row.dtype) * b_row
62-
63-
tl.store(c + col_offsets, c_row, mask=mask)
42+
geglu_a = 0.5 * a_val * (1.0 + tanh_result)
43+
c_row = geglu_a.cast(b_val.dtype) * b_val
44+
tl.store(c_ptr + offsets, c_row, mask=mask)
6445

6546

6647
@triton.jit
67-
def _geglu_tanh_backward_kernel_npu(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr):
48+
def _geglu_backward_kernel_flat(
49+
dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr
50+
):
6851
"""
69-
UB-aware GEGLU backward kernel for NPU.
52+
High-performance GEGLU backward kernel using flatten 1D approach.
7053
71-
Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints).
72-
When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior.
54+
Uses grid-stride loop pattern for optimal performance on NPU.
7355
"""
74-
program_id = tl.program_id(0).to(tl.int64)
56+
pid = tl.program_id(0)
57+
num_progs = tl.num_programs(0)
58+
start_idx = pid * BLOCK_SIZE
59+
stride = num_progs * BLOCK_SIZE
7560

76-
# locate start index
77-
dc += program_id * stride
78-
a += program_id * stride
79-
b += program_id * stride
61+
# Constants for GELU tanh approximation
62+
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
63+
gelu_coeff = 0.044715
8064

81-
# Process in tiles when BLOCK_SIZE < n_cols
82-
for i in range(0, n_cols, BLOCK_SIZE):
83-
col_offsets = i + tl.arange(0, BLOCK_SIZE)
84-
mask = col_offsets < n_cols
65+
for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES):
66+
offsets = idx + tl.arange(0, BLOCK_SIZE)
67+
mask = offsets < total_elements
8568

86-
dc_row = tl.load(dc + col_offsets, mask=mask, other=0)
87-
a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)
88-
b_row = tl.load(b + col_offsets, mask=mask, other=0)
69+
dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0)
70+
a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
71+
b = tl.load(b_ptr + offsets, mask=mask, other=0.0)
8972

9073
# recomputation to save memory
91-
sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)
92-
a_cubed = a_row * a_row * a_row
93-
tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)
74+
a_cubed = a * a * a
75+
tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed)
9476
tanh_result = tanh(tanh_arg)
95-
geglu_a = 0.5 * a_row * (1 + tanh_result)
96-
geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32)
77+
geglu_a = 0.5 * a * (1 + tanh_result)
78+
geglu_a = geglu_a.to(dc.dtype).to(tl.float32)
9779

98-
db_row = dc_row.cast(tl.float32) * geglu_a
80+
db = dc.cast(tl.float32) * geglu_a
9981

10082
# Gradient w.r.t. a can be computed with:
10183
# b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))
10284
# where z = sqrt(2/pi) * (a + 0.044715 * a^3)
103-
term1 = 0.5 * (1 + tanh_result)
85+
term1 = 0.5 * (1.0 + tanh_result)
10486
tanh_sq = tanh_result * tanh_result
105-
term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))
106-
da_row = dc_row * b_row * (term1 + term2)
87+
a_sq = a * a
88+
term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq))
89+
da = dc * b * (term1 + term2)
10790

108-
tl.store(a + col_offsets, da_row, mask=mask)
109-
tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask)
91+
tl.store(da_ptr + offsets, da, mask=mask)
92+
tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask)
11093

11194

112-
def geglu_forward(a, b):
95+
def get_optimal_block_size(total_elements, is_backward=False):
11396
"""
114-
UB-aware GEGLU forward pass for NPU.
97+
Calculate optimal Block Size using compute_default_tiling_strategy.
11598
116-
Automatically adjusts block size to fit within UB constraints.
117-
"""
118-
ori_shape = a.shape
99+
Args:
100+
total_elements: Total number of elements to process
101+
is_backward: Whether this is for backward pass (requires more memory)
119102
120-
n_cols = ori_shape[-1]
121-
a = a.view(-1, n_cols)
122-
b = b.view(-1, n_cols)
123-
c = torch.empty_like(a)
124-
n_rows = a.shape[0]
125-
126-
# Calculate desired block size
127-
desired_block_size, num_warps = calculate_settings(n_cols)
128-
129-
# Compute tiling strategy based on UB capacity
130-
dtype_size = a.element_size()
131-
# GEGLU forward tiling strategy:
132-
# - Calculates maximum safe block size based on UB capacity
133-
# - Memory analysis (only buffers that occupy UB, excluding temporary variables):
134-
# * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes)
135-
# * Output: c_row (dtype_size bytes)
136-
# * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers
137-
# and don't occupy UB since they are only used once
138-
# * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0
139-
# * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0
140-
# - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
141-
# - shapes: ((n_cols,),)
142-
# - tiling_dims: (0,) means first dimension can be tiled
143-
# - Returns: ((block_size,),)
144-
shapes = ((n_cols,),)
145-
if dtype_size == 2:
146-
memory_multiplier = 4.0
103+
Returns:
104+
Optimal block size for the kernel
105+
"""
106+
# Memory multiplier based on peak memory usage analysis
107+
if is_backward:
108+
memory_multiplier = 6.0
147109
else:
148110
memory_multiplier = 3.0
111+
# Call calculation function
112+
# Treat input as 1D (total_elements,), only tiling on dim 0
149113
tile_shapes = compute_default_tiling_strategy(
150-
safety_margin=0.80,
151-
dtype_size=dtype_size,
114+
safety_margin=0.9,
115+
dtype_size=4,
152116
memory_multiplier=memory_multiplier,
153-
shapes=shapes,
117+
shapes=((total_elements,),),
154118
tiling_dims=(0,),
155119
)
156120

157-
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
158-
# Strategy returns ((block_size,),)
159-
adjusted_block_size = tile_shapes[0][0]
121+
# Parse result
122+
if tile_shapes and len(tile_shapes) > 0:
123+
block_size = tile_shapes[0][0]
124+
return max(256, block_size)
160125
else:
161-
# Fallback to desired block size if no best practice found (no tiling needed)
162-
adjusted_block_size = desired_block_size
163-
# Always use the unified NPU kernel
164-
# When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
165-
# When adjusted_block_size < n_cols, the loop handles tiling automatically
166-
_geglu_tanh_forward_kernel_npu[(n_rows,)](
167-
a,
168-
b,
169-
c,
170-
c.stride(-2),
171-
n_cols=n_cols,
172-
BLOCK_SIZE=adjusted_block_size,
173-
num_warps=num_warps,
174-
)
175-
return a, b, c.view(*ori_shape)
126+
return 2048
176127

177128

178-
def geglu_backward(a, b, dc):
129+
def geglu_forward(a, b):
179130
"""
180-
UB-aware GEGLU backward pass for NPU.
131+
High-performance GEGLU forward pass for NPU using flatten 1D approach.
132+
"""
133+
if not a.is_contiguous():
134+
a = a.contiguous()
135+
if not b.is_contiguous():
136+
b = b.contiguous()
137+
138+
total_elements = a.numel()
139+
c = torch.empty_like(a)
140+
141+
block_size = get_optimal_block_size(total_elements, is_backward=False)
142+
143+
num_cores = get_npu_core_count()
144+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
145+
146+
_geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4)
147+
return c
148+
181149

182-
Automatically adjusts block size to fit within UB constraints.
150+
def geglu_backward(a, b, dc):
183151
"""
184-
ori_shape = dc.shape
185-
n_cols = ori_shape[-1]
186-
dc = dc.view(-1, n_cols)
187-
n_rows = dc.shape[0]
188-
189-
# Calculate desired block size
190-
desired_block_size, num_warps = calculate_settings(n_cols)
191-
192-
# Compute tiling strategy based on UB capacity
193-
dtype_size = dc.element_size()
194-
# GEGLU backward tiling strategy:
195-
# - Calculates maximum safe block size based on UB capacity
196-
# - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation)
197-
# At this point, the following buffers simultaneously occupy UB:
198-
# 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes
199-
# 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32)
200-
# 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes
201-
# 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104
202-
# 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98
203-
# 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109)
204-
# Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB
205-
# Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers
206-
# and don't occupy UB since they are only used once
207-
# * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4)
208-
# = 20 bytes/element, ratio = 20/2 = 10.0
209-
# * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4)
210-
# = 24 bytes/element, ratio = 24/4 = 6.0
211-
# - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits
212-
# - shapes: ((n_cols,),)
213-
# - tiling_dims: (0,) means first dimension can be tiled
214-
# - Returns: ((block_size,),)
215-
shapes = ((n_cols,),)
216-
if dtype_size == 2:
217-
memory_multiplier = 10.0
218-
else:
219-
memory_multiplier = 6.0
220-
tile_shapes = compute_default_tiling_strategy(
221-
safety_margin=0.80,
222-
dtype_size=dtype_size,
223-
memory_multiplier=memory_multiplier,
224-
shapes=shapes,
225-
tiling_dims=(0,),
226-
)
152+
High-performance GEGLU backward pass for NPU using flatten 1D approach.
153+
"""
154+
if not dc.is_contiguous():
155+
dc = dc.contiguous()
156+
if not a.is_contiguous():
157+
a = a.contiguous()
158+
if not b.is_contiguous():
159+
b = b.contiguous()
227160

228-
if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0:
229-
# Strategy returns ((block_size,),)
230-
adjusted_block_size = tile_shapes[0][0]
231-
else:
232-
# Fallback to desired block size if no best practice found (no tiling needed)
233-
adjusted_block_size = desired_block_size
234-
235-
# Always use the unified NPU kernel
236-
# When adjusted_block_size >= n_cols, the loop executes only once (no tiling)
237-
# When adjusted_block_size < n_cols, the loop handles tiling automatically
238-
_geglu_tanh_backward_kernel_npu[(n_rows,)](
239-
dc,
240-
a,
241-
b,
242-
dc.stride(-2),
243-
n_cols=n_cols,
244-
BLOCK_SIZE=adjusted_block_size,
245-
num_warps=num_warps,
246-
)
161+
total_elements = dc.numel()
162+
grad_a = torch.empty_like(a)
163+
grad_b = torch.empty_like(b)
164+
165+
block_size = get_optimal_block_size(total_elements, is_backward=True)
247166

248-
return a.view(*ori_shape), b.view(*ori_shape)
167+
num_cores = get_npu_core_count()
168+
grid_size = min(num_cores, (total_elements + block_size - 1) // block_size)
169+
170+
_geglu_backward_kernel_flat[(grid_size,)](
171+
dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4
172+
)
173+
return grad_a, grad_b
249174

250175

251176
class LigerGELUMulFunction(torch.autograd.Function):
252-
"""UB-aware GEGLU function for Ascend NPU."""
177+
"""High-performance GEGLU function for Ascend NPU."""
253178

254179
@staticmethod
255180
@ensure_contiguous
256181
def forward(ctx, a, b):
257-
a, b, c = geglu_forward(a, b)
182+
c = geglu_forward(a, b)
258183
ctx.save_for_backward(a, b)
259184
return c
260185

261186
@staticmethod
262187
@ensure_contiguous
263188
def backward(ctx, dc):
264189
a, b = ctx.saved_tensors
265-
a, b = geglu_backward(a, b, dc)
266-
return a, b
190+
grad_a, grad_b = geglu_backward(a, b, dc)
191+
return grad_a, grad_b

0 commit comments

Comments
 (0)