Skip to content

Commit a43b47a

Browse files
committed
lint
1 parent 4c610e4 commit a43b47a

File tree

4 files changed

+217
-71
lines changed

4 files changed

+217
-71
lines changed

exir/passes/_quant_patterns_and_replacements.py

Lines changed: 168 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@
2525

2626
from torch import Tensor
2727
from torch.library import custom_op
28+
29+
2830
@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=())
2931
def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
3032
num_embeddings, embedding_dim = weight.shape
3133

3234
if bitwidth == 2:
3335
assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4"
3436
weight_range_shifted = weight.add(2).view(torch.uint8)
35-
weight_view = weight_range_shifted.view(
36-
num_embeddings, embedding_dim // 4, 4
37-
)
37+
weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4)
3838
weight_0 = weight_view[:, :, 0]
3939
weight_1 = weight_view[:, :, 1] << 2
4040
weight_2 = weight_view[:, :, 2] << 4
@@ -53,7 +53,7 @@ def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:
5353
return packed_weight
5454
elif bitwidth == 8:
5555
return weight
56-
56+
5757
raise RuntimeError(f"Unsupported bitwidth {bitwidth}")
5858

5959

@@ -64,7 +64,12 @@ def _(weight, bit_width):
6464
num_embeddings, embedding_dim = weight.shape
6565
values_per_byte = 8 // bit_width
6666
assert embedding_dim % values_per_byte == 0
67-
return torch.empty(num_embeddings, embedding_dim // values_per_byte, dtype=torch.uint8, device=weight.device)
67+
return torch.empty(
68+
num_embeddings,
69+
embedding_dim // values_per_byte,
70+
dtype=torch.uint8,
71+
device=weight.device,
72+
)
6873

6974

7075
# TODO: extending an existing library that is defined in OSS might be a bit
@@ -114,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):
114119
assert (
115120
weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype
116121
), "Expecting weight_zero_points to be None or have same dtype as weight_scales"
117-
assert (
118-
weight_zero_points is None or weight_zero_points.dim() in [1, 2]
119-
), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
122+
assert weight_zero_points is None or weight_zero_points.dim() in [
123+
1,
124+
2,
125+
], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"
120126
assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(
121127
0
122128
), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"
@@ -278,6 +284,7 @@ def embedding_2bit(
278284
)
279285
return torch.ops.aten.embedding.default(weight, indices)
280286

287+
281288
@register_fake("quantized_decomposed::embedding_2bit")
282289
def _(
283290
weight: torch.Tensor,
@@ -286,12 +293,13 @@ def _(
286293
weight_quant_min: int,
287294
weight_quant_max: int,
288295
indices: torch.Tensor,
289-
):
296+
):
290297
num_embeddings, packed_embedding_dim = weight.shape
291298
embedding_dim = packed_embedding_dim * 4
292299
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
293300
return embedding(indices)
294301

302+
295303
@register_fake("quantized_decomposed::embedding_2bit.out")
296304
def embedding_2bit_out_meta(
297305
weight: torch.Tensor,
@@ -311,6 +319,7 @@ def embedding_2bit_out_meta(
311319
indices,
312320
)
313321

322+
314323
@impl(quantized_decomposed_lib, "embedding_2bit.dtype", "CompositeExplicitAutograd")
315324
def embedding_2bit_dtype(
316325
weight: torch.Tensor,
@@ -352,6 +361,7 @@ def embedding_2bit_dtype(
352361
)
353362
return torch.ops.aten.embedding.default(weight, indices)
354363

364+
355365
@register_fake("quantized_decomposed::embedding_2bit.dtype")
356366
def _(
357367
weight: torch.Tensor,
@@ -361,12 +371,13 @@ def _(
361371
weight_quant_max: int,
362372
indices: torch.Tensor,
363373
dtype: Optional[torch.dtype],
364-
) -> torch.Tensor:
374+
) -> torch.Tensor:
365375
num_embeddings, packed_embedding_dim = weight.shape
366376
embedding_dim = packed_embedding_dim * 4
367377
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
368378
return embedding(indices).to(dtype)
369379

380+
370381
@register_fake("quantized_decomposed::embedding_2bit.dtype_out")
371382
def embedding_2bit_dtype_out_meta(
372383
weight: torch.Tensor,
@@ -448,6 +459,7 @@ def embedding_4bit(
448459
)
449460
return torch.ops.aten.embedding.default(weight, indices)
450461

462+
451463
@register_fake("quantized_decomposed::embedding_4bit")
452464
def _(
453465
weight: torch.Tensor,
@@ -456,12 +468,13 @@ def _(
456468
weight_quant_min: int,
457469
weight_quant_max: int,
458470
indices: torch.Tensor,
459-
):
471+
):
460472
num_embeddings, packed_embedding_dim = weight.shape
461473
embedding_dim = packed_embedding_dim * 2
462474
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
463475
return embedding(indices)
464476

477+
465478
@register_fake("quantized_decomposed::embedding_4bit.out")
466479
def embedding_4bit_out_meta(
467480
weight: torch.Tensor,
@@ -521,6 +534,7 @@ def embedding_4bit_dtype(
521534
)
522535
return torch.ops.aten.embedding.default(weight, indices)
523536

537+
524538
@register_fake("quantized_decomposed::embedding_4bit.dtype")
525539
def _(
526540
weight: torch.Tensor,
@@ -530,12 +544,13 @@ def _(
530544
weight_quant_max: int,
531545
indices: torch.Tensor,
532546
dtype: Optional[torch.dtype],
533-
) -> torch.Tensor:
547+
) -> torch.Tensor:
534548
num_embeddings, packed_embedding_dim = weight.shape
535549
embedding_dim = packed_embedding_dim * 2
536550
embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)
537551
return embedding(indices).to(dtype)
538552

553+
539554
@register_fake("quantized_decomposed::embedding_4bit.dtype_out")
540555
def embedding_4bit_dtype_out_meta(
541556
weight: torch.Tensor,
@@ -970,10 +985,16 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):
970985
)
971986
]
972987

973-
def _get_embedding_ops_patterns_and_replacements_torchao() -> List[Tuple[Callable, Callable, List[Callable]]]:
988+
989+
def _get_embedding_ops_patterns_and_replacements_torchao() -> ( # noqa C901
990+
List[Tuple[Callable, Callable, List[Callable]]]
991+
):
974992
def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):
975-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127)
993+
dq = torch.ops.torchao.dequantize_affine.default(
994+
int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127
995+
)
976996
return torch.ops.aten.embedding.default(dq, indices)
997+
977998
def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):
978999
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
9791000
return torch.ops.quantized_decomposed.embedding_byte.default(
@@ -984,10 +1005,26 @@ def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point)
9841005
127,
9851006
indices,
9861007
)
987-
def embedding_byte_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
988-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127, 'INT', output_dtype)
1008+
1009+
def embedding_byte_dtype_pattern(
1010+
indices, int_data, group_size, scale, zero_point, output_dtype
1011+
):
1012+
dq = torch.ops.torchao.dequantize_affine.default(
1013+
int_data,
1014+
[1, group_size],
1015+
scale,
1016+
zero_point,
1017+
torch.int8,
1018+
-128,
1019+
127,
1020+
"INT",
1021+
output_dtype,
1022+
)
9891023
return torch.ops.aten.embedding.default(dq, indices)
990-
def embedding_byte_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
1024+
1025+
def embedding_byte_dtype_replacement(
1026+
indices, int_data, group_size, scale, zero_point, output_dtype
1027+
):
9911028
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
9921029
return torch.ops.quantized_decomposed.embedding_byte.dtype(
9931030
int_data,
@@ -996,48 +1033,136 @@ def embedding_byte_dtype_replacement(indices, int_data, group_size, scale, zero_
9961033
-128,
9971034
127,
9981035
indices,
999-
dtype=output_dtype
1036+
dtype=output_dtype,
10001037
)
1001-
1038+
10021039
def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):
1003-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1)
1040+
dq = torch.ops.torchao.dequantize_affine.default(
1041+
int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1
1042+
)
10041043
return torch.ops.aten.embedding.default(dq, indices)
1044+
10051045
def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):
1006-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 2)
1046+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1047+
int_data, 2
1048+
)
10071049
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1008-
return torch.ops.quantized_decomposed.embedding_2bit.default(packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices)
1050+
return torch.ops.quantized_decomposed.embedding_2bit.default(
1051+
packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices
1052+
)
10091053

1010-
def embedding_2bit_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
1011-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1, 'INT', output_dtype)
1054+
def embedding_2bit_dtype_pattern(
1055+
indices, int_data, group_size, scale, zero_point, output_dtype
1056+
):
1057+
dq = torch.ops.torchao.dequantize_affine.default(
1058+
int_data,
1059+
[1, group_size],
1060+
scale,
1061+
zero_point,
1062+
torch.int8,
1063+
-2,
1064+
1,
1065+
"INT",
1066+
output_dtype,
1067+
)
10121068
return torch.ops.aten.embedding.default(dq, indices)
1013-
def embedding_2bit_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
1014-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 2)
1069+
1070+
def embedding_2bit_dtype_replacement(
1071+
indices, int_data, group_size, scale, zero_point, output_dtype
1072+
):
1073+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1074+
int_data, 2
1075+
)
10151076
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1016-
return torch.ops.quantized_decomposed.embedding_2bit.dtype(packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices, dtype=output_dtype)
1017-
1077+
return torch.ops.quantized_decomposed.embedding_2bit.dtype(
1078+
packed_int_data,
1079+
scale,
1080+
zero_point_dtype_cast,
1081+
-2,
1082+
1,
1083+
indices,
1084+
dtype=output_dtype,
1085+
)
1086+
10181087
def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):
1019-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7)
1088+
dq = torch.ops.torchao.dequantize_affine.default(
1089+
int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7
1090+
)
10201091
return torch.ops.aten.embedding.default(dq, indices)
1092+
10211093
def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):
1022-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 4)
1094+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1095+
int_data, 4
1096+
)
10231097
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1024-
return torch.ops.quantized_decomposed.embedding_4bit.default(packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices)
1025-
1026-
def embedding_4bit_dtype_pattern(indices, int_data, group_size, scale, zero_point, output_dtype):
1027-
dq = torch.ops.torchao.dequantize_affine.default(int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7, 'INT', output_dtype)
1098+
return torch.ops.quantized_decomposed.embedding_4bit.default(
1099+
packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices
1100+
)
1101+
1102+
def embedding_4bit_dtype_pattern(
1103+
indices, int_data, group_size, scale, zero_point, output_dtype
1104+
):
1105+
dq = torch.ops.torchao.dequantize_affine.default(
1106+
int_data,
1107+
[1, group_size],
1108+
scale,
1109+
zero_point,
1110+
torch.int8,
1111+
-8,
1112+
7,
1113+
"INT",
1114+
output_dtype,
1115+
)
10281116
return torch.ops.aten.embedding.default(dq, indices)
1029-
def embedding_4bit_dtype_replacement(indices, int_data, group_size, scale, zero_point, output_dtype):
1030-
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(int_data, 4)
1117+
1118+
def embedding_4bit_dtype_replacement(
1119+
indices, int_data, group_size, scale, zero_point, output_dtype
1120+
):
1121+
packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(
1122+
int_data, 4
1123+
)
10311124
zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)
1032-
return torch.ops.quantized_decomposed.embedding_4bit.dtype(packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices, dtype=output_dtype)
1125+
return torch.ops.quantized_decomposed.embedding_4bit.dtype(
1126+
packed_int_data,
1127+
scale,
1128+
zero_point_dtype_cast,
1129+
-8,
1130+
7,
1131+
indices,
1132+
dtype=output_dtype,
1133+
)
10331134

10341135
return [
1035-
(_trace_and_lower_to_edge_ops(embedding_byte_pattern), _trace_and_lower_to_edge_ops(embedding_byte_replacement), []),
1036-
(_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement), []),
1037-
(_trace_and_lower_to_edge_ops(embedding_2bit_pattern), _trace_and_lower_to_edge_ops(embedding_2bit_replacement), []),
1038-
(_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement), []),
1039-
(_trace_and_lower_to_edge_ops(embedding_4bit_pattern), _trace_and_lower_to_edge_ops(embedding_4bit_replacement), []),
1040-
(_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern), _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement), []),
1136+
(
1137+
_trace_and_lower_to_edge_ops(embedding_byte_pattern),
1138+
_trace_and_lower_to_edge_ops(embedding_byte_replacement),
1139+
[],
1140+
),
1141+
(
1142+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),
1143+
_trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),
1144+
[],
1145+
),
1146+
(
1147+
_trace_and_lower_to_edge_ops(embedding_2bit_pattern),
1148+
_trace_and_lower_to_edge_ops(embedding_2bit_replacement),
1149+
[],
1150+
),
1151+
(
1152+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),
1153+
_trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),
1154+
[],
1155+
),
1156+
(
1157+
_trace_and_lower_to_edge_ops(embedding_4bit_pattern),
1158+
_trace_and_lower_to_edge_ops(embedding_4bit_replacement),
1159+
[],
1160+
),
1161+
(
1162+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),
1163+
_trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),
1164+
[],
1165+
),
10411166
]
10421167

10431168

exir/passes/quant_fusion_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def _get_qparams(node):
8989
qnode.replace_all_uses_with(maybe_cat)
9090
model.graph.erase_node(qnode)
9191

92+
9293
def _remove_dtype_getattr_nodes(model: GraphModule) -> None:
9394
for n in model.graph.nodes:
9495
if n.op == "call_function" and n.target == getattr:
@@ -99,7 +100,8 @@ def _remove_dtype_getattr_nodes(model: GraphModule) -> None:
99100
model.graph.eliminate_dead_code()
100101
model.graph.lint()
101102
model.recompile()
102-
103+
104+
103105
class QuantFusionPass(ExportPass):
104106
def __init__(self, _fix_node_meta_val=False):
105107
super().__init__()

0 commit comments

Comments
 (0)