@@ -153,15 +153,30 @@ def pre_process_static(
153153 zero_point = torch .nn .functional .pad (zero_point , padding_changes )
154154 return input , scale , zero_point
155155
156- def post_process (self , input : torch .Tensor ) -> torch .Tensor :
156+ def post_process (
157+ self ,
158+ input : torch .Tensor ,
159+ scale : torch .Tensor ,
160+ zero_point : torch .Tensor ,
161+ block_size : Tuple [int , ...],
162+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
157163 orig_out_features , orig_in_features = input .shape
158164 in_features = find_multiple (orig_in_features , 1024 )
159165 out_features = find_multiple (orig_out_features , 8 )
160166 input = torch .nn .functional .pad (
161167 input ,
162168 (0 , in_features - orig_in_features , 0 , out_features - orig_out_features ),
163169 )
164- return input
170+ assert (
171+ len (block_size ) == 2
172+ ), f"TensorCoreTiledLayout only supports len(block_size) == 2, got: { block_size } "
173+ scale_pad_dim_0 = (out_features - orig_out_features ) // block_size [0 ]
174+ scale_pad_dim_1 = (in_features - orig_in_features ) // block_size [1 ]
175+ scale = torch .nn .functional .pad (scale , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 ))
176+ zero_point = torch .nn .functional .pad (
177+ zero_point , (0 , scale_pad_dim_1 , 0 , scale_pad_dim_0 )
178+ )
179+ return input , scale , zero_point
165180
166181 def extra_repr (self ):
167182 return f"inner_k_tiles={ self .inner_k_tiles } "
@@ -335,31 +350,25 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
335350
336351 if func is aten .slice .Tensor :
337352 self , dim , start , end , step = fill_defaults (args , 5 , [0 , None , None , 1 ])
338- if dim == 0 :
339- int_data , scale , zero_point = self .get_plain ()
340- int_data = aten .slice .Tensor (int_data , dim , start , end , step )
341- # this is to handle padding
342- int_data = self ._layout .post_process (int_data )
343- sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
344- return return_and_correct_aliasing (func , args , kwargs , sliced )
345- elif dim == 1 :
353+ if dim in [0 , 1 ]:
346354 int_data , scale , zero_point = self .get_plain ()
347- assert step == 1 , "Only step == 1 is supported in slicing right now"
348355 data_len = int_data .shape [dim ]
349356 scale_len = scale .shape [dim ]
350357 ratio = data_len / scale_len
351358 start_scale = int (start / ratio )
352359 end_scale = int (end / ratio )
353360
354361 int_data = aten .slice .Tensor (int_data , dim , start , end , step )
355- # this is to handle padding
356- int_data = self ._layout .post_process (int_data )
357362 scale = aten .slice .Tensor (scale , dim , start_scale , end_scale , step )
358363 zero_point = aten .slice .Tensor (
359364 zero_point , dim , start_scale , end_scale , step
360365 )
366+ # this is to handle padding
367+ int_data , scale , zero_point = self ._layout .post_process (
368+ int_data , scale , zero_point , self .block_size
369+ )
361370 sliced = self .from_plain (int_data , scale , zero_point , self ._layout )
362- return sliced
371+ return return_and_correct_aliasing ( func , args , kwargs , sliced )
363372 else :
364373 raise NotImplementedError (
365374 f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run { func } , with dim={ dim } , that is not supported"
@@ -371,6 +380,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
371380
372381 __torch_function__ = torch ._C ._disabled_torch_function_impl
373382
383+ @property
384+ def block_size (self ):
385+ from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
386+
387+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
388+ cur_shape = self .shape
389+ assert len (cur_shape ) == 4
390+ inner_k_tiles = cur_shape [- 1 ] * 2
391+ original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
392+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
393+ return (1 , groupsize )
394+
374395 def get_plain (self ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
375396 from torchao .quantization .quant_primitives import (
376397 ZeroPointDomain ,
0 commit comments