diff --git a/sonicmoe/enums.py b/sonicmoe/enums.py index c7d7132..6643d50 100644 --- a/sonicmoe/enums.py +++ b/sonicmoe/enums.py @@ -26,5 +26,10 @@ class ActivationType(Enum): SILU = "silu" +class ScoringFuncType(Enum): + SOFTMAX = "softmax" + SIGMOID = "sigmoid" + + def is_glu(activation_type: ActivationType): return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU] diff --git a/sonicmoe/functional/__init__.py b/sonicmoe/functional/__init__.py index a29c235..cfa6522 100644 --- a/sonicmoe/functional/__init__.py +++ b/sonicmoe/functional/__init__.py @@ -9,7 +9,7 @@ from quack.gemm_interface import gemm from ..count_cumsum import count_cumsum -from ..enums import ActivationType, is_glu +from ..enums import ActivationType, ScoringFuncType, is_glu from ..quack_utils import gemm_dgated, gemm_gated from .backward import _down_projection_backward, _softmax_topk_bwd, _token_broadcast_backward, _up_projection_backward from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward @@ -84,7 +84,9 @@ def general_routing_router_metadata( class TC_Softmax_Topk_Router_Function(torch.autograd.Function): @staticmethod - def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + ctx, router_logits: torch.Tensor, E: int, K: int, scoring_func: ScoringFuncType + ) -> tuple[torch.Tensor, torch.Tensor]: T = router_logits.size(0) # change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy @@ -92,7 +94,7 @@ def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Ten topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device) ctx.mark_non_differentiable(topk_router_indices) - _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K) + _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K, scoring_func) ctx.save_for_backward(topk_router_score, topk_router_indices) ctx.E = E @@ -108,7 +110,7 @@ def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Ten topk_router_score, topk_router_indices = ctx.saved_tensor() dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device) - _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K) + _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K, scoring_func) return dlogits @@ -445,13 +447,14 @@ def moe_TC_softmax_topk_layer( K: int, stream_id: int, activation_type: ActivationType | str = ActivationType.SWIGLU, + scoring_func: ScoringFuncType | str = ScoringFuncType.SOFTMAX, is_inference_mode_enabled: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert ((b1 is None) and (b2 is None)) or ( (b1 is not None) and (b2 is not None) ), "b1 and b2 has to be None or not None at the same time!" router_logits = F.linear(x, router_w) - topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K) + topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, router_w.size(0), K, scoring_func) expert_frequency, expert_frequency_offset = count_cumsum(topk_indices.view(-1), router_w.size(0), do_cumsum=True) ( diff --git a/sonicmoe/functional/backward.py b/sonicmoe/functional/backward.py index 12d6c98..01250c4 100644 --- a/sonicmoe/functional/backward.py +++ b/sonicmoe/functional/backward.py @@ -10,7 +10,7 @@ import triton import triton.language as tl -from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, is_glu +from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, ScoringFuncType, is_glu from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2 from .moe_config import ( HopperWgmma_MoE_Down_proj_ActGrad_Bwd, @@ -485,7 +485,7 @@ def _token_broadcast_backward( @triton.jit -def _softmax_bwd_scatter_small_kernel( +def _activation_bwd_scatter_small_kernel( dlogits_ptr, dlogits_full_ptr, score_ptr, @@ -502,6 +502,7 @@ def _softmax_bwd_scatter_small_kernel( K: tl.constexpr, BLOCK_K: tl.constexpr, dlogits_is_none: tl.constexpr, + scoring_func: tl.constexpr, ): row = tl.program_id(axis=0) @@ -509,19 +510,38 @@ def _softmax_bwd_scatter_small_kernel( k_offs = tl.arange(0, BLOCK_K) k_mask = k_offs < K + # Load indices idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32) + + # Load forward probabilities (y) and incoming gradients (g) s_sel = tl.load(score_ptr + row * stride_sm + k_offs * stride_sn, mask=k_mask, other=0).to(tl.float32) g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32) - # dot = sum_j g_j * y_j over selected columns - dot = tl.sum(g_sel * s_sel, axis=0) + if scoring_func == ScoringFuncType.SIGMOID: + # === Sigmoid Backward === + # Derivative: dx = g * y * (1 - y) + # No cross-term reduction needed + add_vals = g_sel * s_sel * (1.0 - s_sel) + elif scoring_func == ScoringFuncType.SOFTMAX: + # === Softmax Backward === + # Derivative: dx = y * (g - dot(g, y)) + # Note: Even though Softmax was done on N (Global), since g is 0 for unselected + # indices, the dot product over just the TopK selected elements is correct. - # scatter-only: dx[idx] += y_sel * (g_sel - dot) - add_vals = s_sel * (g_sel - dot) + # dot = sum_j g_j * y_j over selected columns + dot = tl.sum(g_sel * s_sel, axis=0) + + # scatter-only: dx[idx] += y_sel * (g_sel - dot) + add_vals = s_sel * (g_sel - dot) + # Calculate pointers to the full gradient matrix indices = row * stride_dm + idx * stride_dn + + # Accumulate into existing gradients if necessary if not dlogits_is_none: add_vals += tl.load(dlogits_ptr + indices, mask=k_mask) + + # Store result tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask) @@ -533,10 +553,11 @@ def _softmax_topk_bwd( topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, K: int, + scoring_func: ScoringFuncType, ) -> None: T = dtopk_score.shape[0] - _softmax_bwd_scatter_small_kernel[T,]( + _activation_bwd_scatter_small_kernel[T,]( dlogits, dlogits_full, topk_router_score, @@ -553,6 +574,7 @@ def _softmax_topk_bwd( K, triton.next_power_of_2(K), (dlogits is None), + scoring_func, ) diff --git a/sonicmoe/functional/forward.py b/sonicmoe/functional/forward.py index c9b67c5..18b7e27 100644 --- a/sonicmoe/functional/forward.py +++ b/sonicmoe/functional/forward.py @@ -10,16 +10,21 @@ from cutlass.cute.runtime import from_dlpack from quack.cute_dsl_utils import torch2cute_dtype_map -from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType +from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType, ScoringFuncType from ..utils import convert_torch_tensor_to_cute_tensor from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton -from .topk_softmax import TopK_Softmax +from .topk_softmax import Sigmoid_TopK, Softmax_TopK, TopK_Softmax @torch.library.custom_op(f"{LIBRARY_NAME}::_topk_fwd", mutates_args={"values", "indices"}) def _topk_fwd( - x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True + x: torch.Tensor, + k: int, + values: torch.Tensor, + indices: torch.Tensor, + require_softmax_fusion: bool = True, + scoring_func: ScoringFuncType = ScoringFuncType.SOFTMAX, ) -> None: """Top-k forward pass. Args: @@ -38,13 +43,25 @@ def _topk_fwd( x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)] current_stream = cuda.CUstream(torch.cuda.current_stream().stream_base.raw_stream) - compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion) - if compile_key not in _topk_fwd.compile_cache: - topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion) - _topk_fwd.compile_cache[compile_key] = cute.compile( - topk_op, x_tensor, values_tensor, indices_tensor, current_stream - ) - _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream) + + if scoring_func == ScoringFuncType.SOFTMAX: + compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func) + if compile_key not in _topk_fwd.compile_cache: + topk_op = Softmax_TopK(input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func) + _topk_fwd.compile_cache[compile_key] = cute.compile( + topk_op, x_tensor, values_tensor, indices_tensor, current_stream + ) + _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream) + elif scoring_func == ScoringFuncType.SIGMOID: + compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func) + if compile_key not in _topk_fwd.compile_cache: + topk_op = Sigmoid_TopK(input_dtype, output_dtype, N, k, require_softmax_fusion, scoring_func) + _topk_fwd.compile_cache[compile_key] = cute.compile( + topk_op, x_tensor, values_tensor, indices_tensor, current_stream + ) + _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream) + + # TODO: support TopK_Softmax and TopK_Sigmoid _topk_fwd.compile_cache = {} @@ -225,12 +242,24 @@ def _softmax_fwd_small_kernel( f"{LIBRARY_NAME}::_softmax_topk_fwd", mutates_args={"topk_router_score", "topk_router_indices"} ) def _softmax_topk_fwd( - router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int + router_logits: torch.Tensor, + topk_router_score: torch.Tensor, + topk_router_indices: torch.Tensor, + E: int, + K: int, + scoring_func: ScoringFuncType, ) -> None: # T = router_logits.shape[0] if E <= 4096 and K <= 16 and E % 8 == 0: # fast topk-softmax fusion that covers most common MoE configs - _topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True) + _topk_fwd( + router_logits, + K, + topk_router_score, + topk_router_indices, + require_softmax_fusion=True, + scoring_func=scoring_func, + ) else: topk_results = router_logits.topk(K, dim=-1) topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype)) diff --git a/sonicmoe/functional/topk_softmax.py b/sonicmoe/functional/topk_softmax.py index 5cad3f2..5b7108b 100644 --- a/sonicmoe/functional/topk_softmax.py +++ b/sonicmoe/functional/topk_softmax.py @@ -15,7 +15,7 @@ from triton import next_power_of_2 -class TopK_Softmax: +class _BaseTopK: def __init__( self, input_dtype: Type[cutlass.Numeric], @@ -82,6 +82,190 @@ def __call__( stream=stream, ) + def _load_and_process_input( + self, + mX: cute.Tensor, + input_tv_layout: cute.Layout, + input_tiler_mn: cute.Shape, + ): + """Load input data from global memory and convert to f32.""" + import cute + + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # --- Data Loading --- + mX = utils.domain_offset_i64((bidx * input_tiler_mn[0], 0), mX) + gX = cute.local_tile(mX, input_tiler_mn, (0, 0)) + cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0)) + + copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx) + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + tXrX = cute.make_fragment_like(tXgX) + + is_even_N = const_expr(shape[1] == input_tiler_mn[1]) + tXpX = ( + utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + + tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32) + tXrX_f32.store(tXrX.load().to(cutlass.Float32)) + + return tXrX_f32, tXcX, is_even_N, tXpX, shape + + def _apply_activation(self, tXrX_f32: cute.Fragment, threads_per_row: int, input_tv_layout: cute.Layout): + """Apply activation function. Override in subclasses.""" + raise NotImplementedError + + def _encode_indices( + self, + tXrX_f32: cute.Fragment, + tXcX: cute.Tensor, + input_tv_layout: cute.Layout, + ): + """Encode indices into the bottom bits of values.""" + log_N = int(math.log2(self.next_power_of_2_N)) + idx_mask = const_expr((1 << log_N) - 1) + input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) + tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32) + + for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True): + col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize) + encoded_idx = ~col_idx + encoded_idx = encoded_idx & idx_mask + tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx + + return tXrX_u32, idx_mask + + def _extract_and_clean_indices( + self, topk_vals: cute.Fragment, topk_vals_u32: cute.Fragment, idx_mask: int, k: int + ): + """Extract indices and clean values after top-k.""" + topk_indices = cute.make_fragment(k, cutlass.Int32) + for i in cutlass.range_constexpr(k): + encoded_idx = topk_vals_u32[i] & idx_mask + topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask + col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + topk_indices[i] = cutlass.Int32(col_idx & idx_mask) + return topk_indices + + def _store_results( + self, + topk_vals: cute.Fragment, + topk_vals_out: cute.Fragment, + topk_indices: cute.Fragment, + mValues: cute.Tensor, + mIndices: cute.Tensor, + tXcX: cute.Tensor, + shape: tuple, + output_tv_layout: cute.Layout, + ): + """Store top-k results to global memory.""" + row = tXcX[0][0] + output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0]) + if row < shape[0] and tXcX[0][1] == 0: + elems_per_store = const_expr(math.gcd(output_vecsize, self.k)) + mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): + cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + + +class TopK_Softmax(_BaseTopK): + # @cute.kernel + # def kernel( + # self, + # mX: cute.Tensor, + # mValues: cute.Tensor, + # mIndices: cute.Tensor, + # input_tv_layout: cute.Layout, + # input_tiler_mn: cute.Shape, + # output_tv_layout: cute.Layout, + # output_tiler_mn: cute.Shape, + # ): + # tXrX_f32, tXcX, is_even_N, tXpX, shape = self._load_and_process_input(mX, input_tv_layout, input_tiler_mn) + + # # Encode the indices into the bottom bits of values. + # log_N = int(math.log2(self.next_power_of_2_N)) + # idx_mask = const_expr((1 << log_N) - 1) + # input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) + # tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32) + # # Encode indices into the last log_N bits of tXrX_u32 + # for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True): + # # tXcX only keeps track of the indices for every @vecsize elements + # col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize) + # # If positive, invert the bits of the index, so that if there's a tie, + # # indices coming from a earlier column will win. + # encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx + # # Mask to keep only the last log_N bits of the encoded index + # encoded_idx = encoded_idx & idx_mask + # # Clear the last log_N bits and set them to our encoded index + # tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx + + # # Fill OOB values with -inf for top-k + # if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + # utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + # threads_per_row = input_tv_layout.shape[0][0] + # topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # # Extract indices and clean values + # topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + # topk_indices = cute.make_fragment(self.k, cutlass.Int32) + # for i in cutlass.range_constexpr(self.k): + # # Extract the encoded index from the last log_N bits + # encoded_idx = topk_vals_u32[i] & idx_mask + # # Check if original value was positive by looking at the cleaned value + # topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits + # # If positive, we need to invert the bits back to get original index + # col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + # topk_indices[i] = cutlass.Int32(col_idx & idx_mask) + + # if const_expr(self.require_softmax_fusion): + # topk_vals_max = -cutlass.Float32.inf + # for i in cutlass.range_constexpr(self.k): + # topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max) + + # topk_exp_sum = cutlass.Int32(0.0) + # for i in cutlass.range_constexpr(self.k): + # topk_vals[i] = cute.math.exp(topk_vals[i] - topk_vals_max) + # topk_exp_sum = topk_exp_sum + topk_vals[i] + + # for i in cutlass.range_constexpr(self.k): + # topk_vals[i] = topk_vals[i] / topk_exp_sum + + # # Convert cleaned values to output type + # topk_vals_out = cute.make_fragment_like(topk_indices, mValues.element_type) + # for i in cutlass.range_constexpr(self.k): + # topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + # row = tXcX[0][0] + # # Only the 1st thread in this row writes the top-k values and indices + # output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0]) + # if row < shape[0] and tXcX[0][1] == 0: + # # Vectorized write + # elems_per_store = const_expr(math.gcd(output_vecsize, self.k)) + # mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + # mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + # topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + # topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + # for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): + # cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + # cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + @cute.kernel def kernel( self, @@ -191,3 +375,345 @@ def kernel( for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + + +class Softmax_TopK(_BaseTopK): + # @cute.kernel + # def kernel( + # self, + # mX: cute.Tensor, + # mValues: cute.Tensor, + # mIndices: cute.Tensor, + # input_tv_layout: cute.Layout, + # input_tiler_mn: cute.Shape, + # output_tv_layout: cute.Layout, + # output_tiler_mn: cute.Shape, + # ): + # tXrX_f32, tXcX, is_even_N, tXpX, shape = self._load_and_process_input(mX, input_tv_layout, input_tiler_mn) + + # # --- STEP 1: Mask OOB with -inf --- + # if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + # utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + # # --- STEP 2: Softmax Activation (Row-wise Reduction) --- + # threads_per_row = input_tv_layout.shape[0][0] + + # # 2a. Find Max (Reduction) to ensure numerical stability + # row_max = -cutlass.Float32.inf + # for i in cutlass.range(cute.size(tXrX_f32)): + # row_max = cute.arch.fmax(row_max, tXrX_f32[i]) + + # # Warp Reduce Max + # mask = 1 + # while mask < threads_per_row: + # other_max = cute.arch.shfl_xor_sync(0xFFFFFFFF, row_max, mask) + # row_max = cute.arch.fmax(row_max, other_max) + # mask *= 2 + + # # 2b. Compute Exp and Sum + # row_sum = cutlass.Float32(0.0) + # for i in cutlass.range(cute.size(tXrX_f32)): + # tXrX_f32[i] = cute.math.exp(tXrX_f32[i] - row_max) + # row_sum += tXrX_f32[i] + + # # Warp Reduce Sum + # mask = 1 + # while mask < threads_per_row: + # other_sum = cute.arch.shfl_xor_sync(0xFFFFFFFF, row_sum, mask) + # row_sum += other_sum + # mask *= 2 + + # # 2c. Normalize + # inv_sum = cutlass.Float32(1.0) / row_sum + # for i in cutlass.range(cute.size(tXrX_f32)): + # tXrX_f32[i] = tXrX_f32[i] * inv_sum + + # # --- STEP 3: Encode Indices (on Probabilities) --- + # tXrX_u32, idx_mask = self._encode_indices(tXrX_f32, tXcX, input_tv_layout) + + # # --- STEP 4: TopK --- + # topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # # --- STEP 5: Extract Indices and Clean Values --- + # topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + # topk_indices = self._extract_and_clean_indices(topk_vals, topk_vals_u32, idx_mask, self.k) + + # # --- Store Results --- + # topk_vals_out = cute.make_fragment_like(topk_indices, mValues.element_type) + # for i in cutlass.range_constexpr(self.k): + # topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + # self._store_results(topk_vals, topk_vals_out, topk_indices, mValues, mIndices, tXcX, shape, output_tv_layout) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + input_tv_layout: cute.Layout, + input_tiler_mn: cute.Shape, + output_tv_layout: cute.Layout, + output_tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # --- Data Loading --- + mX = utils.domain_offset_i64((bidx * input_tiler_mn[0], 0), mX) + gX = cute.local_tile(mX, input_tiler_mn, (0, 0)) + cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0)) + + copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx) + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + tXrX = cute.make_fragment_like(tXgX) + + is_even_N = const_expr(shape[1] == input_tiler_mn[1]) + tXpX = ( + utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + + tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32) + tXrX_f32.store(tXrX.load().to(cutlass.Float32)) + + # --- STEP 1: Mask OOB with -inf --- + # This works for both Softmax and Sigmoid: + # Softmax: exp(-inf) -> 0 + # Sigmoid: 1 / (1 + exp(-(-inf))) -> 1 / (1 + inf) -> 0 + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + # --- STEP 2: Activation (Sigmoid OR Softmax) --- + threads_per_row = input_tv_layout.shape[0][0] + + # === Softmax Logic (Row-wise Reduction) === + + # 2a. Find Max (Reduction) to ensure numerical stability + row_max = -cutlass.Float32.inf + for i in cutlass.range(cute.size(tXrX_f32)): + row_max = cute.arch.fmax(row_max, tXrX_f32[i]) + + # Warp Reduce Max + mask = 1 + while mask < threads_per_row: + other_max = cute.arch.shfl_xor_sync(0xFFFFFFFF, row_max, mask) + row_max = cute.arch.fmax(row_max, other_max) + mask *= 2 + + # 2b. Compute Exp and Sum + row_sum = cutlass.Float32(0.0) + for i in cutlass.range(cute.size(tXrX_f32)): + tXrX_f32[i] = cute.math.exp(tXrX_f32[i] - row_max) + row_sum += tXrX_f32[i] + + # Warp Reduce Sum + mask = 1 + while mask < threads_per_row: + other_sum = cute.arch.shfl_xor_sync(0xFFFFFFFF, row_sum, mask) + row_sum += other_sum + mask *= 2 + + # 2c. Normalize + inv_sum = cutlass.Float32(1.0) / row_sum + for i in cutlass.range(cute.size(tXrX_f32)): + tXrX_f32[i] = tXrX_f32[i] * inv_sum + + # --- STEP 3: Encode Indices (on Probabilities) --- + log_N = int(math.log2(self.next_power_of_2_N)) + idx_mask = const_expr((1 << log_N) - 1) + input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) + tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32) + + for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True): + col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize) + # Probabilities are always >= 0 (for both Softmax and Sigmoid) + encoded_idx = ~col_idx + encoded_idx = encoded_idx & idx_mask + tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx + + # --- STEP 4: TopK --- + topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # --- STEP 5: Extract Indices and Clean Values --- + topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + topk_indices = cute.make_fragment(self.k, cutlass.Int32) + for i in cutlass.range_constexpr(self.k): + encoded_idx = topk_vals_u32[i] & idx_mask + topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask + + # Decode index + col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + topk_indices[i] = cutlass.Int32(col_idx & idx_mask) + + # --- Store Results --- + topk_vals_out = cute.make_fragment_like(topk_indices, mValues.element_type) + for i in cutlass.range_constexpr(self.k): + topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + row = tXcX[0][0] + output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0]) + if row < shape[0] and tXcX[0][1] == 0: + elems_per_store = const_expr(math.gcd(output_vecsize, self.k)) + mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): + cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + + +class Sigmoid_TopK(_BaseTopK): + # @cute.kernel + # def kernel( + # self, + # mX: cute.Tensor, + # mValues: cute.Tensor, + # mIndices: cute.Tensor, + # input_tv_layout: cute.Layout, + # input_tiler_mn: cute.Shape, + # output_tv_layout: cute.Layout, + # output_tiler_mn: cute.Shape, + # ): + # tXrX_f32, tXcX, is_even_N, tXpX, shape = self._load_and_process_input(mX, input_tv_layout, input_tiler_mn) + + # # --- STEP 1: Mask OOB with -inf --- + # if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + # utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + # # --- STEP 2: Sigmoid Activation (Element-wise) --- + # # Formula: 1 / (1 + exp(-x)) + # for i in cutlass.range(cute.size(tXrX_f32)): + # val = tXrX_f32[i] + # tXrX_f32[i] = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-val)) + + # # --- STEP 3: Encode Indices (on Probabilities) --- + # tXrX_u32, idx_mask = self._encode_indices(tXrX_f32, tXcX, input_tv_layout) + + # # --- STEP 4: TopK --- + # threads_per_row = input_tv_layout.shape[0][0] + # topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # # --- STEP 5: Extract Indices and Clean Values --- + # topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + # topk_indices = self._extract_and_clean_indices(topk_vals, topk_vals_u32, idx_mask, self.k) + + # # --- Store Results --- + # topk_vals_out = cute.make_fragment_like(topk_indices, mValues.element_type) + # for i in cutlass.range_constexpr(self.k): + # topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + # self._store_results(topk_vals, topk_vals_out, topk_indices, mValues, mIndices, tXcX, shape, output_tv_layout) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + input_tv_layout: cute.Layout, + input_tiler_mn: cute.Shape, + output_tv_layout: cute.Layout, + output_tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + + # --- Data Loading --- + mX = utils.domain_offset_i64((bidx * input_tiler_mn[0], 0), mX) + gX = cute.local_tile(mX, input_tiler_mn, (0, 0)) + cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0)) + + copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx) + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + tXrX = cute.make_fragment_like(tXgX) + + is_even_N = const_expr(shape[1] == input_tiler_mn[1]) + tXpX = ( + utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + + tXrX_f32 = cute.make_fragment(tXrX.shape, cutlass.Float32) + tXrX_f32.store(tXrX.load().to(cutlass.Float32)) + + # --- STEP 1: Mask OOB with -inf --- + # This works for both Softmax and Sigmoid: + # Softmax: exp(-inf) -> 0 + # Sigmoid: 1 / (1 + exp(-(-inf))) -> 1 / (1 + inf) -> 0 + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + # --- STEP 2: Activation (Sigmoid OR Softmax) --- + threads_per_row = input_tv_layout.shape[0][0] + + # === Sigmoid Logic (Element-wise) === + # Formula: 1 / (1 + exp(-x)) + for i in cutlass.range(cute.size(tXrX_f32)): + val = tXrX_f32[i] + # Check for -inf to avoid NaN during arithmetic if generic exp handles it poorly + # (Though standard exp(-inf) is 0, so exp(inf) is inf. 1/inf is 0. Safe.) + tXrX_f32[i] = cutlass.Float32(1.0) / (cutlass.Float32(1.0) + cute.math.exp(-val)) + + # --- STEP 3: Encode Indices (on Probabilities) --- + log_N = int(math.log2(self.next_power_of_2_N)) + idx_mask = const_expr((1 << log_N) - 1) + input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) + tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32) + + for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True): + col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize) + # Probabilities are always >= 0 (for both Softmax and Sigmoid) + encoded_idx = ~col_idx + encoded_idx = encoded_idx & idx_mask + tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx + + # --- STEP 4: TopK --- + topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # --- STEP 5: Extract Indices and Clean Values --- + topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + topk_indices = cute.make_fragment(self.k, cutlass.Int32) + for i in cutlass.range_constexpr(self.k): + encoded_idx = topk_vals_u32[i] & idx_mask + topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask + + # Decode index + col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + topk_indices[i] = cutlass.Int32(col_idx & idx_mask) + + # --- Store Results --- + topk_vals_out = cute.make_fragment_like(topk_indices, mValues.element_type) + for i in cutlass.range_constexpr(self.k): + topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + row = tXcX[0][0] + output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0]) + if row < shape[0] and tXcX[0][1] == 0: + elems_per_store = const_expr(math.gcd(output_vecsize, self.k)) + mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): + cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) diff --git a/sonicmoe/moe.py b/sonicmoe/moe.py index 65f3218..3de82e9 100644 --- a/sonicmoe/moe.py +++ b/sonicmoe/moe.py @@ -5,13 +5,12 @@ from typing import Callable import paddle - import torch import torch.nn as nn import torch.nn.functional as F from .count_cumsum import count_cumsum -from .enums import ActivationType, KernelBackendMoE, is_glu +from .enums import ActivationType, KernelBackendMoE, ScoringFuncType, is_glu from .functional import moe_TC_softmax_topk_layer @@ -174,6 +173,7 @@ def __init__( hidden_size: int, intermediate_size: int, activation_function: ActivationType, + scoring_func: ScoringFuncType, add_bias: bool, std: float, ) -> None: @@ -188,6 +188,7 @@ def __init__( self.router = nn.Linear(in_features=self.hidden_size, out_features=num_experts, bias=False) self.activation_function = activation_function + self.scoring_func = scoring_func self.c_fc = Experts( num_experts=num_experts, @@ -229,6 +230,7 @@ def forward( self.top_k, self.stream_id, self.activation_function, + self.scoring_func, is_inference_mode or not self.training, ) else: