Skip to content

Commit e56bdf5

Browse files
committed
optimize performance
1 parent 11b6a56 commit e56bdf5

File tree

1 file changed

+17
-45
lines changed

1 file changed

+17
-45
lines changed

src/liger_kernel/ops/backends/_ascend/ops/fused_linear_cross_entropy.py

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,11 @@ def get_optimal_block_size(n_cols, has_gradients=True):
1313
"""
1414
Calculate optimal Block Size using compute_default_tiling_strategy
1515
"""
16-
# Cross entropy is more memory intensive than swiglu because it needs softmax computation
17-
# Forward needs online softmax calculation, backward needs more memory for intermediate variables
18-
# 10.0 and 16.0 are empirical values based on Atlas 800I A2 UB (192KB)
1916
multiplier = 12.0 if has_gradients else 8.0
20-
21-
# Call calculation function
22-
# Treat input as 1D (n_cols,), only tiling on dim 0
2317
tile_shapes = compute_default_tiling_strategy(
2418
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((n_cols,),), tiling_dims=(0,)
2519
)
2620

27-
# Parse result
2821
if tile_shapes and len(tile_shapes) > 0:
2922
block_size = tile_shapes[0][0]
3023
return block_size
@@ -37,14 +30,10 @@ def get_optimal_block_size_element_mul(n_cols, dtype_size):
3730
Calculate optimal Block Size using compute_default_tiling_strategy for element-wise multiplication in backward pass
3831
"""
3932
multiplier = 3.0
40-
41-
# Call calculation function
42-
# Treat input as 1D (n_cols,), only tiling on dim 0
4333
tile_shapes = compute_default_tiling_strategy(
4434
safety_margin=0.9, dtype_size=dtype_size, memory_multiplier=multiplier, shapes=((n_cols,),), tiling_dims=(0,)
4535
)
4636

47-
# Parse result
4837
if tile_shapes and len(tile_shapes) > 0:
4938
block_size = tile_shapes[0][0]
5039
return block_size
@@ -77,16 +66,7 @@ def fused_linear_cross_entropy_forward(
7766
f"return_predicted_tokens must be True or False. Got: {return_predicted_tokens}"
7867
)
7968
device = _input.device
80-
8169
input_requires_grad = _input.requires_grad
82-
83-
# inputs have shape: BT x H
84-
# materialized activations will have shape: BT x V
85-
# the increase in memory = BT x V
86-
# reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
87-
# for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
88-
# inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
89-
# for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
9070
BT, H = _input.shape
9171
V = weight.shape[0]
9272
BLOCK_SIZE = get_optimal_block_size(V, has_gradients=_input.requires_grad)
@@ -130,13 +110,15 @@ def fused_linear_cross_entropy_forward(
130110
if ce_weight.stride(-1) != 1:
131111
ce_weight = ce_weight.contiguous()
132112

113+
num_cores = get_npu_core_count()
114+
logits = _input @ weight.t() # BT x V
115+
133116
for chunk_id in range(num_chunks):
134117
start_idx = chunk_id * chunk_size
135118
end_idx = min((chunk_id + 1) * chunk_size, BT)
136-
_input_chunk = _input[start_idx:end_idx] # chunk_size x H
119+
# # when doing matmul, use the original precision
137120

138-
# when doing matmul, use the original precision
139-
logits_chunk = _input_chunk @ weight.t() # chunk_size x V
121+
logits_chunk = logits[start_idx:end_idx] # chunk_size x V
140122
if bias is not None:
141123
logits_chunk = logits_chunk + bias
142124

@@ -183,10 +165,7 @@ def fused_linear_cross_entropy_forward(
183165
# ensure _input and target are contiguous
184166
logits_chunk = logits_chunk.contiguous()
185167
target_chunk = target_chunk.contiguous()
186-
num_cores = get_npu_core_count()
187168

188-
# Here we calculate the gradient of logits_chunk in place so we can save memory.
189-
# Grid size is capped at NPU core count; the kernel uses a grid-stride loop
190169
liger_cross_entropy_kernel[(min(n_rows, num_cores),)](
191170
X_ptr=logits_chunk,
192171
X_stride=logits_chunk.stride(-2),
@@ -247,31 +226,26 @@ def fused_linear_cross_entropy_forward(
247226
if input_requires_grad:
248227
grad_input[start_idx:end_idx] = grad_logits_chunk @ weight
249228

250-
if grad_weight is not None and input_requires_grad:
251-
grad_weight += torch.mm(grad_logits_chunk.t(), _input_chunk).float()
252-
253-
if bias is not None and input_requires_grad:
254-
torch.add(
255-
input=grad_bias,
256-
other=grad_logits_chunk.sum(dim=0),
257-
out=grad_bias,
258-
alpha=1.0,
259-
)
260-
261-
# Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
262-
# if reduction == "none":
263-
# loss = loss_1d
264-
# z_loss = z_loss_1d if return_z_loss else None
229+
if bias is not None:
230+
logits[start_idx:end_idx] = grad_logits_chunk
231+
232+
if grad_weight is not None and input_requires_grad:
233+
grad_weight = logits.t() @ _input
234+
if bias is not None and input_requires_grad:
235+
torch.add(
236+
input=grad_bias,
237+
other=logits.sum(dim=0),
238+
out=grad_bias,
239+
alpha=1.0,
240+
)
265241

266242
if reduction == "none":
267-
# Return per-token losses
268243
loss = loss_1d
269244
z_loss = z_loss_1d if return_z_loss else None
270245
token_accuracy = token_accuracy_1d if return_token_accuracy else None
271246
else:
272247
loss = torch.sum(loss_1d)
273248
z_loss = torch.sum(z_loss_1d) if return_z_loss else None
274-
# For accuracy, we compute the mean across all non-ignored tokens
275249
token_accuracy = torch.sum(token_accuracy_1d) / total_n_non_ignore if return_token_accuracy else None
276250

277251
predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
@@ -286,8 +260,6 @@ def fused_linear_cross_entropy_forward(
286260
def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
287261
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
288262
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
289-
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
290-
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
291263
BT, H = grad_input.shape
292264
n_rows = BT
293265
BLOCK_SIZE = get_optimal_block_size_element_mul(H, grad_output.element_size())

0 commit comments

Comments
 (0)