@@ -144,11 +144,17 @@ def create_block_scale_descriptor(mx_tensor: torch.Tensor, block_k: int, block_n
144
144
block_shape = [1 , MX_SCALE_BLOCK_K ,
145
145
block_n ], transpose = transpose )
146
146
147
+ @staticmethod
148
+ def squeeze_after_dim (x , dim = 2 ):
149
+ shape = list (x .shape )
150
+ new_shape = [s for s in shape [:dim - 1 ] if s != 1 ] + shape [dim - 1 :]
151
+ return x .view (* new_shape )
152
+
147
153
@staticmethod
148
154
def create_input_descriptor_gather (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int ,
149
155
block_k : int ) -> TensorDescriptor :
150
156
"""Create a tensor descriptor for input matrix X via TMA gather"""
151
- x_desc = x_tensor . squeeze ( )
157
+ x_desc = TensorDescriptorBuilder . squeeze_after_dim ( x_tensor )
152
158
assert x_desc .ndim == 2 , "TMA gather descriptor requires 2D input"
153
159
INT_MAX = 2147483647
154
160
return TensorDescriptor (base = x_desc , shape = [INT_MAX , K ], strides = [x_stride_1 , x_stride_2 ],
@@ -158,7 +164,7 @@ def create_input_descriptor_gather(x_tensor: torch.Tensor, K: int, x_stride_1: i
158
164
def create_input_descriptor_load (x_tensor : torch .Tensor , K : int , x_stride_1 : int , x_stride_2 : int , block_m : int ,
159
165
block_k : int ) -> TensorDescriptor :
160
166
"""Create a tensor descriptor for input matrix X via TMA"""
161
- x_desc = x_tensor . squeeze ( )
167
+ x_desc = TensorDescriptorBuilder . squeeze_after_dim ( x_tensor )
162
168
assert x_desc .ndim in [2 , 3 ], "LHS input TMA descriptor builder expects 2D or 3D input"
163
169
return TensorDescriptor (base = x_desc , shape = [x_desc .shape [0 ], K ], strides = [x_stride_1 , x_stride_2 ],
164
170
block_shape = [block_m , block_k ])
0 commit comments