1
1
from dataclasses import dataclass
2
2
import itertools
3
- import math
4
3
import sys
5
4
import torch
6
5
import triton
@@ -121,20 +120,19 @@ def create_weight_descriptor(w_tensor: torch.Tensor, block_k: int, block_n: int,
121
120
transpose = transpose )
122
121
123
122
@staticmethod
124
- def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , K : int , N : int ,
125
- mx_scale_stride_k : int , mx_scale_stride_n : int , n_expts_tot : int , batch_size : int ,
126
- expt_data : Optional [ExptData ], swizzle_mx : bool ,
127
- transpose : bool ) -> TensorDescriptor :
123
+ def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , B : int , K : int , N : int ,
124
+ mx_scale_stride_k : int , mx_scale_stride_n : int , swizzle_mx : bool ,
125
+ transpose : Optional [bool ]) -> TensorDescriptor :
128
126
"""Create a tensor descriptor for block scale factors"""
129
127
MX_PACK_DIVISOR = 32
130
128
MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131
129
PackedK = (K + MX_PACK_DIVISOR - 1 ) // MX_PACK_DIVISOR
132
130
133
131
if swizzle_mx :
134
- num_expt_x_ncol = ( n_expts_tot if expt_data is not None and len ( expt_data . block_pid_map ) > 0 else
135
- batch_size ) * (( N + 127 ) // 128 )
132
+ assert transpose is None
133
+ num_expt_x_ncol = B * triton . cdiv ( N , 128 )
136
134
return TensorDescriptor (
137
- base = mx_tensor , shape = [1 , num_expt_x_ncol , (PackedK + 3 ) // 4 , 2 , 256 ],
135
+ base = mx_tensor , shape = [1 , num_expt_x_ncol , triton . cdiv (PackedK , 4 ) , 2 , 256 ],
138
136
strides = [num_expt_x_ncol * mx_scale_stride_n , mx_scale_stride_n , mx_scale_stride_k , 256 ,
139
137
1 ], block_shape = [1 , block_n // 128 , MX_SCALE_BLOCK_K // 4 , 2 , 256 ])
140
138
else :
@@ -151,35 +149,12 @@ def squeeze_after_dim(x, dim=2):
151
149
return x .view (* new_shape )
152
150
153
151
@staticmethod
154
- def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
155
- block_k : int ) -> TensorDescriptor :
156
- """Create a tensor descriptor for input matrix X via TMA gather"""
157
- x_desc = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
158
- assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
159
- INT_MAX = 2147483647
160
- return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
161
- block_shape = [1 , block_k ])
162
-
163
- @staticmethod
164
- def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
165
- block_k : int ) -> TensorDescriptor :
166
- """Create a tensor descriptor for input matrix X via TMA"""
167
- x_desc = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
168
- assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
169
- return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
170
- block_shape = [block_m , block_k ])
171
-
172
- @staticmethod
173
- def create_input_descriptor (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_k : int ,
174
- block_m : int , use_gather_tma : bool , use_load_tma : bool ) -> TensorDescriptor :
175
- """Create a tensor descriptor for input matrix X based on TMA usage"""
176
- if use_gather_tma :
177
- return TensorDescriptorBuilder .create_input_descriptor_gather (x_tensor , K , x_stride_1 , x_stride_2 , block_k )
178
- elif use_load_tma :
179
- return TensorDescriptorBuilder .create_input_descriptor_load (x_tensor , K , x_stride_1 , x_stride_2 , block_m ,
180
- block_k )
181
- else :
182
- return x_tensor
152
+ def create_descriptor (x_tensor : torch .Tensor , block_m : int , block_k : int ) -> TensorDescriptor :
153
+ """Create a tensor descriptor for matrix X via TMA"""
154
+ x_tensor = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
155
+ assert x_tensor .ndim in [2 , 3 ], "TMA descriptor builder expects 2D or 3D input"
156
+ block_shape = [1 ] * (x_tensor .ndim - 2 ) + [block_m , block_k ]
157
+ return TensorDescriptor .from_tensor (x_tensor , block_shape = block_shape )
183
158
184
159
185
160
# ---------------------
@@ -590,66 +565,53 @@ def _create_tma_descriptors(
590
565
mx_ctx : MicroscalingCtx ,
591
566
expt_data : ExptData ,
592
567
opt_flags : OptFlags ,
593
- batch_size : int ,
568
+ B : int ,
594
569
K : int ,
595
570
N : int ,
596
571
mx_scale_stride_k : int ,
597
572
mx_scale_stride_n : int ,
598
- USE_GATHER_TMA : bool ,
599
- X_USE_LOAD_TMA : bool ,
600
- w_transpose : bool ,
601
- mx_transpose : bool ,
573
+ HAS_GATHER : bool ,
602
574
) -> Tuple [bool , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
603
575
"""Create and cache TMA descriptors for tensors."""
604
- use_host_tma_descriptors = opt_flags .is_persistent and target_info .cuda_capability_geq (10 , 0 )
605
-
606
- x_desc , w_desc = [None ] * 2
607
- descriptors = []
608
- # The dense case currently uses on device descriptor updates
609
- # so we bail out on using host descriptors in that case
610
- if (use_host_tma_descriptors ):
611
- if USE_GATHER_TMA or X_USE_LOAD_TMA :
612
- x_desc = TensorDescriptorBuilder .create_input_descriptor (
613
- x , K , x .stride (1 ), x .stride (2 ),
614
- opt_flags .block_k , opt_flags .block_m ,
615
- USE_GATHER_TMA , X_USE_LOAD_TMA
616
- )
617
- descriptors .append (x_desc )
618
- if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
619
- w_desc = TensorDescriptorBuilder .create_weight_descriptor (
620
- w , opt_flags .block_k , opt_flags .block_n , w_transpose
621
- )
622
- is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w .dtype == torch .uint8 )
623
- if is_microscaled_format :
624
- # Pad the inner shape to 128 for mxfp4 weights
625
- # for mixed precision fp8 x mxfp4 compute
626
- pad = 128
627
- dim_to_pad = - 1
628
- old_size = w_desc .shape [dim_to_pad ]
629
- padded_size = math .ceil (old_size / pad ) * pad
630
- if padded_size != old_size :
631
- w_desc .shape = list (w_desc .shape )
632
- w_desc .shape [dim_to_pad ] = padded_size
633
- descriptors .append (w_desc )
634
- # Optional MX scale descriptor
635
- descriptors .append (None )
636
- if mx_tensor is not None :
637
- descriptors [- 1 ] = TensorDescriptorBuilder .create_block_scale_descriptor (
638
- mx_tensor , opt_flags .block_k , opt_flags .block_n , K , N ,
639
- mx_scale_stride_k , mx_scale_stride_n , routing_data .n_expts_tot ,
640
- batch_size ,
641
- expt_data , mx_ctx .swizzle_scale , mx_transpose
642
- )
643
576
644
- # TODO: Currently all or none, instead should support a mixture
645
- # of host and device descriptors
646
- if None in descriptors or len (descriptors ) == 0 :
647
- descriptors = [x , w , mx_tensor ]
648
- use_host_tma_descriptors = False
649
- if opt_flags .is_persistent :
650
- opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
577
+ x_tensor_or_desc , mx_desc_and_transpose = x , (None , False )
651
578
652
- return use_host_tma_descriptors , * descriptors
579
+ if not HAS_GATHER :
580
+ x_tensor_or_desc = TensorDescriptorBuilder .create_descriptor (x , opt_flags .block_m , opt_flags .block_k )
581
+
582
+ w_transpose = w .stride (2 ) != 1
583
+ w_desc = TensorDescriptorBuilder .create_weight_descriptor (
584
+ w , opt_flags .block_k , opt_flags .block_n , w_transpose
585
+ )
586
+ w_desc_and_transpose = (w_desc , w_transpose )
587
+
588
+ is_microscaled_format = mx_ctx .weight_scale is not None and w .dtype == torch .uint8
589
+ if is_microscaled_format :
590
+ # Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
591
+ # CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
592
+ # This technically makes the shape masking incorrect, but it's fine because:
593
+ # - When the N dim is padded, the scales will be masked to 0.
594
+ # - When the K dim is padded, the activations we perform tl.dot with will be masked to 0.
595
+ # Note: the scales can't be relied on for zeroing in this case, because they apply to groups
596
+ # of 32 elements in the K dimension.
597
+ pad = 128
598
+ dim_to_pad = - 1
599
+ old_size = w_desc .shape [dim_to_pad ]
600
+ padded_size = triton .cdiv (old_size , pad ) * pad
601
+ if padded_size != old_size :
602
+ w_desc .shape = list (w_desc .shape )
603
+ w_desc .shape [dim_to_pad ] = padded_size
604
+
605
+ if mx_tensor is not None :
606
+ mx_transpose = mx_scale_stride_n != 1 if mx_ctx .swizzle_scale is None else None
607
+ mx_desc = TensorDescriptorBuilder .create_block_scale_descriptor (
608
+ mx_tensor , opt_flags .block_k , opt_flags .block_n ,
609
+ routing_data .n_expts_tot if expt_data is not None and len (expt_data .block_pid_map ) > 0 else B , K , N ,
610
+ mx_scale_stride_k , mx_scale_stride_n , mx_ctx .swizzle_scale , mx_transpose
611
+ )
612
+ mx_desc_and_transpose = (mx_desc , mx_transpose )
613
+
614
+ return x_tensor_or_desc , w_desc_and_transpose , mx_desc_and_transpose
653
615
654
616
655
617
def matmul_ogs (x , w , bias ,
@@ -754,41 +716,39 @@ def matmul_ogs(x, w, bias,
754
716
expt_token_offs_raw = None if expt_data is None else expt_data .token_offs_raw
755
717
expt_block_pid_map = None if expt_data is None else expt_data .block_pid_map [block_m ]
756
718
757
- HAS_TMA_GS = target_info . cuda_capability_geq ( 10 , 0 )
758
- USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
759
- X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
760
- _ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
761
- x = x , w = w ,
762
- mx_tensor = mx_ctx . weight_scale ,
763
- routing_data = routing_data ,
764
- mx_ctx = mx_ctx ,
765
- expt_data = expt_data ,
766
- opt_flags = opt_flags ,
767
- batch_size = batch_size ,
768
- K = K ,
769
- N = N ,
770
- mx_scale_stride_k = mx_scale_stride_k ,
771
- mx_scale_stride_n = mx_scale_stride_n ,
772
- USE_GATHER_TMA = USE_GATHER_TMA ,
773
- X_USE_LOAD_TMA = X_USE_LOAD_TMA ,
774
- w_transpose = w . stride ( 2 ) != 1 ,
775
- mx_transpose = mx_scale_stride_n != 1 ,
776
- )
719
+ if opt_flags . is_persistent :
720
+ x_tensor , w_tensor_and_transpose , mx_tensor_and_tranpose = _create_tma_descriptors (
721
+ x = x , w = w , mx_tensor = mx_ctx . weight_scale ,
722
+ routing_data = routing_data ,
723
+ mx_ctx = mx_ctx ,
724
+ expt_data = expt_data ,
725
+ opt_flags = opt_flags ,
726
+ B = batch_size ,
727
+ K = K ,
728
+ N = N ,
729
+ mx_scale_stride_k = mx_scale_stride_k ,
730
+ mx_scale_stride_n = mx_scale_stride_n ,
731
+ HAS_GATHER = gather_indx is not None ,
732
+ )
733
+ w_tensor , w_tma_transpose = w_tensor_and_transpose
734
+ mx_tensor , mx_tma_transpose = mx_tensor_and_tranpose
735
+ else :
736
+ x_tensor = x
737
+ w_tensor , w_tma_transpose = w , False
738
+ mx_tensor , mx_tma_transpose = mx_ctx . weight_scale , False
777
739
if isinstance (x_tensor , torch .Tensor ):
778
740
x_tensor = flex .lhs_data .reinterpret (x )
779
741
if isinstance (w_tensor , torch .Tensor ):
780
742
w_tensor = flex .rhs_data .reinterpret (w )
781
743
(kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
782
744
flex .out_data .reinterpret (memory ["output" ]),
783
- flex .out_data .reinterpret (out0 ), * out0 .stride (),
784
- * out0_flex ,
745
+ flex .out_data .reinterpret (out0 ), * out0 .stride (), * out0_flex ,
785
746
x_tensor , x .stride (0 ), x .stride (1 ), x .stride (2 ),
786
747
flex .lhs_data .scale ,
787
- w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w . stride ( 2 ) != 1 ,
748
+ w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w_tma_transpose ,
788
749
flex .rhs_data .scale ,
789
- mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
750
+ mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_tma_transpose ,
790
751
bias , bias_stride ,
791
- x .shape [1 ],
792
752
x .shape [1 ] if routing_data .expt_hist is None else None ,
793
753
N , K ,
794
754
betas , gammas ,
0 commit comments