File tree Expand file tree Collapse file tree 2 files changed +15
-1
lines changed
backends/xnnpack/test/ops
examples/models/llama/source_transformation Expand file tree Collapse file tree 2 files changed +15
-1
lines changed Original file line number Diff line number Diff line change @@ -395,7 +395,9 @@ def _test_groupwise_dq_linear(
395395 quantize_ (
396396 mod ,
397397 Int8DynamicActivationIntxWeightConfig (
398- weight_dtype = torch .int4 , weight_granularity = PerGroup (group_size )
398+ # pyre-ignore[16]
399+ weight_dtype = torch .int4 ,
400+ weight_granularity = PerGroup (group_size ),
399401 ),
400402 )
401403 unwrap_tensor_subclass (mod )
Original file line number Diff line number Diff line change @@ -135,6 +135,7 @@ def quantize( # noqa C901
135135 PerAxis (0 ) if group_size == 0 else PerGroup (group_size )
136136 ),
137137 weight_mapping_type = MappingType .SYMMETRIC ,
138+ # pyre-ignore[6]
138139 intx_packing_format = "opaque_torchao_auto" ,
139140 ),
140141 )
@@ -154,12 +155,23 @@ def quantize( # noqa C901
154155 from torchao .quantization .granularity import PerGroup
155156 from torchao .utils import unwrap_tensor_subclass
156157
158+ def filter_fn (m , fqn ):
159+ is_linear = isinstance (m , nn .Linear )
160+ has_shape_compatible_with_group_size = False
161+ if is_linear :
162+ has_shape_compatible_with_group_size = (
163+ m .weight .shape [1 ] % group_size == 0
164+ )
165+ return is_linear and has_shape_compatible_with_group_size
166+
157167 quantize_ (
158168 model ,
159169 Int8DynamicActivationIntxWeightConfig (
170+ # pyre-ignore[16]
160171 weight_dtype = torch .int4 ,
161172 weight_granularity = PerGroup (group_size ),
162173 ),
174+ filter_fn = filter_fn ,
163175 )
164176
165177 model = unwrap_tensor_subclass (model )
You can’t perform that action at this time.
0 commit comments