|
1 | | -""" |
2 | | -UB-aware GEGLU implementation for Ascend NPU. |
3 | | -
|
4 | | -This implementation automatically adjusts block sizes to fit within UB constraints, |
5 | | -preventing UB overflow errors when running on Ascend NPU. |
6 | | -
|
7 | | -It reuses the original kernels when possible, and only uses tiling when necessary. |
8 | | -""" |
9 | | - |
10 | | -import operator |
11 | | - |
12 | 1 | import torch |
13 | 2 | import triton |
14 | 3 | import triton.language as tl |
15 | 4 |
|
| 5 | +from triton.language.math import tanh |
| 6 | + |
16 | 7 | from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy |
17 | | -from liger_kernel.ops.utils import calculate_settings |
18 | | -from liger_kernel.ops.utils import compare_version |
19 | 8 | from liger_kernel.ops.utils import ensure_contiguous |
20 | | -from liger_kernel.utils import is_npu_available |
21 | | - |
22 | | -if compare_version("triton", operator.ge, "3.0.0") and not is_npu_available(): |
23 | | - try: |
24 | | - from triton.language.extra.libdevice import tanh |
25 | | - except ModuleNotFoundError: |
26 | | - from triton.language.extra.cuda.libdevice import tanh |
27 | | -else: |
28 | | - from triton.language.math import tanh |
| 9 | +from liger_kernel.ops.utils import get_npu_core_count |
29 | 10 |
|
30 | 11 |
|
31 | 12 | @triton.jit |
32 | | -def _geglu_tanh_forward_kernel_npu(a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): |
| 13 | +def _geglu_forward_kernel_flat(a_ptr, b_ptr, c_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr): |
33 | 14 | """ |
34 | | - UB-aware GEGLU forward kernel for NPU. |
| 15 | + High-performance GEGLU forward kernel using flatten 1D approach. |
35 | 16 |
|
36 | | - Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints). |
37 | | - When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior. |
| 17 | + Uses grid-stride loop pattern for optimal performance on NPU. |
38 | 18 | """ |
39 | | - program_id = tl.program_id(0).to(tl.int64) |
| 19 | + pid = tl.program_id(0) |
| 20 | + num_progs = tl.num_programs(0) |
| 21 | + |
| 22 | + # Grid-Stride Loop |
| 23 | + start_idx = pid * BLOCK_SIZE |
| 24 | + stride = num_progs * BLOCK_SIZE |
40 | 25 |
|
41 | | - # locate start index |
42 | | - a += program_id * stride |
43 | | - b += program_id * stride |
44 | | - c += program_id * stride |
| 26 | + # Constants for GELU tanh approximation |
| 27 | + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) |
| 28 | + gelu_coeff = 0.044715 |
45 | 29 |
|
46 | | - # Process in tiles when BLOCK_SIZE < n_cols |
47 | | - for i in range(0, n_cols, BLOCK_SIZE): |
48 | | - col_offsets = i + tl.arange(0, BLOCK_SIZE) |
49 | | - mask = col_offsets < n_cols |
| 30 | + for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES): |
| 31 | + offsets = idx + tl.arange(0, BLOCK_SIZE) |
| 32 | + mask = offsets < total_elements |
50 | 33 |
|
51 | | - a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) |
52 | | - b_row = tl.load(b + col_offsets, mask=mask, other=0) |
| 34 | + a_val = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) |
| 35 | + b_val = tl.load(b_ptr + offsets, mask=mask, other=0.0) |
53 | 36 |
|
54 | 37 | # tanh approximation form of GELU is computed with: |
55 | 38 | # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3))) |
56 | | - sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) |
57 | | - a_cubed = a_row * a_row * a_row |
58 | | - tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) |
| 39 | + a_cubed = a_val * a_val * a_val |
| 40 | + tanh_arg = sqrt_2_over_pi * (a_val + gelu_coeff * a_cubed) |
59 | 41 | tanh_result = tanh(tanh_arg) |
60 | | - geglu_a = 0.5 * a_row * (1 + tanh_result) |
61 | | - c_row = geglu_a.cast(b_row.dtype) * b_row |
62 | | - |
63 | | - tl.store(c + col_offsets, c_row, mask=mask) |
| 42 | + geglu_a = 0.5 * a_val * (1.0 + tanh_result) |
| 43 | + c_row = geglu_a.cast(b_val.dtype) * b_val |
| 44 | + tl.store(c_ptr + offsets, c_row, mask=mask) |
64 | 45 |
|
65 | 46 |
|
66 | 47 | @triton.jit |
67 | | -def _geglu_tanh_backward_kernel_npu(dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr): |
| 48 | +def _geglu_backward_kernel_flat( |
| 49 | + dc_ptr, a_ptr, b_ptr, da_ptr, db_ptr, total_elements, BLOCK_SIZE: tl.constexpr, NUM_STAGES: tl.constexpr |
| 50 | +): |
68 | 51 | """ |
69 | | - UB-aware GEGLU backward kernel for NPU. |
| 52 | + High-performance GEGLU backward kernel using flatten 1D approach. |
70 | 53 |
|
71 | | - Uses tiling loop to handle cases where BLOCK_SIZE < n_cols (due to UB constraints). |
72 | | - When BLOCK_SIZE >= n_cols, the loop executes only once, maintaining original behavior. |
| 54 | + Uses grid-stride loop pattern for optimal performance on NPU. |
73 | 55 | """ |
74 | | - program_id = tl.program_id(0).to(tl.int64) |
| 56 | + pid = tl.program_id(0) |
| 57 | + num_progs = tl.num_programs(0) |
| 58 | + start_idx = pid * BLOCK_SIZE |
| 59 | + stride = num_progs * BLOCK_SIZE |
75 | 60 |
|
76 | | - # locate start index |
77 | | - dc += program_id * stride |
78 | | - a += program_id * stride |
79 | | - b += program_id * stride |
| 61 | + # Constants for GELU tanh approximation |
| 62 | + sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) |
| 63 | + gelu_coeff = 0.044715 |
80 | 64 |
|
81 | | - # Process in tiles when BLOCK_SIZE < n_cols |
82 | | - for i in range(0, n_cols, BLOCK_SIZE): |
83 | | - col_offsets = i + tl.arange(0, BLOCK_SIZE) |
84 | | - mask = col_offsets < n_cols |
| 65 | + for idx in tl.range(start_idx, total_elements, stride, num_stages=NUM_STAGES): |
| 66 | + offsets = idx + tl.arange(0, BLOCK_SIZE) |
| 67 | + mask = offsets < total_elements |
85 | 68 |
|
86 | | - dc_row = tl.load(dc + col_offsets, mask=mask, other=0) |
87 | | - a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32) |
88 | | - b_row = tl.load(b + col_offsets, mask=mask, other=0) |
| 69 | + dc = tl.load(dc_ptr + offsets, mask=mask, other=0.0) |
| 70 | + a = tl.load(a_ptr + offsets, mask=mask, other=0.0).to(tl.float32) |
| 71 | + b = tl.load(b_ptr + offsets, mask=mask, other=0.0) |
89 | 72 |
|
90 | 73 | # recomputation to save memory |
91 | | - sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi) |
92 | | - a_cubed = a_row * a_row * a_row |
93 | | - tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed) |
| 74 | + a_cubed = a * a * a |
| 75 | + tanh_arg = sqrt_2_over_pi * (a + gelu_coeff * a_cubed) |
94 | 76 | tanh_result = tanh(tanh_arg) |
95 | | - geglu_a = 0.5 * a_row * (1 + tanh_result) |
96 | | - geglu_a = geglu_a.to(dc_row.dtype).to(tl.float32) |
| 77 | + geglu_a = 0.5 * a * (1 + tanh_result) |
| 78 | + geglu_a = geglu_a.to(dc.dtype).to(tl.float32) |
97 | 79 |
|
98 | | - db_row = dc_row.cast(tl.float32) * geglu_a |
| 80 | + db = dc.cast(tl.float32) * geglu_a |
99 | 81 |
|
100 | 82 | # Gradient w.r.t. a can be computed with: |
101 | 83 | # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2))) |
102 | 84 | # where z = sqrt(2/pi) * (a + 0.044715 * a^3) |
103 | | - term1 = 0.5 * (1 + tanh_result) |
| 85 | + term1 = 0.5 * (1.0 + tanh_result) |
104 | 86 | tanh_sq = tanh_result * tanh_result |
105 | | - term2 = 0.5 * a_row * (1 - tanh_sq) * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row)) |
106 | | - da_row = dc_row * b_row * (term1 + term2) |
| 87 | + a_sq = a * a |
| 88 | + term2 = 0.5 * a * (1.0 - tanh_sq) * (sqrt_2_over_pi * (1.0 + 3.0 * gelu_coeff * a_sq)) |
| 89 | + da = dc * b * (term1 + term2) |
107 | 90 |
|
108 | | - tl.store(a + col_offsets, da_row, mask=mask) |
109 | | - tl.store(b + col_offsets, db_row.to(dc_row.dtype), mask=mask) |
| 91 | + tl.store(da_ptr + offsets, da, mask=mask) |
| 92 | + tl.store(db_ptr + offsets, db.to(dc.dtype), mask=mask) |
110 | 93 |
|
111 | 94 |
|
112 | | -def geglu_forward(a, b): |
| 95 | +def get_optimal_block_size(total_elements, is_backward=False): |
113 | 96 | """ |
114 | | - UB-aware GEGLU forward pass for NPU. |
| 97 | + Calculate optimal Block Size using compute_default_tiling_strategy. |
115 | 98 |
|
116 | | - Automatically adjusts block size to fit within UB constraints. |
117 | | - """ |
118 | | - ori_shape = a.shape |
| 99 | + Args: |
| 100 | + total_elements: Total number of elements to process |
| 101 | + is_backward: Whether this is for backward pass (requires more memory) |
119 | 102 |
|
120 | | - n_cols = ori_shape[-1] |
121 | | - a = a.view(-1, n_cols) |
122 | | - b = b.view(-1, n_cols) |
123 | | - c = torch.empty_like(a) |
124 | | - n_rows = a.shape[0] |
125 | | - |
126 | | - # Calculate desired block size |
127 | | - desired_block_size, num_warps = calculate_settings(n_cols) |
128 | | - |
129 | | - # Compute tiling strategy based on UB capacity |
130 | | - dtype_size = a.element_size() |
131 | | - # GEGLU forward tiling strategy: |
132 | | - # - Calculates maximum safe block size based on UB capacity |
133 | | - # - Memory analysis (only buffers that occupy UB, excluding temporary variables): |
134 | | - # * Inputs: a_row (4 bytes, float32), b_row (dtype_size bytes) |
135 | | - # * Output: c_row (dtype_size bytes) |
136 | | - # * Temporary variables (a_cubed, tanh_arg, tanh_result, geglu_a) are optimized to registers |
137 | | - # and don't occupy UB since they are only used once |
138 | | - # * For float16: a_row(4) + b_row(2) + c_row(2) = 8 bytes/element, ratio = 8/2 = 4.0 |
139 | | - # * For float32: a_row(4) + b_row(4) + c_row(4) = 12 bytes/element, ratio = 12/4 = 3.0 |
140 | | - # - Uses memory_multiplier=4.0 (float16) or 3.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits |
141 | | - # - shapes: ((n_cols,),) |
142 | | - # - tiling_dims: (0,) means first dimension can be tiled |
143 | | - # - Returns: ((block_size,),) |
144 | | - shapes = ((n_cols,),) |
145 | | - if dtype_size == 2: |
146 | | - memory_multiplier = 4.0 |
| 103 | + Returns: |
| 104 | + Optimal block size for the kernel |
| 105 | + """ |
| 106 | + # Memory multiplier based on peak memory usage analysis |
| 107 | + if is_backward: |
| 108 | + memory_multiplier = 6.0 |
147 | 109 | else: |
148 | 110 | memory_multiplier = 3.0 |
| 111 | + # Call calculation function |
| 112 | + # Treat input as 1D (total_elements,), only tiling on dim 0 |
149 | 113 | tile_shapes = compute_default_tiling_strategy( |
150 | | - safety_margin=0.80, |
151 | | - dtype_size=dtype_size, |
| 114 | + safety_margin=0.9, |
| 115 | + dtype_size=4, |
152 | 116 | memory_multiplier=memory_multiplier, |
153 | | - shapes=shapes, |
| 117 | + shapes=((total_elements,),), |
154 | 118 | tiling_dims=(0,), |
155 | 119 | ) |
156 | 120 |
|
157 | | - if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0: |
158 | | - # Strategy returns ((block_size,),) |
159 | | - adjusted_block_size = tile_shapes[0][0] |
| 121 | + # Parse result |
| 122 | + if tile_shapes and len(tile_shapes) > 0: |
| 123 | + block_size = tile_shapes[0][0] |
| 124 | + return max(256, block_size) |
160 | 125 | else: |
161 | | - # Fallback to desired block size if no best practice found (no tiling needed) |
162 | | - adjusted_block_size = desired_block_size |
163 | | - # Always use the unified NPU kernel |
164 | | - # When adjusted_block_size >= n_cols, the loop executes only once (no tiling) |
165 | | - # When adjusted_block_size < n_cols, the loop handles tiling automatically |
166 | | - _geglu_tanh_forward_kernel_npu[(n_rows,)]( |
167 | | - a, |
168 | | - b, |
169 | | - c, |
170 | | - c.stride(-2), |
171 | | - n_cols=n_cols, |
172 | | - BLOCK_SIZE=adjusted_block_size, |
173 | | - num_warps=num_warps, |
174 | | - ) |
175 | | - return a, b, c.view(*ori_shape) |
| 126 | + return 2048 |
176 | 127 |
|
177 | 128 |
|
178 | | -def geglu_backward(a, b, dc): |
| 129 | +def geglu_forward(a, b): |
179 | 130 | """ |
180 | | - UB-aware GEGLU backward pass for NPU. |
| 131 | + High-performance GEGLU forward pass for NPU using flatten 1D approach. |
| 132 | + """ |
| 133 | + if not a.is_contiguous(): |
| 134 | + a = a.contiguous() |
| 135 | + if not b.is_contiguous(): |
| 136 | + b = b.contiguous() |
| 137 | + |
| 138 | + total_elements = a.numel() |
| 139 | + c = torch.empty_like(a) |
| 140 | + |
| 141 | + block_size = get_optimal_block_size(total_elements, is_backward=False) |
| 142 | + |
| 143 | + num_cores = get_npu_core_count() |
| 144 | + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) |
| 145 | + |
| 146 | + _geglu_forward_kernel_flat[(grid_size,)](a, b, c, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4) |
| 147 | + return c |
| 148 | + |
181 | 149 |
|
182 | | - Automatically adjusts block size to fit within UB constraints. |
| 150 | +def geglu_backward(a, b, dc): |
183 | 151 | """ |
184 | | - ori_shape = dc.shape |
185 | | - n_cols = ori_shape[-1] |
186 | | - dc = dc.view(-1, n_cols) |
187 | | - n_rows = dc.shape[0] |
188 | | - |
189 | | - # Calculate desired block size |
190 | | - desired_block_size, num_warps = calculate_settings(n_cols) |
191 | | - |
192 | | - # Compute tiling strategy based on UB capacity |
193 | | - dtype_size = dc.element_size() |
194 | | - # GEGLU backward tiling strategy: |
195 | | - # - Calculates maximum safe block size based on UB capacity |
196 | | - # - Memory analysis: Peak memory usage occurs when executing line 103 (term1 calculation) |
197 | | - # At this point, the following buffers simultaneously occupy UB: |
198 | | - # 1. dc_row = tl.load(dc + col_offsets, ...) # dtype_size bytes |
199 | | - # 2. a_row = tl.load(a + col_offsets, ...).to(tl.float32) # 4 bytes (float32) |
200 | | - # 3. b_row = tl.load(b + col_offsets, ...) # dtype_size bytes |
201 | | - # 4. tanh_result = tanh(tanh_arg) # 4 bytes (float32), used in lines 95, 103, 104 |
202 | | - # 5. geglu_a = 0.5 * a_row * (1 + tanh_result) # 4 bytes (float32), used in lines 96, 98 |
203 | | - # 6. db_row = dc_row.cast(tl.float32) * geglu_a # 4 bytes (float32, computed at line 98, stored at line 109) |
204 | | - # Note: term1 (line 103) is a temporary variable optimized to registers and doesn't occupy UB |
205 | | - # Temporary variables (a_cubed, tanh_arg, term1, tanh_sq, term2) are optimized to registers |
206 | | - # and don't occupy UB since they are only used once |
207 | | - # * For float16: dc_row(2) + a_row(4) + b_row(2) + tanh_result(4) + geglu_a(4) + db_row(4) |
208 | | - # = 20 bytes/element, ratio = 20/2 = 10.0 |
209 | | - # * For float32: dc_row(4) + a_row(4) + b_row(4) + tanh_result(4) + geglu_a(4) + db_row(4) |
210 | | - # = 24 bytes/element, ratio = 24/4 = 6.0 |
211 | | - # - Uses memory_multiplier=10.0 (float16) or 6.0 (float32) * BLOCK_SIZE * dtype_size * 8 bits |
212 | | - # - shapes: ((n_cols,),) |
213 | | - # - tiling_dims: (0,) means first dimension can be tiled |
214 | | - # - Returns: ((block_size,),) |
215 | | - shapes = ((n_cols,),) |
216 | | - if dtype_size == 2: |
217 | | - memory_multiplier = 10.0 |
218 | | - else: |
219 | | - memory_multiplier = 6.0 |
220 | | - tile_shapes = compute_default_tiling_strategy( |
221 | | - safety_margin=0.80, |
222 | | - dtype_size=dtype_size, |
223 | | - memory_multiplier=memory_multiplier, |
224 | | - shapes=shapes, |
225 | | - tiling_dims=(0,), |
226 | | - ) |
| 152 | + High-performance GEGLU backward pass for NPU using flatten 1D approach. |
| 153 | + """ |
| 154 | + if not dc.is_contiguous(): |
| 155 | + dc = dc.contiguous() |
| 156 | + if not a.is_contiguous(): |
| 157 | + a = a.contiguous() |
| 158 | + if not b.is_contiguous(): |
| 159 | + b = b.contiguous() |
227 | 160 |
|
228 | | - if tile_shapes is not None and len(tile_shapes) > 0 and len(tile_shapes[0]) > 0: |
229 | | - # Strategy returns ((block_size,),) |
230 | | - adjusted_block_size = tile_shapes[0][0] |
231 | | - else: |
232 | | - # Fallback to desired block size if no best practice found (no tiling needed) |
233 | | - adjusted_block_size = desired_block_size |
234 | | - |
235 | | - # Always use the unified NPU kernel |
236 | | - # When adjusted_block_size >= n_cols, the loop executes only once (no tiling) |
237 | | - # When adjusted_block_size < n_cols, the loop handles tiling automatically |
238 | | - _geglu_tanh_backward_kernel_npu[(n_rows,)]( |
239 | | - dc, |
240 | | - a, |
241 | | - b, |
242 | | - dc.stride(-2), |
243 | | - n_cols=n_cols, |
244 | | - BLOCK_SIZE=adjusted_block_size, |
245 | | - num_warps=num_warps, |
246 | | - ) |
| 161 | + total_elements = dc.numel() |
| 162 | + grad_a = torch.empty_like(a) |
| 163 | + grad_b = torch.empty_like(b) |
| 164 | + |
| 165 | + block_size = get_optimal_block_size(total_elements, is_backward=True) |
247 | 166 |
|
248 | | - return a.view(*ori_shape), b.view(*ori_shape) |
| 167 | + num_cores = get_npu_core_count() |
| 168 | + grid_size = min(num_cores, (total_elements + block_size - 1) // block_size) |
| 169 | + |
| 170 | + _geglu_backward_kernel_flat[(grid_size,)]( |
| 171 | + dc, a, b, grad_a, grad_b, total_elements, BLOCK_SIZE=block_size, NUM_STAGES=3, num_warps=4 |
| 172 | + ) |
| 173 | + return grad_a, grad_b |
249 | 174 |
|
250 | 175 |
|
251 | 176 | class LigerGELUMulFunction(torch.autograd.Function): |
252 | | - """UB-aware GEGLU function for Ascend NPU.""" |
| 177 | + """High-performance GEGLU function for Ascend NPU.""" |
253 | 178 |
|
254 | 179 | @staticmethod |
255 | 180 | @ensure_contiguous |
256 | 181 | def forward(ctx, a, b): |
257 | | - a, b, c = geglu_forward(a, b) |
| 182 | + c = geglu_forward(a, b) |
258 | 183 | ctx.save_for_backward(a, b) |
259 | 184 | return c |
260 | 185 |
|
261 | 186 | @staticmethod |
262 | 187 | @ensure_contiguous |
263 | 188 | def backward(ctx, dc): |
264 | 189 | a, b = ctx.saved_tensors |
265 | | - a, b = geglu_backward(a, b, dc) |
266 | | - return a, b |
| 190 | + grad_a, grad_b = geglu_backward(a, b, dc) |
| 191 | + return grad_a, grad_b |
0 commit comments