Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sonicmoe/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,10 @@ class ActivationType(Enum):
SILU = "silu"


class ScoringFuncType(Enum):
SOFTMAX = "softmax"
SIGMOID = "sigmoid"


def is_glu(activation_type: ActivationType):
return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU]
13 changes: 8 additions & 5 deletions sonicmoe/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from quack.gemm_interface import gemm

from ..count_cumsum import count_cumsum
from ..enums import ActivationType, is_glu
from ..enums import ActivationType, ScoringFuncType, is_glu
from ..quack_utils import gemm_dgated, gemm_gated
from .backward import _down_projection_backward, _softmax_topk_bwd, _token_broadcast_backward, _up_projection_backward
from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward
Expand Down Expand Up @@ -84,15 +84,17 @@ def general_routing_router_metadata(

class TC_Softmax_Topk_Router_Function(torch.autograd.Function):
@staticmethod
def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]:
def forward(
ctx, router_logits: torch.Tensor, E: int, K: int, scoring_func: ScoringFuncType
) -> tuple[torch.Tensor, torch.Tensor]:
T = router_logits.size(0)

# change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy
topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device)
topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device)
ctx.mark_non_differentiable(topk_router_indices)

_softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K)
_softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K, scoring_func)

ctx.save_for_backward(topk_router_score, topk_router_indices)
ctx.E = E
Expand All @@ -108,7 +110,7 @@ def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Ten
topk_router_score, topk_router_indices = ctx.saved_tensor()
dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device)

_softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K)
_softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K, scoring_func)

return dlogits

Expand Down Expand Up @@ -445,13 +447,14 @@ def moe_TC_softmax_topk_layer(
K: int,
stream_id: int,
activation_type: ActivationType | str = ActivationType.SWIGLU,
scoring_func: ScoringFuncType | str = ScoringFuncType.SOFTMAX,
is_inference_mode_enabled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert ((b1 is None) and (b2 is None)) or (
(b1 is not None) and (b2 is not None)
), "b1 and b2 has to be None or not None at the same time!"
router_logits = F.linear(x, router_w)
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K)
topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K, scoring_func)
expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True)

(
Expand Down
36 changes: 29 additions & 7 deletions sonicmoe/functional/backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import triton
import triton.language as tl

from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, is_glu
from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, ScoringFuncType, is_glu
from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2
from .moe_config import (
HopperWgmma_MoE_Down_proj_ActGrad_Bwd,
Expand Down Expand Up @@ -485,7 +485,7 @@ def _token_broadcast_backward(


@triton.jit
def _softmax_bwd_scatter_small_kernel(
def _activation_bwd_scatter_small_kernel(
dlogits_ptr,
dlogits_full_ptr,
score_ptr,
Expand All @@ -502,26 +502,46 @@ def _softmax_bwd_scatter_small_kernel(
K: tl.constexpr,
BLOCK_K: tl.constexpr,
dlogits_is_none: tl.constexpr,
scoring_func: tl.constexpr,
):
row = tl.program_id(axis=0)

# tl.assume(K <= BLOCK_K)
k_offs = tl.arange(0, BLOCK_K)
k_mask = k_offs < K

# Load indices
idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32)

# Load forward probabilities (y) and incoming gradients (g)
s_sel = tl.load(score_ptr + row * stride_sm + k_offs * stride_sn, mask=k_mask, other=0).to(tl.float32)
g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32)

# dot = sum_j g_j * y_j over selected columns
dot = tl.sum(g_sel * s_sel, axis=0)
if scoring_func == ScoringFuncType.SIGMOID:
# === Sigmoid Backward ===
# Derivative: dx = g * y * (1 - y)
# No cross-term reduction needed
add_vals = g_sel * s_sel * (1.0 - s_sel)
elif scoring_func == ScoringFuncType.SOFTMAX:
# === Softmax Backward ===
# Derivative: dx = y * (g - dot(g, y))
# Note: Even though Softmax was done on N (Global), since g is 0 for unselected
# indices, the dot product over just the TopK selected elements is correct.

# scatter-only: dx[idx] += y_sel * (g_sel - dot)
add_vals = s_sel * (g_sel - dot)
# dot = sum_j g_j * y_j over selected columns
dot = tl.sum(g_sel * s_sel, axis=0)

# scatter-only: dx[idx] += y_sel * (g_sel - dot)
add_vals = s_sel * (g_sel - dot)

# Calculate pointers to the full gradient matrix
indices = row * stride_dm + idx * stride_dn

# Accumulate into existing gradients if necessary
if not dlogits_is_none:
add_vals += tl.load(dlogits_ptr + indices, mask=k_mask)

# Store result
tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask)


Expand All @@ -533,10 +553,11 @@ def _softmax_topk_bwd(
topk_router_score: torch.Tensor,
topk_router_indices: torch.Tensor,
K: int,
scoring_func: ScoringFuncType,
) -> None:
T = dtopk_score.shape[0]

_softmax_bwd_scatter_small_kernel[T,](
_activation_bwd_scatter_small_kernel[T,](
dlogits,
dlogits_full,
topk_router_score,
Expand All @@ -553,6 +574,7 @@ def _softmax_topk_bwd(
K,
triton.next_power_of_2(K),
(dlogits is None),
scoring_func,
)


Expand Down
53 changes: 41 additions & 12 deletions sonicmoe/functional/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,21 @@
from cutlass.cute.runtime import from_dlpack
from quack.cute_dsl_utils import torch2cute_dtype_map

from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType
from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, ScoringFuncType
from ..utils import convert_torch_tensor_to_cute_tensor
from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd
from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton
from .topk_softmax import TopK_Softmax
from .topk_softmax import Sigmoid_TopK, Softmax_TopK, TopK_Softmax


@torch.library.custom_op(f"{LIBRARY_NAME}::_topk_fwd", mutates_args={"values", "indices"})
def _topk_fwd(
x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True
x: torch.Tensor,
k: int,
values: torch.Tensor,
indices: torch.Tensor,
require_softmax_fusion: bool = True,
scoring_func: ScoringFuncType = ScoringFuncType.SOFTMAX,
) -> None:
"""Top-k forward pass.
Args:
Expand All @@ -38,13 +43,25 @@ def _topk_fwd(

x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)]
current_stream = cuda.CUstream(torch.cuda.current_stream().stream_base.raw_stream)
compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion)
if compile_key not in _topk_fwd.compile_cache:
topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion)
_topk_fwd.compile_cache[compile_key] = cute.compile(
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
)
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)

if scoring_func == ScoringFuncType.SOFTMAX:
compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func)
if compile_key not in _topk_fwd.compile_cache:
topk_op = Softmax_TopK(input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func)
_topk_fwd.compile_cache[compile_key] = cute.compile(
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
)
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)
elif scoring_func == ScoringFuncType.SIGMOID:
compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func)
if compile_key not in _topk_fwd.compile_cache:
topk_op = Sigmoid_TopK(input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func)
_topk_fwd.compile_cache[compile_key] = cute.compile(
topk_op, x_tensor, values_tensor, indices_tensor, current_stream
)
_topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream)

# TODO: support TopK_Softmax and TopK_Sigmoid


_topk_fwd.compile_cache = {}
Expand Down Expand Up @@ -225,12 +242,24 @@ def _softmax_fwd_small_kernel(
f"{LIBRARY_NAME}::_softmax_topk_fwd", mutates_args={"topk_router_score", "topk_router_indices"}
)
def _softmax_topk_fwd(
router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int
router_logits: torch.Tensor,
topk_router_score: torch.Tensor,
topk_router_indices: torch.Tensor,
E: int,
K: int,
scoring_func: ScoringFuncType,
) -> None:
# T = router_logits.shape[0]
if E <= 4096 and K <= 16 and E % 8 == 0:
# fast topk-softmax fusion that covers most common MoE configs
_topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True)
_topk_fwd(
router_logits,
K,
topk_router_score,
topk_router_indices,
require_softmax_fusion=True,
scoring_func=scoring_func,
)
else:
topk_results = router_logits.topk(K, dim=-1)
topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype))
Expand Down
Loading