@@ -13,18 +13,11 @@ def get_optimal_block_size(n_cols, has_gradients=True):
1313 """
1414 Calculate optimal Block Size using compute_default_tiling_strategy
1515 """
16- # Cross entropy is more memory intensive than swiglu because it needs softmax computation
17- # Forward needs online softmax calculation, backward needs more memory for intermediate variables
18- # 10.0 and 16.0 are empirical values based on Atlas 800I A2 UB (192KB)
1916 multiplier = 12.0 if has_gradients else 8.0
20-
21- # Call calculation function
22- # Treat input as 1D (n_cols,), only tiling on dim 0
2317 tile_shapes = compute_default_tiling_strategy (
2418 safety_margin = 0.9 , dtype_size = 4 , memory_multiplier = multiplier , shapes = ((n_cols ,),), tiling_dims = (0 ,)
2519 )
2620
27- # Parse result
2821 if tile_shapes and len (tile_shapes ) > 0 :
2922 block_size = tile_shapes [0 ][0 ]
3023 return block_size
@@ -37,14 +30,10 @@ def get_optimal_block_size_element_mul(n_cols, dtype_size):
3730 Calculate optimal Block Size using compute_default_tiling_strategy for element-wise multiplication in backward pass
3831 """
3932 multiplier = 3.0
40-
41- # Call calculation function
42- # Treat input as 1D (n_cols,), only tiling on dim 0
4333 tile_shapes = compute_default_tiling_strategy (
4434 safety_margin = 0.9 , dtype_size = dtype_size , memory_multiplier = multiplier , shapes = ((n_cols ,),), tiling_dims = (0 ,)
4535 )
4636
47- # Parse result
4837 if tile_shapes and len (tile_shapes ) > 0 :
4938 block_size = tile_shapes [0 ][0 ]
5039 return block_size
@@ -77,16 +66,7 @@ def fused_linear_cross_entropy_forward(
7766 f"return_predicted_tokens must be True or False. Got: { return_predicted_tokens } "
7867 )
7968 device = _input .device
80-
8169 input_requires_grad = _input .requires_grad
82-
83- # inputs have shape: BT x H
84- # materialized activations will have shape: BT x V
85- # the increase in memory = BT x V
86- # reduction can be achieved by partitioning the number of tokens BT into smaller chunks.
87- # for ex: if we were to achieve the same memory consumption as BT x H, then the chunk size should be:
88- # inc_factor = (V+H-1)//H, chunk_size = (BT + inc_factor - 1)//inc_factor
89- # for ex: BT = 4096*4, V = 32000, H = 4096 ==> inc_factor = 8, chunk_size = 2048
9070 BT , H = _input .shape
9171 V = weight .shape [0 ]
9272 BLOCK_SIZE = get_optimal_block_size (V , has_gradients = _input .requires_grad )
@@ -130,13 +110,15 @@ def fused_linear_cross_entropy_forward(
130110 if ce_weight .stride (- 1 ) != 1 :
131111 ce_weight = ce_weight .contiguous ()
132112
113+ num_cores = get_npu_core_count ()
114+ logits = _input @ weight .t () # BT x V
115+
133116 for chunk_id in range (num_chunks ):
134117 start_idx = chunk_id * chunk_size
135118 end_idx = min ((chunk_id + 1 ) * chunk_size , BT )
136- _input_chunk = _input [ start_idx : end_idx ] # chunk_size x H
119+ # # when doing matmul, use the original precision
137120
138- # when doing matmul, use the original precision
139- logits_chunk = _input_chunk @ weight .t () # chunk_size x V
121+ logits_chunk = logits [start_idx :end_idx ] # chunk_size x V
140122 if bias is not None :
141123 logits_chunk = logits_chunk + bias
142124
@@ -183,10 +165,7 @@ def fused_linear_cross_entropy_forward(
183165 # ensure _input and target are contiguous
184166 logits_chunk = logits_chunk .contiguous ()
185167 target_chunk = target_chunk .contiguous ()
186- num_cores = get_npu_core_count ()
187168
188- # Here we calculate the gradient of logits_chunk in place so we can save memory.
189- # Grid size is capped at NPU core count; the kernel uses a grid-stride loop
190169 liger_cross_entropy_kernel [(min (n_rows , num_cores ),)](
191170 X_ptr = logits_chunk ,
192171 X_stride = logits_chunk .stride (- 2 ),
@@ -247,31 +226,26 @@ def fused_linear_cross_entropy_forward(
247226 if input_requires_grad :
248227 grad_input [start_idx :end_idx ] = grad_logits_chunk @ weight
249228
250- if grad_weight is not None and input_requires_grad :
251- grad_weight += torch .mm (grad_logits_chunk .t (), _input_chunk ).float ()
252-
253- if bias is not None and input_requires_grad :
254- torch .add (
255- input = grad_bias ,
256- other = grad_logits_chunk .sum (dim = 0 ),
257- out = grad_bias ,
258- alpha = 1.0 ,
259- )
260-
261- # Need extra calculations for backward if reduction=='none'. Not supporting reduction='none' now.
262- # if reduction == "none":
263- # loss = loss_1d
264- # z_loss = z_loss_1d if return_z_loss else None
229+ if bias is not None :
230+ logits [start_idx :end_idx ] = grad_logits_chunk
231+
232+ if grad_weight is not None and input_requires_grad :
233+ grad_weight = logits .t () @ _input
234+ if bias is not None and input_requires_grad :
235+ torch .add (
236+ input = grad_bias ,
237+ other = logits .sum (dim = 0 ),
238+ out = grad_bias ,
239+ alpha = 1.0 ,
240+ )
265241
266242 if reduction == "none" :
267- # Return per-token losses
268243 loss = loss_1d
269244 z_loss = z_loss_1d if return_z_loss else None
270245 token_accuracy = token_accuracy_1d if return_token_accuracy else None
271246 else :
272247 loss = torch .sum (loss_1d )
273248 z_loss = torch .sum (z_loss_1d ) if return_z_loss else None
274- # For accuracy, we compute the mean across all non-ignored tokens
275249 token_accuracy = torch .sum (token_accuracy_1d ) / total_n_non_ignore if return_token_accuracy else None
276250
277251 predicted_tokens = predicted_tokens_1d if return_predicted_tokens else None
@@ -286,8 +260,6 @@ def fused_linear_cross_entropy_forward(
286260def fused_linear_cross_entropy_backward (grad_output , grad_input , grad_weight , grad_bias ):
287261 # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
288262 if not torch .equal (grad_output , torch .tensor (1.0 , device = grad_output .device )):
289- # We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
290- # for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
291263 BT , H = grad_input .shape
292264 n_rows = BT
293265 BLOCK_SIZE = get_optimal_block_size_element_mul (H , grad_output .element_size ())
0 commit comments