11from dataclasses import dataclass
22import itertools
3- import math
43import sys
54import torch
65import triton
@@ -121,20 +120,19 @@ def create_weight_descriptor(w_tensor: torch.Tensor, block_k: int, block_n: int,
121120 transpose = transpose )
122121
123122 @staticmethod
124- def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , K : int , N : int ,
125- mx_scale_stride_k : int , mx_scale_stride_n : int , n_expts_tot : int , batch_size : int ,
126- expt_data : Optional [ExptData ], swizzle_mx : bool ,
127- transpose : bool ) -> TensorDescriptor :
123+ def create_block_scale_descriptor (mx_tensor : torch .Tensor , block_k : int , block_n : int , B : int , K : int , N : int ,
124+ mx_scale_stride_k : int , mx_scale_stride_n : int , swizzle_mx : bool ,
125+ transpose : Optional [bool ]) -> TensorDescriptor :
128126 """Create a tensor descriptor for block scale factors"""
129127 MX_PACK_DIVISOR = 32
130128 MX_SCALE_BLOCK_K = block_k // MX_PACK_DIVISOR
131129 PackedK = (K + MX_PACK_DIVISOR - 1 ) // MX_PACK_DIVISOR
132130
133131 if swizzle_mx :
134- num_expt_x_ncol = ( n_expts_tot if expt_data is not None and len ( expt_data . block_pid_map ) > 0 else
135- batch_size ) * (( N + 127 ) // 128 )
132+ assert transpose is None
133+ num_expt_x_ncol = B * triton . cdiv ( N , 128 )
136134 return TensorDescriptor (
137- base = mx_tensor , shape = [1 , num_expt_x_ncol , (PackedK + 3 ) // 4 , 2 , 256 ],
135+ base = mx_tensor , shape = [1 , num_expt_x_ncol , triton . cdiv (PackedK , 4 ) , 2 , 256 ],
138136 strides = [num_expt_x_ncol * mx_scale_stride_n , mx_scale_stride_n , mx_scale_stride_k , 256 ,
139137 1 ], block_shape = [1 , block_n // 128 , MX_SCALE_BLOCK_K // 4 , 2 , 256 ])
140138 else :
@@ -151,35 +149,12 @@ def squeeze_after_dim(x, dim=2):
151149 return x .view (* new_shape )
152150
153151 @staticmethod
154- def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
155- block_k : int ) -> TensorDescriptor :
156- """Create a tensor descriptor for input matrix X via TMA gather"""
157- x_desc = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
158- assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
159- INT_MAX = 2147483647
160- return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
161- block_shape = [1 , block_k ])
162-
163- @staticmethod
164- def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
165- block_k : int ) -> TensorDescriptor :
166- """Create a tensor descriptor for input matrix X via TMA"""
167- x_desc = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
168- assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
169- return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
170- block_shape = [block_m , block_k ])
171-
172- @staticmethod
173- def create_input_descriptor (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_k : int ,
174- block_m : int , use_gather_tma : bool , use_load_tma : bool ) -> TensorDescriptor :
175- """Create a tensor descriptor for input matrix X based on TMA usage"""
176- if use_gather_tma :
177- return TensorDescriptorBuilder .create_input_descriptor_gather (x_tensor , K , x_stride_1 , x_stride_2 , block_k )
178- elif use_load_tma :
179- return TensorDescriptorBuilder .create_input_descriptor_load (x_tensor , K , x_stride_1 , x_stride_2 , block_m ,
180- block_k )
181- else :
182- return x_tensor
152+ def create_descriptor (x_tensor : torch .Tensor , block_m : int , block_k : int ) -> TensorDescriptor :
153+ """Create a tensor descriptor for matrix X via TMA"""
154+ x_tensor = TensorDescriptorBuilder .squeeze_after_dim (x_tensor )
155+ assert x_tensor .ndim in [2 , 3 ], "TMA descriptor builder expects 2D or 3D input"
156+ block_shape = [1 ] * (x_tensor .ndim - 2 ) + [block_m , block_k ]
157+ return TensorDescriptor .from_tensor (x_tensor , block_shape = block_shape )
183158
184159
185160# ---------------------
@@ -590,66 +565,53 @@ def _create_tma_descriptors(
590565 mx_ctx : MicroscalingCtx ,
591566 expt_data : ExptData ,
592567 opt_flags : OptFlags ,
593- batch_size : int ,
568+ B : int ,
594569 K : int ,
595570 N : int ,
596571 mx_scale_stride_k : int ,
597572 mx_scale_stride_n : int ,
598- USE_GATHER_TMA : bool ,
599- X_USE_LOAD_TMA : bool ,
600- w_transpose : bool ,
601- mx_transpose : bool ,
573+ HAS_GATHER : bool ,
602574) -> Tuple [bool , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
603575 """Create and cache TMA descriptors for tensors."""
604- use_host_tma_descriptors = opt_flags .is_persistent and target_info .cuda_capability_geq (10 , 0 )
605-
606- x_desc , w_desc = [None ] * 2
607- descriptors = []
608- # The dense case currently uses on device descriptor updates
609- # so we bail out on using host descriptors in that case
610- if (use_host_tma_descriptors ):
611- if USE_GATHER_TMA or X_USE_LOAD_TMA :
612- x_desc = TensorDescriptorBuilder .create_input_descriptor (
613- x , K , x .stride (1 ), x .stride (2 ),
614- opt_flags .block_k , opt_flags .block_m ,
615- USE_GATHER_TMA , X_USE_LOAD_TMA
616- )
617- descriptors .append (x_desc )
618- if (expt_data is not None and len (expt_data .block_pid_map ) > 0 ):
619- w_desc = TensorDescriptorBuilder .create_weight_descriptor (
620- w , opt_flags .block_k , opt_flags .block_n , w_transpose
621- )
622- is_microscaled_format = (mx_ctx .weight_scale is not None ) and (w .dtype == torch .uint8 )
623- if is_microscaled_format :
624- # Pad the inner shape to 128 for mxfp4 weights
625- # for mixed precision fp8 x mxfp4 compute
626- pad = 128
627- dim_to_pad = - 1
628- old_size = w_desc .shape [dim_to_pad ]
629- padded_size = math .ceil (old_size / pad ) * pad
630- if padded_size != old_size :
631- w_desc .shape = list (w_desc .shape )
632- w_desc .shape [dim_to_pad ] = padded_size
633- descriptors .append (w_desc )
634- # Optional MX scale descriptor
635- descriptors .append (None )
636- if mx_tensor is not None :
637- descriptors [- 1 ] = TensorDescriptorBuilder .create_block_scale_descriptor (
638- mx_tensor , opt_flags .block_k , opt_flags .block_n , K , N ,
639- mx_scale_stride_k , mx_scale_stride_n , routing_data .n_expts_tot ,
640- batch_size ,
641- expt_data , mx_ctx .swizzle_scale , mx_transpose
642- )
643576
644- # TODO: Currently all or none, instead should support a mixture
645- # of host and device descriptors
646- if None in descriptors or len (descriptors ) == 0 :
647- descriptors = [x , w , mx_tensor ]
648- use_host_tma_descriptors = False
649- if opt_flags .is_persistent :
650- opt_flags .target_kernel_kwargs ["USE_HOST_TMA_DESCRIPTORS" ] = use_host_tma_descriptors
577+ x_tensor_or_desc , mx_desc_and_transpose = x , (None , False )
651578
652- return use_host_tma_descriptors , * descriptors
579+ if not HAS_GATHER :
580+ x_tensor_or_desc = TensorDescriptorBuilder .create_descriptor (x , opt_flags .block_m , opt_flags .block_k )
581+
582+ w_transpose = w .stride (2 ) != 1
583+ w_desc = TensorDescriptorBuilder .create_weight_descriptor (
584+ w , opt_flags .block_k , opt_flags .block_n , w_transpose
585+ )
586+ w_desc_and_transpose = (w_desc , w_transpose )
587+
588+ is_microscaled_format = mx_ctx .weight_scale is not None and w .dtype == torch .uint8
589+ if is_microscaled_format :
590+ # Pad the inner shape to 128 for mxfp4 weights; TMA requires this when the compiler uses
591+ # CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B.
592+ # This technically makes the shape masking incorrect, but it's fine because:
593+ # - When the N dim is padded, the scales will be masked to 0.
594+ # - When the K dim is padded, the activations we perform tl.dot with will be masked to 0.
595+ # Note: the scales can't be relied on for zeroing in this case, because they apply to groups
596+ # of 32 elements in the K dimension.
597+ pad = 128
598+ dim_to_pad = - 1
599+ old_size = w_desc .shape [dim_to_pad ]
600+ padded_size = triton .cdiv (old_size , pad ) * pad
601+ if padded_size != old_size :
602+ w_desc .shape = list (w_desc .shape )
603+ w_desc .shape [dim_to_pad ] = padded_size
604+
605+ if mx_tensor is not None :
606+ mx_transpose = mx_scale_stride_n != 1 if mx_ctx .swizzle_scale is None else None
607+ mx_desc = TensorDescriptorBuilder .create_block_scale_descriptor (
608+ mx_tensor , opt_flags .block_k , opt_flags .block_n ,
609+ routing_data .n_expts_tot if expt_data is not None and len (expt_data .block_pid_map ) > 0 else B , K , N ,
610+ mx_scale_stride_k , mx_scale_stride_n , mx_ctx .swizzle_scale , mx_transpose
611+ )
612+ mx_desc_and_transpose = (mx_desc , mx_transpose )
613+
614+ return x_tensor_or_desc , w_desc_and_transpose , mx_desc_and_transpose
653615
654616
655617def matmul_ogs (x , w , bias ,
@@ -754,41 +716,39 @@ def matmul_ogs(x, w, bias,
754716 expt_token_offs_raw = None if expt_data is None else expt_data .token_offs_raw
755717 expt_block_pid_map = None if expt_data is None else expt_data .block_pid_map [block_m ]
756718
757- HAS_TMA_GS = target_info . cuda_capability_geq ( 10 , 0 )
758- USE_GATHER_TMA = HAS_TMA_GS and gather_indx is not None
759- X_USE_LOAD_TMA = gather_indx is None and not USE_GATHER_TMA
760- _ , x_tensor , w_tensor , mx_tensor = _create_tma_descriptors (
761- x = x , w = w ,
762- mx_tensor = mx_ctx . weight_scale ,
763- routing_data = routing_data ,
764- mx_ctx = mx_ctx ,
765- expt_data = expt_data ,
766- opt_flags = opt_flags ,
767- batch_size = batch_size ,
768- K = K ,
769- N = N ,
770- mx_scale_stride_k = mx_scale_stride_k ,
771- mx_scale_stride_n = mx_scale_stride_n ,
772- USE_GATHER_TMA = USE_GATHER_TMA ,
773- X_USE_LOAD_TMA = X_USE_LOAD_TMA ,
774- w_transpose = w . stride ( 2 ) != 1 ,
775- mx_transpose = mx_scale_stride_n != 1 ,
776- )
719+ if opt_flags . is_persistent :
720+ x_tensor , w_tensor_and_transpose , mx_tensor_and_tranpose = _create_tma_descriptors (
721+ x = x , w = w , mx_tensor = mx_ctx . weight_scale ,
722+ routing_data = routing_data ,
723+ mx_ctx = mx_ctx ,
724+ expt_data = expt_data ,
725+ opt_flags = opt_flags ,
726+ B = batch_size ,
727+ K = K ,
728+ N = N ,
729+ mx_scale_stride_k = mx_scale_stride_k ,
730+ mx_scale_stride_n = mx_scale_stride_n ,
731+ HAS_GATHER = gather_indx is not None ,
732+ )
733+ w_tensor , w_tma_transpose = w_tensor_and_transpose
734+ mx_tensor , mx_tma_transpose = mx_tensor_and_tranpose
735+ else :
736+ x_tensor = x
737+ w_tensor , w_tma_transpose = w , False
738+ mx_tensor , mx_tma_transpose = mx_ctx . weight_scale , False
777739 if isinstance (x_tensor , torch .Tensor ):
778740 x_tensor = flex .lhs_data .reinterpret (x )
779741 if isinstance (w_tensor , torch .Tensor ):
780742 w_tensor = flex .rhs_data .reinterpret (w )
781743 (kernels ._p_matmul_ogs if opt_flags .is_persistent else kernels ._matmul_ogs )[(n_cta ,)](
782744 flex .out_data .reinterpret (memory ["output" ]),
783- flex .out_data .reinterpret (out0 ), * out0 .stride (),
784- * out0_flex ,
745+ flex .out_data .reinterpret (out0 ), * out0 .stride (), * out0_flex ,
785746 x_tensor , x .stride (0 ), x .stride (1 ), x .stride (2 ),
786747 flex .lhs_data .scale ,
787- w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w . stride ( 2 ) != 1 ,
748+ w_tensor , w .stride (0 ), w .stride (1 ), w .stride (2 ), w_tma_transpose ,
788749 flex .rhs_data .scale ,
789- mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_scale_stride_n != 1 ,
750+ mx_tensor , mx_scale_stride_e , mx_scale_stride_k , mx_scale_stride_n , mx_tma_transpose ,
790751 bias , bias_stride ,
791- x .shape [1 ],
792752 x .shape [1 ] if routing_data .expt_hist is None else None ,
793753 N , K ,
794754 betas , gammas ,
0 commit comments