Skip to content

Commit ec07cba

Browse files
metascroyfacebook-github-bot
authored andcommitted
Forward fix for D82242003 (#14241)
Summary: This fixes internal failures on D82242003: * pyre errors * buck build --flagfile fbcode//mode/dev fbcode//executorch/examples/models/fb/llama4:ngtts_semantic_lm_xnnpack_quantized.pte The second failure is because the old and new APIs have different behaviors when group_size is incompatible with the nn.Linear module's shape. In the old API, it silently does not quantize the layer, whereas the new API is more explicit and throws an error. This diff uses a filter_fn to restore the previous behavior. Reviewed By: digantdesai Differential Revision: D82265586
1 parent e6b9111 commit ec07cba

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

backends/xnnpack/test/ops/test_linear.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ def _test_groupwise_dq_linear(
395395
quantize_(
396396
mod,
397397
Int8DynamicActivationIntxWeightConfig(
398-
weight_dtype=torch.int4, weight_granularity=PerGroup(group_size)
398+
# pyre-ignore[16]
399+
weight_dtype=torch.int4,
400+
weight_granularity=PerGroup(group_size),
399401
),
400402
)
401403
unwrap_tensor_subclass(mod)

examples/models/llama/source_transformation/quantize.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def quantize( # noqa C901
135135
PerAxis(0) if group_size == 0 else PerGroup(group_size)
136136
),
137137
weight_mapping_type=MappingType.SYMMETRIC,
138+
# pyre-ignore[6]
138139
intx_packing_format="opaque_torchao_auto",
139140
),
140141
)
@@ -154,12 +155,23 @@ def quantize( # noqa C901
154155
from torchao.quantization.granularity import PerGroup
155156
from torchao.utils import unwrap_tensor_subclass
156157

158+
def filter_fn(m, fqn):
159+
is_linear = isinstance(m, nn.Linear)
160+
has_shape_compatible_with_group_size = False
161+
if is_linear:
162+
has_shape_compatible_with_group_size = (
163+
m.weight.shape[1] % group_size == 0
164+
)
165+
return is_linear and has_shape_compatible_with_group_size
166+
157167
quantize_(
158168
model,
159169
Int8DynamicActivationIntxWeightConfig(
170+
# pyre-ignore[16]
160171
weight_dtype=torch.int4,
161172
weight_granularity=PerGroup(group_size),
162173
),
174+
filter_fn=filter_fn,
163175
)
164176

165177
model = unwrap_tensor_subclass(model)

0 commit comments

Comments
 (0)