55
66import torch
77import torch .nn .functional as F
8+ import triton
9+ import triton .language as tl
810
911from megatron .core .jit import jit_fuser
1012from megatron .core .utils import nvtx_decorator
@@ -190,20 +192,51 @@ def backward(ctx, grad_output):
190192
191193class WeightedSwiGLUFunction (torch .autograd .Function ):
192194 @staticmethod
193- # bias is an optional argument
194- def forward (ctx , input , weights , fp8_input_store ):
195+ def forward (ctx , input , weights , fp8_input_store , num_tokens_tensor = None ):
196+ """Forward pass for weighted SwiGLU.
197+
198+ Args:
199+ input: [total_tokens, hidden_size * 2]
200+ weights: [total_tokens, 1]
201+ fp8_input_store: Whether to store in FP8
202+ num_tokens_tensor: Optional scalar tensor with actual token count
203+ (uses Triton if provided)
204+ """
205+ # Convert input for backward pass
195206 input_for_backward = input .to (torch .float8_e4m3fn ) if fp8_input_store else input
196- ctx .save_for_backward (input_for_backward , weights )
207+
208+ # Use Triton implementation if num_tokens_tensor provided and available
209+ if num_tokens_tensor is not None and input .dim () == 2 :
210+ output = weighted_swiglu_triton (input , weights , num_tokens_tensor )
211+ ctx .save_for_backward (input_for_backward , weights , num_tokens_tensor )
212+ ctx .use_triton = True
213+ else :
214+ # Fallback to JIT fused implementation
215+ output = weighted_swiglu (input , weights )
216+ ctx .save_for_backward (input_for_backward , weights )
217+ ctx .use_triton = False
218+
197219 ctx .ori_input_dtype = input .dtype
198220 ctx .fp8_input_store = fp8_input_store
199- return weighted_swiglu ( input , weights )
221+ return output
200222
201223 @staticmethod
202224 def backward (ctx , grad_output ):
203- input , weights = ctx .saved_tensors
204- input = input .to (ctx .ori_input_dtype ) if ctx .fp8_input_store else input
205- tmp , wgrad = weighted_swiglu_back (grad_output , input , weights )
206- return tmp , wgrad , None
225+ """Backward pass for weighted SwiGLU."""
226+ if ctx .use_triton :
227+ # Triton backward path
228+ input , weights , num_tokens_tensor = ctx .saved_tensors
229+ input = input .to (ctx .ori_input_dtype ) if ctx .fp8_input_store else input
230+ grad_input , grad_weights = weighted_swiglu_triton_back (
231+ grad_output , input , weights , num_tokens_tensor
232+ )
233+ return grad_input , grad_weights , None , None
234+ else :
235+ # JIT fused backward path
236+ input , weights = ctx .saved_tensors
237+ input = input .to (ctx .ori_input_dtype ) if ctx .fp8_input_store else input
238+ tmp , wgrad = weighted_swiglu_back (grad_output , input , weights )
239+ return tmp , wgrad , None , None
207240
208241
209242def bias_swiglu_impl (input , bias , fp8_input_store = False , cpu_offload_input = False ):
@@ -236,7 +269,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False
236269 return output if len (ori_shape ) == 2 else output .view (ori_shape [0 ], ori_shape [1 ], - 1 )
237270
238271
239- def weighted_bias_swiglu_impl (input , bias , weights , fp8_input_store = False ):
272+ def weighted_bias_swiglu_impl (input , bias , weights , fp8_input_store = False , num_tokens_tensor = None ):
240273 """
241274 Token-wise-weighted bias swiglu fusion.
242275 """
@@ -246,10 +279,241 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):
246279 if bias is not None :
247280 raise NotImplementedError ("Bias is not supported for weighted swiglu fusion" )
248281 else :
249- output = WeightedSwiGLUFunction .apply (input , weights , fp8_input_store )
282+ output = WeightedSwiGLUFunction .apply (input , weights , fp8_input_store , num_tokens_tensor )
250283
251284 return output if len (ori_shape ) == 2 else output .view (ori_shape [0 ], ori_shape [1 ], - 1 )
252285
253286
254287# bias_swiglu_impl = BiasSwiGLUFunction.apply
255288# swiglu_impl = SwiGLUFunction.apply
289+
290+ @triton .jit
291+ def _weighted_swiglu_fwd_kernel (
292+ input_ptr ,
293+ weights_ptr ,
294+ output_ptr ,
295+ num_tokens_ptr ,
296+ hidden_size : tl .constexpr ,
297+ BLOCK_SIZE : tl .constexpr ,
298+ ):
299+ """Triton kernel for weighted SwiGLU forward pass.
300+
301+ Processes tokens in strided pattern, only operating on valid tokens.
302+ Formula: output = SiLU(input[:, :H]) * input[:, H:] * weights
303+ """
304+ pid = tl .program_id (axis = 0 )
305+ num_blocks = tl .num_programs (axis = 0 )
306+
307+ # Load actual number of tokens
308+ num_tokens = tl .load (num_tokens_ptr )
309+
310+ # Strided access: each block handles tokens [pid, pid+num_blocks, ...]
311+ token_idx = pid
312+ while token_idx < num_tokens :
313+ # Load weight for this token
314+ weight = tl .load (weights_ptr + token_idx )
315+
316+ # Process hidden dimension
317+ for h_offset in range (0 , hidden_size , BLOCK_SIZE ):
318+ h_mask = (h_offset + tl .arange (0 , BLOCK_SIZE )) < hidden_size
319+
320+ # Load input chunks (gate and value)
321+ input_offset_1 = token_idx * (hidden_size * 2 ) + h_offset
322+ input_offset_2 = token_idx * (hidden_size * 2 ) + hidden_size + h_offset
323+
324+ y1 = tl .load (
325+ input_ptr + input_offset_1 + tl .arange (0 , BLOCK_SIZE ), mask = h_mask , other = 0.0
326+ )
327+ y2 = tl .load (
328+ input_ptr + input_offset_2 + tl .arange (0 , BLOCK_SIZE ), mask = h_mask , other = 0.0
329+ )
330+
331+ # SwiGLU: SiLU(y1) * y2 * weight
332+ # SiLU(x) = x * sigmoid(x)
333+ # Cast to fp32 for sigmoid computation (required by Triton)
334+ y1_fp32 = y1 .to (tl .float32 )
335+ y2_fp32 = y2 .to (tl .float32 )
336+ weight_fp32 = weight .to (tl .float32 )
337+
338+ sigmoid_y1 = tl .sigmoid (y1_fp32 )
339+ silu_y1 = y1_fp32 * sigmoid_y1
340+ result = silu_y1 * y2_fp32 * weight_fp32
341+
342+ # Store output (cast back to original dtype)
343+ output_offset = token_idx * hidden_size + h_offset
344+ tl .store (
345+ output_ptr + output_offset + tl .arange (0 , BLOCK_SIZE ),
346+ result .to (y1 .dtype ),
347+ mask = h_mask ,
348+ )
349+
350+ # Stride to next token
351+ token_idx += num_blocks
352+
353+ @triton .jit
354+ def _weighted_swiglu_bwd_kernel (
355+ grad_output_ptr ,
356+ input_ptr ,
357+ weights_ptr ,
358+ grad_input_ptr ,
359+ grad_weights_ptr ,
360+ num_tokens_ptr ,
361+ hidden_size : tl .constexpr ,
362+ BLOCK_SIZE : tl .constexpr ,
363+ ):
364+ """Triton kernel for weighted SwiGLU backward pass.
365+
366+ Computes gradients with respect to input and weights for valid tokens only.
367+ """
368+ pid = tl .program_id (axis = 0 )
369+ num_blocks = tl .num_programs (axis = 0 )
370+
371+ # Load actual number of tokens
372+ num_tokens = tl .load (num_tokens_ptr )
373+
374+ # Strided access
375+ token_idx = pid
376+ while token_idx < num_tokens :
377+ # Load weight for this token
378+ weight = tl .load (weights_ptr + token_idx )
379+
380+ # Accumulator for weight gradient (fp32 for precision)
381+ weight_grad_acc = 0.0
382+
383+ # Process hidden dimension
384+ for h_offset in range (0 , hidden_size , BLOCK_SIZE ):
385+ h_mask = (h_offset + tl .arange (0 , BLOCK_SIZE )) < hidden_size
386+
387+ # Load grad_output
388+ grad_out_offset = token_idx * hidden_size + h_offset
389+ grad_out = tl .load (
390+ grad_output_ptr + grad_out_offset + tl .arange (0 , BLOCK_SIZE ),
391+ mask = h_mask ,
392+ other = 0.0 ,
393+ )
394+
395+ # Load input chunks
396+ input_offset_1 = token_idx * (hidden_size * 2 ) + h_offset
397+ input_offset_2 = token_idx * (hidden_size * 2 ) + hidden_size + h_offset
398+
399+ y1 = tl .load (
400+ input_ptr + input_offset_1 + tl .arange (0 , BLOCK_SIZE ), mask = h_mask , other = 0.0
401+ )
402+ y2 = tl .load (
403+ input_ptr + input_offset_2 + tl .arange (0 , BLOCK_SIZE ), mask = h_mask , other = 0.0
404+ )
405+
406+ # Cast to fp32 for sigmoid computation (required by Triton)
407+ y1_fp32 = y1 .to (tl .float32 )
408+ y2_fp32 = y2 .to (tl .float32 )
409+ grad_out_fp32 = grad_out .to (tl .float32 )
410+ weight_fp32 = weight .to (tl .float32 )
411+
412+ # Forward calculations
413+ sigmoid_y1 = tl .sigmoid (y1_fp32 )
414+ silu_y1 = y1_fp32 * sigmoid_y1
415+
416+ # Gradient for y1 (gate): d(SiLU(y1))/dy1 * y2 * weight * grad_out
417+ # d(SiLU(y1))/dy1 = sigmoid(y1) * (1 + y1 * (1 - sigmoid(y1)))
418+ dsilu_dy1 = sigmoid_y1 * (1.0 + y1_fp32 * (1.0 - sigmoid_y1 ))
419+ grad_y1 = grad_out_fp32 * weight_fp32 * dsilu_dy1 * y2_fp32
420+
421+ # Gradient for y2 (value): SiLU(y1) * weight * grad_out
422+ grad_y2 = grad_out_fp32 * weight_fp32 * silu_y1
423+
424+ # Store input gradients (cast back to original dtype)
425+ tl .store (
426+ grad_input_ptr + input_offset_1 + tl .arange (0 , BLOCK_SIZE ),
427+ grad_y1 .to (y1 .dtype ),
428+ mask = h_mask ,
429+ )
430+ tl .store (
431+ grad_input_ptr + input_offset_2 + tl .arange (0 , BLOCK_SIZE ),
432+ grad_y2 .to (y2 .dtype ),
433+ mask = h_mask ,
434+ )
435+
436+ # Accumulate weight gradient: swiglu(y) * grad_out
437+ # swiglu(y) = silu_y1 * y2
438+ weight_grad_contribution = silu_y1 * y2_fp32 * grad_out_fp32
439+ weight_grad_acc += tl .sum (weight_grad_contribution )
440+
441+ # Store weight gradient after processing all chunks
442+ tl .store (grad_weights_ptr + token_idx , weight_grad_acc )
443+
444+ # Stride to next token
445+ token_idx += num_blocks
446+
447+ def weighted_swiglu_triton (input , weights , num_tokens_tensor ):
448+ """Triton implementation of weighted SwiGLU forward pass.
449+
450+ Args:
451+ input: [total_tokens, hidden_size * 2]
452+ weights: [total_tokens, 1]
453+ num_tokens_tensor: Scalar tensor with actual token count
454+
455+ Returns:
456+ output: [total_tokens, hidden_size]
457+ """
458+ assert input .dim () == 2 , "Input must be 2D [total_tokens, hidden_size*2]"
459+ assert weights .dim () == 2 and weights .size (1 ) == 1 , "Weights must be [total_tokens, 1]"
460+
461+ total_tokens , hidden_size_2 = input .shape
462+ hidden_size = hidden_size_2 // 2
463+
464+ # Allocate output
465+ output = torch .empty ((total_tokens , hidden_size ), dtype = input .dtype , device = input .device )
466+
467+ # Launch kernel
468+ BLOCK_SIZE = 128
469+ num_blocks = min (total_tokens , 4096 )
470+ grid = (num_blocks ,)
471+
472+ _weighted_swiglu_fwd_kernel [grid ](
473+ input ,
474+ weights ,
475+ output ,
476+ num_tokens_tensor ,
477+ hidden_size = hidden_size ,
478+ BLOCK_SIZE = BLOCK_SIZE ,
479+ )
480+
481+ return output
482+
483+ def weighted_swiglu_triton_back (grad_output , input , weights , num_tokens_tensor ):
484+ """Triton implementation of weighted SwiGLU backward pass.
485+
486+ Args:
487+ grad_output: [total_tokens, hidden_size]
488+ input: [total_tokens, hidden_size * 2]
489+ weights: [total_tokens, 1]
490+ num_tokens_tensor: Scalar tensor with actual token count
491+
492+ Returns:
493+ grad_input: [total_tokens, hidden_size * 2]
494+ grad_weights: [total_tokens, 1]
495+ """
496+ total_tokens , hidden_size_2 = input .shape
497+ hidden_size = hidden_size_2 // 2
498+
499+ # Allocate gradients
500+ grad_input = torch .empty_like (input )
501+ grad_weights = torch .empty_like (weights )
502+
503+ # Launch kernel
504+ BLOCK_SIZE = 128
505+ num_blocks = min (total_tokens , 4096 )
506+ grid = (num_blocks ,)
507+
508+ _weighted_swiglu_bwd_kernel [grid ](
509+ grad_output ,
510+ input ,
511+ weights ,
512+ grad_input ,
513+ grad_weights ,
514+ num_tokens_tensor ,
515+ hidden_size = hidden_size ,
516+ BLOCK_SIZE = BLOCK_SIZE ,
517+ )
518+
519+ return grad_input , grad_weights
0 commit comments