Skip to content

Commit 6794ef5

Browse files
authored
Reference representation of dqlinear int4 for xnnpack (#2520)
* When replacing literals with placeholders lists are always converted to tuples Summary: THis is needed because lists are not hashable, since they are mutable, and as a result we cannot have literals_to_ph in pattern rewrites used inside reference_representation_rewrite.py Test Plan: CI + next diff relies on this feature Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Allow pattern replacement to ignore literals Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Reference representation of dqlinear int4 for xnnpack Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Reference representation of dqlinear int4 for xnnpack" Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Reference representation of dqlinear int4 for xnnpack" Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Reference representation of dqlinear int4 for xnnpack" Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update base for Update on "Reference representation of dqlinear int4 for xnnpack" Summary: This diff adds dynamic quantized linear's integer arithmetic representation. This is quite close to how arithmetic is done in xnnpack. Basic tests added against q/dq to make things are sane. Followups: - See if such a graph is traceable. - Optimize implementation if needed Test Plan: added Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D78198154](https://our.internmc.facebook.com/intern/diff/D78198154) [ghstack-poisoned]
1 parent ea3691e commit 6794ef5

File tree

2 files changed

+762
-2
lines changed

2 files changed

+762
-2
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 324 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
import contextlib
99
from dataclasses import dataclass
1010
from functools import partial
11-
from typing import Any, Callable, Optional
11+
from typing import Any, Callable, List, Optional
1212

1313
import torch
1414
from torch._higher_order_ops.out_dtype import out_dtype
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616
from torch.fx import GraphModule
17+
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
1718
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
1819

1920
from torchao.quantization.pt2e.export_utils import WrapperModule
@@ -23,12 +24,17 @@
2324
_replace_literals_with_new_placeholders,
2425
remove_tensor_overload_for_qdq_ops,
2526
)
27+
from torchao.quantization.quant_primitives import MappingType
28+
from torchao.quantization.utils import _get_per_token_block_size
29+
from torchao.utils import _register_custom_op
2630

2731
try:
2832
from torch._export.utils import _disable_aten_to_metadata_assertions
2933
except:
3034
_disable_aten_to_metadata_assertions = contextlib.nullcontext
3135

36+
quant_lib = torch.library.Library("torchao", "FRAGMENT")
37+
register_custom_op = _register_custom_op(quant_lib)
3238

3339
__all__ = [
3440
"reference_representation_rewrite",
@@ -203,6 +209,252 @@ def _reference_dynamic_quantized_linear(
203209
return out_fp32
204210

205211

212+
def _qdq_dynamic_quantized_linear_4bit_groupwise(
213+
x_fp32,
214+
x_eps,
215+
weight_i4,
216+
weight_scale,
217+
weight_zero_point,
218+
bias_fp32,
219+
group_size,
220+
):
221+
# Dynamic quantization of activation
222+
x_mapping_type = MappingType.ASYMMETRIC
223+
per_token_block_size = _get_per_token_block_size(x_fp32)
224+
x_quant_min = -128
225+
x_quant_max = 127
226+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
227+
x_fp32,
228+
x_mapping_type.name,
229+
per_token_block_size,
230+
torch.int8,
231+
x_quant_min,
232+
x_quant_max,
233+
x_eps,
234+
torch.float32,
235+
torch.int32,
236+
)
237+
x_i8 = torch.ops.torchao.quantize_affine(
238+
x_fp32,
239+
per_token_block_size,
240+
x_scale,
241+
x_zero_point,
242+
torch.int8,
243+
x_quant_min,
244+
x_quant_max,
245+
)
246+
x_fp32 = torch.ops.torchao.dequantize_affine(
247+
x_i8,
248+
per_token_block_size,
249+
x_scale,
250+
x_zero_point,
251+
torch.int8,
252+
x_quant_min,
253+
x_quant_max,
254+
torch.float32,
255+
)
256+
257+
assert group_size > 0, "Group size must be positive"
258+
assert weight_i4.shape[1] % group_size == 0, (
259+
"Weight must be divisible by group_size"
260+
)
261+
assert weight_i4.dim() == 2, "Weight must be 2D tensor"
262+
block_size = (1, group_size)
263+
weight_fp32 = torch.ops.torchao.dequantize_affine(
264+
weight_i4,
265+
block_size,
266+
weight_scale,
267+
weight_zero_point,
268+
torch.int8,
269+
-8,
270+
7,
271+
)
272+
273+
out_fp32 = torch.ops.aten.linear.default(x_fp32, weight_fp32, bias_fp32)
274+
return out_fp32
275+
276+
277+
@register_custom_op
278+
def _reference_dqlinear_int4(
279+
x_fp32: torch.Tensor,
280+
x_eps: float,
281+
weight_i4: torch.Tensor,
282+
weight_scale: torch.Tensor,
283+
weight_zero_point: torch.Tensor, # Not used because assuming weight is symmetric
284+
bias_fp32: Optional[torch.Tensor],
285+
group_size: List[int],
286+
) -> torch.Tensor:
287+
"""
288+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
289+
This implementation emulates actual numerics of on-device integer compute.
290+
291+
Args:
292+
x_fp32: Input activation tensor in fp32
293+
x_eps: Epsilon for quantization parameter computation
294+
weight_i4: 4-bit quantized weight (stored as int8 with values in [-8, 7])
295+
weight_scale: Groupwise scales for weight dequantization
296+
weight_zero_point: Groupwise zero points for weight (unused for symmetric)
297+
bias_fp32: Optional bias tensor in fp32
298+
group_size: Size of each group for groupwise quantization
299+
300+
Returns:
301+
Output tensor in fp32
302+
"""
303+
# Dynamic quantization of activation
304+
group_size = group_size[1]
305+
x_mapping_type = MappingType.ASYMMETRIC
306+
per_token_block_size = _get_per_token_block_size(x_fp32)
307+
x_quant_min = -128
308+
x_quant_max = 127
309+
x_scale, x_zero_point = torch.ops.torchao.choose_qparams_affine(
310+
x_fp32,
311+
x_mapping_type.name,
312+
per_token_block_size,
313+
torch.int8,
314+
x_quant_min,
315+
x_quant_max,
316+
x_eps,
317+
torch.float32,
318+
torch.int32,
319+
)
320+
x_i8 = torch.ops.torchao.quantize_affine(
321+
x_fp32,
322+
per_token_block_size,
323+
x_scale,
324+
x_zero_point,
325+
torch.int8,
326+
x_quant_min,
327+
x_quant_max,
328+
)
329+
330+
# For groupwise quantization, we need to handle the computation differently
331+
# weight_i4 shape: [out_features, in_features]
332+
# weight_scale shape: [out_features, in_features // group_size]
333+
# weight_zero_point shape: [out_features, in_features // group_size]
334+
out_features, in_features = weight_i4.shape
335+
num_groups = in_features // group_size
336+
337+
# scales in xnnpack are stored as bf16 and converted to fp32 for computation
338+
weight_scale = weight_scale.to(torch.bfloat16).to(torch.float32)
339+
340+
# Reshape for group-wise processing
341+
# x: [batch_size, in_features] -> [batch_size, num_groups, group_size]
342+
x_orig_shape = x_i8.shape
343+
k_dim = x_i8.shape[-1]
344+
x_i8 = x_i8.view(-1, k_dim)
345+
batch_size = x_i8.shape[0]
346+
x_i8_grouped = x_i8.view(batch_size, num_groups, group_size)
347+
348+
# weight: [out_features, in_features] -> [out_features, num_groups, group_size]
349+
weight_i4_grouped = weight_i4.view(out_features, num_groups, group_size)
350+
351+
# Convert to int16 for computation
352+
x_i32_grouped = x_i8_grouped.to(torch.int32)
353+
weight_i32_grouped = weight_i4_grouped.to(torch.int32)
354+
355+
# Perform groupwise integer linear operation
356+
acc_fp32 = torch.zeros(
357+
batch_size, out_features, dtype=torch.float32, device=x_fp32.device
358+
)
359+
out_shape = list(x_orig_shape)
360+
out_shape[-1] = out_features
361+
362+
if weight_scale.ndim == 1:
363+
weight_scale = weight_scale.unsqueeze(0)
364+
365+
for group_idx in range(num_groups):
366+
# Extract current group
367+
x_group = x_i32_grouped[:, group_idx, :] # [batch_size, group_size]
368+
weight_group = weight_i32_grouped[:, group_idx, :] # [out_features, group_size]
369+
weight_group_col_sum = weight_group.sum(dim=-1) # [out_features]
370+
371+
# Get scale for this group
372+
weight_scale_group = weight_scale[:, group_idx] # [out_features]
373+
374+
# Integer matmul: [batch_size, group_size] @ [group_size, out_features] -> [batch_size, out_features]
375+
group_acc = out_dtype(
376+
torch.ops.aten.linear.default,
377+
torch.int32,
378+
x_group,
379+
weight_group,
380+
None,
381+
)
382+
383+
# Output has to be scaled by x_scale * weight_scale_group
384+
# However we will first scale by weight_scale_group, that is accounting
385+
# only for scale of weight, and then scale by x_scale at the end because
386+
# x_scale applies to all groups
387+
acc_fp32 = acc_fp32 + group_acc.to(torch.float32) * weight_scale_group.view(
388+
1, -1
389+
)
390+
391+
# we must also subtract x_zero_point * weight_group_sum
392+
# since (X - x_zero_point) * W = X * W - x_zero_point * W
393+
weights_col_sum_adjusted = (
394+
weight_group_col_sum.to(torch.float32).view(1, -1)
395+
* x_zero_point.view(-1, 1)
396+
* weight_scale_group.view(1, -1)
397+
)
398+
acc_fp32 = acc_fp32 - weights_col_sum_adjusted
399+
x_scale_multiplier = x_scale.view(-1, 1)
400+
out_fp32 = acc_fp32 * x_scale_multiplier
401+
if bias_fp32 is not None:
402+
out_fp32 = out_fp32 + bias_fp32
403+
404+
return out_fp32.view(out_shape)
405+
406+
407+
def _reference_dynamic_quantized_linear_4bit_groupwise(
408+
x_fp32,
409+
x_eps,
410+
weight_i4,
411+
weight_scale,
412+
weight_zero_point, # Not used because assuming weight is symmetric
413+
bias_fp32,
414+
group_size,
415+
):
416+
"""
417+
Reference implementation for dynamically quantized linear 4-bit groupwise operation.
418+
This function now delegates to the custom op implementation.
419+
"""
420+
return torch.ops.torchao.reference_dqlinear_int4(
421+
x_fp32,
422+
x_eps,
423+
weight_i4,
424+
weight_scale,
425+
weight_zero_point,
426+
bias_fp32,
427+
(1, group_size),
428+
)
429+
430+
431+
def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise(
432+
match,
433+
original_graph,
434+
pattern_graph,
435+
) -> bool:
436+
weight_is_int4 = False
437+
act_quant_is_int8 = False
438+
for node in match.nodes_map.values():
439+
if (
440+
isinstance(node, torch.fx.Node)
441+
and node.op == "call_function"
442+
and node.target == torch.ops.torchao.dequantize_affine.default
443+
):
444+
args = node.args
445+
if len(args) >= 7:
446+
weight_is_int4 = args[5] == -8 and args[6] == 7
447+
if (
448+
isinstance(node, torch.fx.Node)
449+
and node.op == "call_function"
450+
and node.target == torch.ops.torchao.quantize_affine.default
451+
):
452+
args = node.args
453+
if len(args) >= 5:
454+
act_quant_is_int8 = args[4] == torch.int8
455+
return weight_is_int4 and act_quant_is_int8
456+
457+
206458
def _qdq_quantized_conv2d(
207459
x_i8,
208460
x_scale,
@@ -627,6 +879,9 @@ class _RewriteInfo:
627879
# post transformation on the exported pattern and replacement GraphModule
628880
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
629881
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
882+
filter_fn: Optional[
883+
list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]]
884+
] = None
630885
ignore_literals: bool = False
631886

632887

@@ -739,6 +994,31 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
739994
127,
740995
)
741996

997+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1 = (
998+
torch.randn((1, 32), dtype=torch.float), # x_fp32
999+
torch.finfo(torch.float32).eps, # x_eps
1000+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
1001+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
1002+
torch.zeros(
1003+
8, 4, dtype=torch.int
1004+
), # weight_zero_point [out_features, num_groups]
1005+
torch.randn(8, dtype=torch.float), # bias_fp32
1006+
8, # group_size
1007+
)
1008+
1009+
# just saw that we can match again > 2 dim input. Hacky.
1010+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2 = (
1011+
torch.randn((1, 1, 32), dtype=torch.float), # x_fp32
1012+
torch.finfo(torch.float32).eps, # x_eps
1013+
torch.randint(-8, 7, (8, 32), dtype=torch.int8), # weight_i4 (stored as int8)
1014+
torch.randn(8, 4, dtype=torch.float), # weight_scale [out_features, num_groups]
1015+
torch.zeros(
1016+
8, 4, dtype=torch.int
1017+
), # weight_zero_point [out_features, num_groups]
1018+
torch.randn(8, dtype=torch.float), # bias_fp32
1019+
8, # group_size
1020+
)
1021+
7421022
_REWRITE_INFO_LIST = [
7431023
_RewriteInfo(
7441024
_DYNAMIC_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
@@ -753,6 +1033,48 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
7531033
literal_to_ph_idx={-128: 1, 127: 2, torch.finfo(torch.float32).eps: 3},
7541034
),
7551035
),
1036+
_RewriteInfo(
1037+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_1,
1038+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
1039+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
1040+
partial(
1041+
_replace_literals_with_existing_placeholders,
1042+
literal_to_ph_idx={
1043+
torch.finfo(torch.float32).eps: 1,
1044+
(1, 8): 6,
1045+
},
1046+
),
1047+
partial(
1048+
_replace_literals_with_existing_placeholders,
1049+
literal_to_ph_idx={
1050+
torch.finfo(torch.float32).eps: 1,
1051+
(1, 8): 6,
1052+
},
1053+
),
1054+
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
1055+
ignore_literals=True,
1056+
),
1057+
_RewriteInfo(
1058+
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2,
1059+
WrapperModule(_qdq_dynamic_quantized_linear_4bit_groupwise),
1060+
WrapperModule(_reference_dynamic_quantized_linear_4bit_groupwise),
1061+
partial(
1062+
_replace_literals_with_existing_placeholders,
1063+
literal_to_ph_idx={
1064+
torch.finfo(torch.float32).eps: 1,
1065+
(1, 8): 6,
1066+
},
1067+
),
1068+
partial(
1069+
_replace_literals_with_existing_placeholders,
1070+
literal_to_ph_idx={
1071+
torch.finfo(torch.float32).eps: 1,
1072+
(1, 8): 6,
1073+
},
1074+
),
1075+
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
1076+
ignore_literals=True,
1077+
),
7561078
_RewriteInfo(
7571079
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
7581080
WrapperModule(_qdq_quantized_linear),
@@ -835,7 +1157,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
8351157
model,
8361158
pattern,
8371159
replacement,
838-
match_filters=None,
1160+
match_filters=rewrite_info.filter_fn,
8391161
ignore_literals=rewrite_info.ignore_literals,
8401162
) # type: ignore[arg-type]
8411163

0 commit comments

Comments
 (0)