2020
2121
2222@dataclass
23- class EpilogueSpecs :
23+ class FnSpecs :
2424 name : str
2525 fn : "triton.runtime.jit.JITFunction"
2626 fn_arg_names : tuple [str ]
2727 fn_arg_do_not_specialize : tuple [str ] = tuple ()
2828
29+ @staticmethod
30+ def default ():
31+ return FnSpecs ("dflt" , None , tuple ())
32+
33+
34+ @dataclass
35+ class FusedActivation :
36+ specs : FnSpecs
37+ fn_args : tuple [object ]
38+ reduction_n : int
39+
2940
3041@dataclass
3142class Epilogue :
32- specs : EpilogueSpecs
43+ specs : FnSpecs
3344 fn_arg_values_matmul : tuple [object ]
3445 fn_arg_values_finalize : tuple [object ]
3546 is_expensive : bool = False
3647
3748
49+ EpilogueSpecs = FnSpecs # TODO: remove this alias when callers are updated
50+
3851_kernels = dict ()
3952
4053
41- def get_kernels (epilogue : EpilogueSpecs ):
54+ def get_kernels (epilogue : FnSpecs = FnSpecs . default (), fused_activation : FnSpecs = FnSpecs . default () ):
4255 global _kernels
43- if epilogue .name in _kernels :
44- return _kernels [epilogue .name ]
45- spec_constants = {"EPILOGUE_FN" : epilogue .fn }
46- spec_tuples = {"epilogue_fn_args" : epilogue .fn_arg_names }
47- do_not_specialize = epilogue .fn_arg_do_not_specialize
56+ key = (fused_activation .name , epilogue .name )
57+ if key in _kernels :
58+ return _kernels [key ]
59+ spec_constants = {
60+ "ACTIVATION_FN" : fused_activation .fn ,
61+ "EPILOGUE_FN" : epilogue .fn ,
62+ }
63+ spec_tuples = {
64+ "activation_fn_args" : fused_activation .fn_arg_names ,
65+ "epilogue_fn_args" : epilogue .fn_arg_names ,
66+ }
67+ do_not_specialize = fused_activation .fn_arg_do_not_specialize + epilogue .fn_arg_do_not_specialize
4868 import types
4969
50- module = types .ModuleType (f"matmul_ogs_{ epilogue . name } " )
70+ module = types .ModuleType (f"matmul_ogs_{ '_' . join ( key ) } " )
5171 sys .modules [module .__name__ ] = module
5272 module ._finalize_matmul = specialize (_finalize_matmul , module , spec_constants , spec_tuples ,
5373 do_not_specialize = do_not_specialize )
5474 module ._matmul_ogs = specialize (_matmul_ogs , module , spec_constants , spec_tuples ,
5575 do_not_specialize = do_not_specialize )
5676 module ._p_matmul_ogs = specialize (_p_matmul_ogs , module , spec_constants , spec_tuples ,
5777 do_not_specialize = do_not_specialize )
58- _kernels [epilogue . name ] = module
78+ _kernels [key ] = module
5979 return module
6080
6181
@@ -254,8 +274,8 @@ def can_use_persistent_tma(x, w, gather_indx, precision_config):
254274 and mx_ctx .swizzle_value is None
255275 )
256276
257- def can_use_fused_scatter (scatter_indx ):
258- return scatter_indx is not None
277+ def can_use_fused_scatter (scatter_indx , fused_activation ):
278+ return scatter_indx is not None and fused_activation . specs . fn is None
259279
260280# ---------------------
261281# Preprocessing
@@ -341,7 +361,7 @@ def init_postprocessing_features(routing_data, scatter_indx, opt_flags):
341361 return PostprocessingFeatures (finalize )
342362
343363def apply_postprocessing_features (scatter_indx , finalize_scatter_idxs , opt_flags , expt_offs , num_indx , precision_config , routing_data ,
344- postprocess_features , memory , epilogue ):
364+ postprocess_features , memory , fused_activation , epilogue ):
345365 out = memory ["output" ]
346366 flex_ctx = precision_config .flex_ctx
347367 if postprocess_features .finalize :
@@ -407,14 +427,15 @@ def compute_grid(BLOCK_N, num_warps):
407427 grid , (BLOCK_N , num_warps ) = sorted ([(compute_grid (* c ), c ) for c in candidates ], key = lambda x : x [0 ][1 ])[0 ]
408428 STAGES = 1 if num_warps == 1 else min (triton .cdiv (triton .cdiv (N , BLOCK_N ), grid [1 ]), 5 )
409429
410- kernels = get_kernels (epilogue .specs )
430+ kernels = get_kernels (epilogue .specs , fused_activation . specs )
411431 kernels ._finalize_matmul [grid ](
412432 flex_ctx .out_data .reinterpret (out_scatter ),
413433 * out_scatter_flex ,
414434 flex_ctx .out_data .reinterpret (inp ), inp .stride (0 ), inp .stride (2 ),
415435 inp_flex .expected_scale ,
416436 scatter_src_indx , finalize_scatter_idxs ,
417437 inp .shape [0 ], M , N , num_rows ,
438+ * fused_activation .fn_args , fused_activation .reduction_n ,
418439 * epilogue .fn_arg_values_finalize ,
419440 EXPT_PER_TOK = EXPT_PER_TOK ,
420441 BLOCK_N = BLOCK_N ,
@@ -443,7 +464,7 @@ class MatmulAllocation:
443464 output : tuple [tuple [int ], torch .dtype ]
444465 scratchpads : dict [str , tuple ]
445466
446- def init_allocation (x , w , precision_config , routing_data , gather_indx , scatter_indx , opt_flags ,
467+ def init_allocation (x , w , precision_config , fused_activation , routing_data , gather_indx , scatter_indx , opt_flags ,
447468 preprocessing_features , postprocessing_features ):
448469 # ---- output ------
449470 N = precision_config .mx_ctx .get_packed_tensor_logical_shape (w )[- 1 ]
@@ -462,7 +483,7 @@ def init_allocation(x, w, precision_config, routing_data, gather_indx, scatter_i
462483 else :
463484 Mc = scatter_indx .src_indx .shape [0 ] // routing_data .n_expts_act # compressed number of rows
464485 y_rows = Mc
465- y_shape = (x .shape [0 ], y_rows , N )
486+ y_shape = (x .shape [0 ], y_rows , N // fused_activation . reduction_n )
466487 out_dtype = precision_config .out_dtype or x .dtype
467488 output = (y_shape , out_dtype )
468489 # ---- scratchpad -----#
@@ -500,6 +521,7 @@ def matmul_ogs(x, w, bias,
500521 gammas : torch .Tensor | None = None ,
501522 out_alpha : float | None = None ,
502523 y : torch .Tensor | None = None ,
524+ fused_activation : FusedActivation | None = None ,
503525 epilogue : Epilogue | None = None ,
504526 ):
505527 """
@@ -516,9 +538,10 @@ def matmul_ogs(x, w, bias,
516538 assert w .ndim == 3 and w .shape [0 ] == x .shape [0 ]
517539 if precision_config is None :
518540 precision_config = PrecisionConfig ()
541+ if fused_activation is None :
542+ fused_activation = FusedActivation (FnSpecs .default (), tuple (), 1 )
519543 if epilogue is None :
520- epilogue_specs = EpilogueSpecs ("dflt" , None , tuple (), tuple ())
521- epilogue = Epilogue (epilogue_specs , tuple (), tuple (), False )
544+ epilogue = Epilogue (FnSpecs .default (), tuple (), tuple (), False )
522545 if w .ndim == 2 :
523546 w = w .view (1 , w .shape [- 2 ], w .shape [- 1 ])
524547 if x .ndim == 2 :
@@ -540,7 +563,7 @@ def matmul_ogs(x, w, bias,
540563 opt_flags = make_opt_flags (out_dtype , x .dtype , w .dtype , precision_config ,
541564 M , N , K , routing_data ,
542565 can_use_persistent_tma (x , w , gather_indx , precision_config ),
543- can_use_fused_scatter (scatter_indx ),
566+ can_use_fused_scatter (scatter_indx , fused_activation ),
544567 epilogue .is_expensive ,
545568 )
546569 # compute grid size
@@ -551,25 +574,27 @@ def matmul_ogs(x, w, bias,
551574 grid_n = triton .cdiv (N , opt_flags .block_n )
552575 assert n_expts_tot == routing_data .n_expts_tot
553576 assert grid_m > 0
554- assert x .dtype == w .dtype or mx_ctx .weight_scale is not None
555577 # determine necessary pre/post processing
556578 preprocessing_features = init_preprocessing_features (w , precision_config , opt_flags )
557579 postprocessing_features = init_postprocessing_features (routing_data , scatter_indx , opt_flags )
558580 # allocate output/scratchpad memory
559- allocation = init_allocation (x , w , precision_config , routing_data , gather_indx , scatter_indx , opt_flags ,
581+ allocation = init_allocation (x , w , precision_config , fused_activation , routing_data , gather_indx , scatter_indx , opt_flags ,
560582 preprocessing_features , postprocessing_features )
561583 memory = apply_allocation (allocation , y )
562584 # TMA descriptors require a global memory allocation
563585 if opt_flags .is_persistent :
564586 triton .set_allocator (get_per_device_per_stream_alloc_fn (x .device ))
565587 # Intermediate tensors and postprocess kernels for each situation
566588 out0 , out0_flex = memory ["output" ], precision_config .flex_ctx .out_data
589+ fused_postprocess_activation = FusedActivation (FnSpecs .default (), tuple (), 1 )
567590 if postprocessing_features .finalize :
568591 if opt_flags .fused_scatter :
569592 out0 = memory ["output" ]
570593 else :
571594 out0 = memory ["scratchpad" ]["matmul" ]
572595 out0_flex = OutFlexData () if out0 .dtype == torch .float32 else precision_config .flex_ctx .out_data
596+
597+ fused_activation , fused_postprocess_activation = fused_postprocess_activation , fused_activation
573598 # pre-processing
574599 x , w , swap_xw , writeback_idxs , writeback_size , finalize_scatter_idxs , expt_data = apply_preprocessing_features (
575600 x , w , gather_indx , scatter_indx , routing_data , opt_flags , preprocessing_features
@@ -584,7 +609,7 @@ def matmul_ogs(x, w, bias,
584609 flex = precision_config .flex_ctx
585610 bias_stride = None if bias is None else bias .stride (0 )
586611 num_indx = None if scatter_indx is None else scatter_indx .src_indx .shape [0 ]
587- kernels = get_kernels (epilogue .specs )
612+ kernels = get_kernels (epilogue .specs , fused_activation . specs )
588613 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
589614 flex .out_data .reinterpret (memory ["output" ]),
590615 flex .out_data .reinterpret (out0 ), * out0 .stride (),
@@ -606,6 +631,7 @@ def matmul_ogs(x, w, bias,
606631 expt_data .hist , expt_data .offs , expt_data .offs_sum , expt_data .blocks ,
607632 batch_size , grid_m , grid_n ,
608633 out_alpha ,
634+ * fused_activation .fn_args , fused_activation .reduction_n ,
609635 * epilogue .fn_arg_values_matmul ,
610636 routing_data .n_expts_tot , routing_data .n_expts_act ,
611637 precision_config .max_num_imprecise_acc ,
@@ -635,7 +661,7 @@ def matmul_ogs(x, w, bias,
635661 # post-processing
636662 out = apply_postprocessing_features (scatter_indx , finalize_scatter_idxs , opt_flags , expt_data .offs ,
637663 num_indx , precision_config , routing_data ,
638- postprocessing_features , memory , epilogue )
664+ postprocessing_features , memory , fused_postprocess_activation , epilogue )
639665
640666 # remove split-k
641667 out = out .squeeze (0 )
0 commit comments