Skip to content

Commit 0b3ccfd

Browse files
committed
Check in dynamic-shape-aware SwiGLU triton kernel
1 parent d324947 commit 0b3ccfd

File tree

2 files changed

+276
-10
lines changed

2 files changed

+276
-10
lines changed

megatron/core/fusions/fused_bias_swiglu.py

Lines changed: 274 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
import torch
77
import torch.nn.functional as F
8+
import triton
9+
import triton.language as tl
810

911
from megatron.core.jit import jit_fuser
1012
from megatron.core.utils import nvtx_decorator
@@ -190,20 +192,51 @@ def backward(ctx, grad_output):
190192

191193
class WeightedSwiGLUFunction(torch.autograd.Function):
192194
@staticmethod
193-
# bias is an optional argument
194-
def forward(ctx, input, weights, fp8_input_store):
195+
def forward(ctx, input, weights, fp8_input_store, num_tokens_tensor=None):
196+
"""Forward pass for weighted SwiGLU.
197+
198+
Args:
199+
input: [total_tokens, hidden_size * 2]
200+
weights: [total_tokens, 1]
201+
fp8_input_store: Whether to store in FP8
202+
num_tokens_tensor: Optional scalar tensor with actual token count
203+
(uses Triton if provided)
204+
"""
205+
# Convert input for backward pass
195206
input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input
196-
ctx.save_for_backward(input_for_backward, weights)
207+
208+
# Use Triton implementation if num_tokens_tensor provided and available
209+
if num_tokens_tensor is not None and input.dim() == 2:
210+
output = weighted_swiglu_triton(input, weights, num_tokens_tensor)
211+
ctx.save_for_backward(input_for_backward, weights, num_tokens_tensor)
212+
ctx.use_triton = True
213+
else:
214+
# Fallback to JIT fused implementation
215+
output = weighted_swiglu(input, weights)
216+
ctx.save_for_backward(input_for_backward, weights)
217+
ctx.use_triton = False
218+
197219
ctx.ori_input_dtype = input.dtype
198220
ctx.fp8_input_store = fp8_input_store
199-
return weighted_swiglu(input, weights)
221+
return output
200222

201223
@staticmethod
202224
def backward(ctx, grad_output):
203-
input, weights = ctx.saved_tensors
204-
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
205-
tmp, wgrad = weighted_swiglu_back(grad_output, input, weights)
206-
return tmp, wgrad, None
225+
"""Backward pass for weighted SwiGLU."""
226+
if ctx.use_triton:
227+
# Triton backward path
228+
input, weights, num_tokens_tensor = ctx.saved_tensors
229+
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
230+
grad_input, grad_weights = weighted_swiglu_triton_back(
231+
grad_output, input, weights, num_tokens_tensor
232+
)
233+
return grad_input, grad_weights, None, None
234+
else:
235+
# JIT fused backward path
236+
input, weights = ctx.saved_tensors
237+
input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input
238+
tmp, wgrad = weighted_swiglu_back(grad_output, input, weights)
239+
return tmp, wgrad, None, None
207240

208241

209242
def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False):
@@ -236,7 +269,7 @@ def bias_swiglu_impl(input, bias, fp8_input_store=False, cpu_offload_input=False
236269
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
237270

238271

239-
def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):
272+
def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False, num_tokens_tensor=None):
240273
"""
241274
Token-wise-weighted bias swiglu fusion.
242275
"""
@@ -246,10 +279,241 @@ def weighted_bias_swiglu_impl(input, bias, weights, fp8_input_store=False):
246279
if bias is not None:
247280
raise NotImplementedError("Bias is not supported for weighted swiglu fusion")
248281
else:
249-
output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store)
282+
output = WeightedSwiGLUFunction.apply(input, weights, fp8_input_store, num_tokens_tensor)
250283

251284
return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1)
252285

253286

254287
# bias_swiglu_impl = BiasSwiGLUFunction.apply
255288
# swiglu_impl = SwiGLUFunction.apply
289+
290+
@triton.jit
291+
def _weighted_swiglu_fwd_kernel(
292+
input_ptr,
293+
weights_ptr,
294+
output_ptr,
295+
num_tokens_ptr,
296+
hidden_size: tl.constexpr,
297+
BLOCK_SIZE: tl.constexpr,
298+
):
299+
"""Triton kernel for weighted SwiGLU forward pass.
300+
301+
Processes tokens in strided pattern, only operating on valid tokens.
302+
Formula: output = SiLU(input[:, :H]) * input[:, H:] * weights
303+
"""
304+
pid = tl.program_id(axis=0)
305+
num_blocks = tl.num_programs(axis=0)
306+
307+
# Load actual number of tokens
308+
num_tokens = tl.load(num_tokens_ptr)
309+
310+
# Strided access: each block handles tokens [pid, pid+num_blocks, ...]
311+
token_idx = pid
312+
while token_idx < num_tokens:
313+
# Load weight for this token
314+
weight = tl.load(weights_ptr + token_idx)
315+
316+
# Process hidden dimension
317+
for h_offset in range(0, hidden_size, BLOCK_SIZE):
318+
h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size
319+
320+
# Load input chunks (gate and value)
321+
input_offset_1 = token_idx * (hidden_size * 2) + h_offset
322+
input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset
323+
324+
y1 = tl.load(
325+
input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
326+
)
327+
y2 = tl.load(
328+
input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
329+
)
330+
331+
# SwiGLU: SiLU(y1) * y2 * weight
332+
# SiLU(x) = x * sigmoid(x)
333+
# Cast to fp32 for sigmoid computation (required by Triton)
334+
y1_fp32 = y1.to(tl.float32)
335+
y2_fp32 = y2.to(tl.float32)
336+
weight_fp32 = weight.to(tl.float32)
337+
338+
sigmoid_y1 = tl.sigmoid(y1_fp32)
339+
silu_y1 = y1_fp32 * sigmoid_y1
340+
result = silu_y1 * y2_fp32 * weight_fp32
341+
342+
# Store output (cast back to original dtype)
343+
output_offset = token_idx * hidden_size + h_offset
344+
tl.store(
345+
output_ptr + output_offset + tl.arange(0, BLOCK_SIZE),
346+
result.to(y1.dtype),
347+
mask=h_mask,
348+
)
349+
350+
# Stride to next token
351+
token_idx += num_blocks
352+
353+
@triton.jit
354+
def _weighted_swiglu_bwd_kernel(
355+
grad_output_ptr,
356+
input_ptr,
357+
weights_ptr,
358+
grad_input_ptr,
359+
grad_weights_ptr,
360+
num_tokens_ptr,
361+
hidden_size: tl.constexpr,
362+
BLOCK_SIZE: tl.constexpr,
363+
):
364+
"""Triton kernel for weighted SwiGLU backward pass.
365+
366+
Computes gradients with respect to input and weights for valid tokens only.
367+
"""
368+
pid = tl.program_id(axis=0)
369+
num_blocks = tl.num_programs(axis=0)
370+
371+
# Load actual number of tokens
372+
num_tokens = tl.load(num_tokens_ptr)
373+
374+
# Strided access
375+
token_idx = pid
376+
while token_idx < num_tokens:
377+
# Load weight for this token
378+
weight = tl.load(weights_ptr + token_idx)
379+
380+
# Accumulator for weight gradient (fp32 for precision)
381+
weight_grad_acc = 0.0
382+
383+
# Process hidden dimension
384+
for h_offset in range(0, hidden_size, BLOCK_SIZE):
385+
h_mask = (h_offset + tl.arange(0, BLOCK_SIZE)) < hidden_size
386+
387+
# Load grad_output
388+
grad_out_offset = token_idx * hidden_size + h_offset
389+
grad_out = tl.load(
390+
grad_output_ptr + grad_out_offset + tl.arange(0, BLOCK_SIZE),
391+
mask=h_mask,
392+
other=0.0,
393+
)
394+
395+
# Load input chunks
396+
input_offset_1 = token_idx * (hidden_size * 2) + h_offset
397+
input_offset_2 = token_idx * (hidden_size * 2) + hidden_size + h_offset
398+
399+
y1 = tl.load(
400+
input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
401+
)
402+
y2 = tl.load(
403+
input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE), mask=h_mask, other=0.0
404+
)
405+
406+
# Cast to fp32 for sigmoid computation (required by Triton)
407+
y1_fp32 = y1.to(tl.float32)
408+
y2_fp32 = y2.to(tl.float32)
409+
grad_out_fp32 = grad_out.to(tl.float32)
410+
weight_fp32 = weight.to(tl.float32)
411+
412+
# Forward calculations
413+
sigmoid_y1 = tl.sigmoid(y1_fp32)
414+
silu_y1 = y1_fp32 * sigmoid_y1
415+
416+
# Gradient for y1 (gate): d(SiLU(y1))/dy1 * y2 * weight * grad_out
417+
# d(SiLU(y1))/dy1 = sigmoid(y1) * (1 + y1 * (1 - sigmoid(y1)))
418+
dsilu_dy1 = sigmoid_y1 * (1.0 + y1_fp32 * (1.0 - sigmoid_y1))
419+
grad_y1 = grad_out_fp32 * weight_fp32 * dsilu_dy1 * y2_fp32
420+
421+
# Gradient for y2 (value): SiLU(y1) * weight * grad_out
422+
grad_y2 = grad_out_fp32 * weight_fp32 * silu_y1
423+
424+
# Store input gradients (cast back to original dtype)
425+
tl.store(
426+
grad_input_ptr + input_offset_1 + tl.arange(0, BLOCK_SIZE),
427+
grad_y1.to(y1.dtype),
428+
mask=h_mask,
429+
)
430+
tl.store(
431+
grad_input_ptr + input_offset_2 + tl.arange(0, BLOCK_SIZE),
432+
grad_y2.to(y2.dtype),
433+
mask=h_mask,
434+
)
435+
436+
# Accumulate weight gradient: swiglu(y) * grad_out
437+
# swiglu(y) = silu_y1 * y2
438+
weight_grad_contribution = silu_y1 * y2_fp32 * grad_out_fp32
439+
weight_grad_acc += tl.sum(weight_grad_contribution)
440+
441+
# Store weight gradient after processing all chunks
442+
tl.store(grad_weights_ptr + token_idx, weight_grad_acc)
443+
444+
# Stride to next token
445+
token_idx += num_blocks
446+
447+
def weighted_swiglu_triton(input, weights, num_tokens_tensor):
448+
"""Triton implementation of weighted SwiGLU forward pass.
449+
450+
Args:
451+
input: [total_tokens, hidden_size * 2]
452+
weights: [total_tokens, 1]
453+
num_tokens_tensor: Scalar tensor with actual token count
454+
455+
Returns:
456+
output: [total_tokens, hidden_size]
457+
"""
458+
assert input.dim() == 2, "Input must be 2D [total_tokens, hidden_size*2]"
459+
assert weights.dim() == 2 and weights.size(1) == 1, "Weights must be [total_tokens, 1]"
460+
461+
total_tokens, hidden_size_2 = input.shape
462+
hidden_size = hidden_size_2 // 2
463+
464+
# Allocate output
465+
output = torch.empty((total_tokens, hidden_size), dtype=input.dtype, device=input.device)
466+
467+
# Launch kernel
468+
BLOCK_SIZE = 128
469+
num_blocks = min(total_tokens, 4096)
470+
grid = (num_blocks,)
471+
472+
_weighted_swiglu_fwd_kernel[grid](
473+
input,
474+
weights,
475+
output,
476+
num_tokens_tensor,
477+
hidden_size=hidden_size,
478+
BLOCK_SIZE=BLOCK_SIZE,
479+
)
480+
481+
return output
482+
483+
def weighted_swiglu_triton_back(grad_output, input, weights, num_tokens_tensor):
484+
"""Triton implementation of weighted SwiGLU backward pass.
485+
486+
Args:
487+
grad_output: [total_tokens, hidden_size]
488+
input: [total_tokens, hidden_size * 2]
489+
weights: [total_tokens, 1]
490+
num_tokens_tensor: Scalar tensor with actual token count
491+
492+
Returns:
493+
grad_input: [total_tokens, hidden_size * 2]
494+
grad_weights: [total_tokens, 1]
495+
"""
496+
total_tokens, hidden_size_2 = input.shape
497+
hidden_size = hidden_size_2 // 2
498+
499+
# Allocate gradients
500+
grad_input = torch.empty_like(input)
501+
grad_weights = torch.empty_like(weights)
502+
503+
# Launch kernel
504+
BLOCK_SIZE = 128
505+
num_blocks = min(total_tokens, 4096)
506+
grid = (num_blocks,)
507+
508+
_weighted_swiglu_bwd_kernel[grid](
509+
grad_output,
510+
input,
511+
weights,
512+
grad_input,
513+
grad_weights,
514+
num_tokens_tensor,
515+
hidden_size=hidden_size,
516+
BLOCK_SIZE=BLOCK_SIZE,
517+
)
518+
519+
return grad_input, grad_weights

megatron/core/transformer/moe/experts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,8 @@ def bias_act_func(intermediate_parallel, bias_parallel, permuted_probs):
941941
bias_parallel,
942942
permuted_probs,
943943
self.config.activation_func_fp8_input_store,
944+
tokens_per_expert.sum() if self.packed_offload_moe_act else None,
945+
944946
)
945947
elif self.activation_func == quick_gelu and self.config.gated_linear_unit:
946948
intermediate_parallel = weighted_bias_quick_geglu_impl(

0 commit comments

Comments
 (0)