1
1
from dataclasses import dataclass
2
2
import itertools
3
+ import math
3
4
import sys
4
5
import torch
5
6
import triton
6
7
# utilities
7
8
from triton_kernels import target_info
8
9
from triton_kernels .numerics import InFlexData , OutFlexData
9
- from triton_kernels .routing import GatherIndx , RoutingData , ScatterIndx
10
+ from triton_kernels .routing import ExptData , GatherIndx , RoutingData , ScatterIndx
11
+ from triton .tools .tensor_descriptor import TensorDescriptor
10
12
# details
11
13
from .matmul_ogs_details ._matmul_ogs import _compute_writeback_idx
12
14
from .matmul_ogs_details ._matmul_ogs import _matmul_ogs
13
15
from .matmul_ogs_details ._p_matmul_ogs import _p_matmul_ogs , get_per_device_per_stream_alloc_fn
14
16
from .matmul_ogs_details ._finalize_matmul import _finalize_matmul
15
- from .matmul_ogs_details .opt_flags import make_opt_flags
17
+ from .matmul_ogs_details .opt_flags import make_opt_flags , OptFlags
16
18
from .matmul_ogs_details .fast_contiguous import fast_contiguous
17
19
from .numerics_details .mxfp import SwizzlingType
18
20
from .specialize import specialize
21
+ from typing import Tuple , Optional
19
22
20
23
21
24
@dataclass
@@ -95,6 +98,84 @@ def should_upcast_indices(*args):
95
98
return any (tensor is not None and can_overflow_int32 (tensor ) for tensor in args )
96
99
97
100
101
+ class TensorDescriptorBuilder :
102
+ """Builder for creating different types of tensor descriptors"""
103
+
104
+ @staticmethod
105
+ def create_basic_descriptor (tensor : torch .Tensor , block_shape : Tuple [int , ...],
106
+ transpose : bool = False ) -> TensorDescriptor :
107
+ """Create a basic tensor descriptor with optional transpose"""
108
+ if transpose :
109
+ block_shape = block_shape [:- 2 ] + [block_shape [- 1 ], block_shape [- 2 ]]
110
+ tensor = tensor .permute (0 , 2 , 1 )
111
+ return TensorDescriptor .from_tensor (tensor , block_shape = block_shape )
112
+
113
+ @staticmethod
114
+ def create_weight_descriptor (w_tensor : torch .Tensor , block_k : int , block_n : int ,
115
+ transpose : bool ) -> TensorDescriptor :
116
+ """Create a tensor descriptor for weight matrix"""
117
+ # Two e2m1 packed in a uint8 or a single fp8
118
+ W_PACK_DIVISOR = 2 if w_tensor .dtype == torch .uint8 else 1
119
+ PACKED_BLOCK_K_W = block_k // W_PACK_DIVISOR
120
+ return TensorDescriptorBuilder .create_basic_descriptor (w_tensor , block_shape = [1 , PACKED_BLOCK_K_W , block_n ],
121
+ transpose = transpose )
122
+
123
+ @staticmethod
124
+ def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , K : int , N : int ,
125
+ mx_scale_stride_k : int , mx_scale_stride_n : int , n_expts_tot : int , batch_size : int ,
126
+ expt_data : Optional [ExptData ], swizzle_mx : bool ,
127
+ transpose : bool ) -> TensorDescriptor :
128
+ """Create a tensor descriptor for block scale factors"""
129
+ MX_PACK_DIVISOR = 32
130
+ MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131
+ PackedK = (K + MX_PACK_DIVISOR - 1 ) // MX_PACK_DIVISOR
132
+
133
+ if swizzle_mx :
134
+ num_expt_x_ncol = (n_expts_tot if expt_data is not None and len (expt_data .block_pid_map ) > 0 else
135
+ batch_size ) * ((N + 127 ) // 128 )
136
+ return TensorDescriptor (
137
+ base = mx_tensor , shape = [1 , num_expt_x_ncol , (PackedK + 3 ) // 4 , 2 , 256 ],
138
+ strides = [num_expt_x_ncol * mx_scale_stride_n , mx_scale_stride_n , mx_scale_stride_k , 256 ,
139
+ 1 ], block_shape = [1 , block_n // 128 , MX_SCALE_BLOCK_K // 4 , 2 , 256 ])
140
+ else :
141
+ # Non-optimal SF layout, expect slow transfers
142
+ # from global to shmem and from shmem to tmem
143
+ return TensorDescriptorBuilder .create_basic_descriptor (mx_tensor ,
144
+ block_shape = [1 , MX_SCALE_BLOCK_K ,
145
+ block_n ], transpose = transpose )
146
+
147
+ @staticmethod
148
+ def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
149
+ block_k : int ) -> TensorDescriptor :
150
+ """Create a tensor descriptor for input matrix X via TMA gather"""
151
+ x_desc = x_tensor .squeeze ()
152
+ assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
153
+ INT_MAX = 2147483647
154
+ return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
155
+ block_shape = [1 , block_k ])
156
+
157
+ @staticmethod
158
+ def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
159
+ block_k : int ) -> TensorDescriptor :
160
+ """Create a tensor descriptor for input matrix X via TMA"""
161
+ x_desc = x_tensor .squeeze ()
162
+ assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
163
+ return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
164
+ block_shape = [block_m , block_k ])
165
+
166
+ @staticmethod
167
+ def create_input_descriptor (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_k : int ,
168
+ block_m : int , use_gather_tma : bool , use_load_tma : bool ) -> TensorDescriptor :
169
+ """Create a tensor descriptor for input matrix X based on TMA usage"""
170
+ if use_gather_tma :
171
+ return TensorDescriptorBuilder .create_input_descriptor_gather (x_tensor , K , x_stride_1 , x_stride_2 , block_k )
172
+ elif use_load_tma :
173
+ return TensorDescriptorBuilder .create_input_descriptor_load (x_tensor , K , x_stride_1 , x_stride_2 , block_m ,
174
+ block_k )
175
+ else :
176
+ return x_tensor
177
+
178
+
98
179
# ---------------------
99
180
# Numerics
100
181
# ---------------------
@@ -490,7 +571,6 @@ def init_allocation(x, w, precision_config, fused_activation, routing_data, gath
490
571
scratchpad ["matmul" ] = ((opt_flags .split_k , x .shape [0 ], M , N ), dtype )
491
572
return MatmulAllocation (x .device , output , scratchpad )
492
573
493
-
494
574
def apply_allocation (allocation : MatmulAllocation , output ):
495
575
ret = dict ()
496
576
if output is None :
@@ -504,10 +584,82 @@ def apply_allocation(allocation: MatmulAllocation, output):
504
584
}
505
585
return ret
506
586
587
+
507
588
# -----------------------------------------------------------------------------
508
589
# Triton Implementation
509
590
# -----------------------------------------------------------------------------
510
591
592
+ def _create_tma_descriptors (
593
+ x : torch .Tensor ,
594
+ x_tensor : torch .Tensor ,
595
+ w_tensor : torch .Tensor ,
596
+ mx_tensor : Optional [torch .Tensor ],
597
+ routing_data : RoutingData ,
598
+ mx_ctx : MicroscalingCtx ,
599
+ expt_data : ExptData ,
600
+ opt_flags : OptFlags ,
601
+ batch_size : int ,
602
+ K : int ,
603
+ N : int ,
604
+ mx_scale_stride_k : int ,
605
+ mx_scale_stride_n : int ,
606
+ USE_GATHER_TMA : bool ,
607
+ X_USE_LOAD_TMA : bool ,
608
+ w_transpose : bool ,
609
+ mx_transpose : bool ,
610
+ ) -> Tuple [bool , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
611
+ """Create and cache TMA descriptors for tensors."""
612
+ use_host_tma_descriptors = opt_flags .is_persistent and target_info .cuda_capability_geq (10 , 0 )
613
+
614
+ x_desc , w_desc = [None ] * 2
615
+ descriptors = []
616
+ # The dense case currently uses on device descriptor updates
617
+ # so we bail out on using host descriptors in that case
618
+ if (use_host_tma_descriptors ):
619
+ if USE_GATHER_TMA or X_USE_LOAD_TMA :
620
+ x_desc = TensorDescriptorBuilder .create_input_descriptor (
621
+ x_tensor , K , x .stride (1 ), x .stride (2 ),
622
+ opt_flags .block_k , opt_flags .block_m ,
623
+ USE_GATHER_TMA , X_USE_LOAD_TMA
624
+ )
625
+ descriptors .append (x_desc )
626
+ if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
627
+ w_desc = TensorDescriptorBuilder .create_weight_descriptor (
628
+ w_tensor , opt_flags .block_k , opt_flags .block_n , w_transpose
629
+ )
630
+ is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w_tensor .dtype == torch .uint8 )
631
+ if is_microscaled_format :
632
+ # Pad the inner shape to 128 for mxfp4 weights
633
+ # for mixed precision fp8 x mxfp4 compute
634
+ pad = 128
635
+ dim_to_pad = - 1 if w_transpose else - 2
636
+ old_size = w_desc .shape [dim_to_pad ]
637
+ padded_size = math .ceil (old_size / pad ) * pad
638
+ if padded_size != old_size :
639
+ w_desc .shape = list (w_desc .shape )
640
+ w_desc .shape [dim_to_pad ] = padded_size
641
+ descriptors .append (w_desc )
642
+ # Optional MX scale descriptor
643
+ descriptors .append (None )
644
+ if mx_tensor is not None :
645
+ descriptors [- 1 ] = TensorDescriptorBuilder .create_block_scale_descriptor (
646
+ mx_tensor , opt_flags .block_k , opt_flags .block_n , K , N ,
647
+ mx_scale_stride_k , mx_scale_stride_n , routing_data .n_expts_tot ,
648
+ batch_size ,
649
+ expt_data , mx_ctx .swizzle_scale , mx_transpose
650
+ )
651
+
652
+ # TODO: Currently all or none, instead should support a mixture
653
+ # of host and device descriptors
654
+ if None in descriptors or len (descriptors ) == 0 :
655
+ descriptors = [x_tensor , w_tensor , mx_tensor ]
656
+ use_host_tma_descriptors = False
657
+ if opt_flags .is_persistent :
658
+ opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
659
+
660
+ return use_host_tma_descriptors , * descriptors
661
+
662
+
511
663
def matmul_ogs (x , w , bias ,
512
664
routing_data : RoutingData | None = None ,
513
665
gather_indx : GatherIndx | None = None ,
@@ -601,22 +753,47 @@ def matmul_ogs(x, w, bias,
601
753
flex = precision_config .flex_ctx
602
754
bias_stride = None if bias is None else bias .stride (0 )
603
755
num_indx = None if scatter_indx is None else scatter_indx .src_indx .shape [0 ]
756
+
604
757
kernels = get_kernels (epilogue .specs , fused_activation .specs )
605
758
expt_data = routing_data .expt_data
606
759
block_m = opt_flags .block_m
607
760
expt_hist = None if expt_data is None else expt_data .hist
608
761
expt_hist_sum = None if expt_data is None else expt_data .token_offs_pad [block_m ][- 1 ]
609
762
expt_token_offs_raw = None if expt_data is None else expt_data .token_offs_raw
610
763
expt_block_pid_map = None if expt_data is None else expt_data .block_pid_map [block_m ]
764
+
765
+ HAS_TMA_GS = target_info .cuda_capability_geq (10 , 0 )
766
+ USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
767
+ X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
768
+ _ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
769
+ x = x ,
770
+ x_tensor = flex .lhs_data .reinterpret (x ),
771
+ w_tensor = flex .rhs_data .reinterpret (w ),
772
+ mx_tensor = mx_ctx .weight_scale ,
773
+ routing_data = routing_data ,
774
+ mx_ctx = mx_ctx ,
775
+ expt_data = expt_data ,
776
+ opt_flags = opt_flags ,
777
+ batch_size = batch_size ,
778
+ K = K ,
779
+ N = N ,
780
+ mx_scale_stride_k = mx_scale_stride_k ,
781
+ mx_scale_stride_n = mx_scale_stride_n ,
782
+ USE_GATHER_TMA = USE_GATHER_TMA ,
783
+ X_USE_LOAD_TMA = X_USE_LOAD_TMA ,
784
+ w_transpose = w .stride (2 ) != 1 ,
785
+ mx_transpose = mx_scale_stride_n != 1 ,
786
+ )
787
+
611
788
(kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
612
789
flex .out_data .reinterpret (memory ["output" ]),
613
790
flex .out_data .reinterpret (out0 ), * out0 .stride (),
614
791
* out0_flex ,
615
- flex . lhs_data . reinterpret ( x ) , x .stride (0 ), x .stride (1 ), x .stride (2 ),
792
+ x_tensor , x .stride (0 ), x .stride (1 ), x .stride (2 ),
616
793
flex .lhs_data .scale ,
617
- flex . rhs_data . reinterpret ( w ) , w .stride (0 ), w .stride (1 ), w .stride (2 ), w .stride (2 ) != 1 ,
794
+ w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w .stride (2 ) != 1 ,
618
795
flex .rhs_data .scale ,
619
- mx_ctx . weight_scale , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
796
+ mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
620
797
bias , bias_stride ,
621
798
x .shape [1 ],
622
799
x .shape [1 ] if routing_data .expt_hist is None else None ,
0 commit comments