| 
22 | 22 |     "get_quant_patterns_and_replacements",  | 
23 | 23 | ]  | 
24 | 24 | 
 
  | 
 | 25 | + | 
 | 26 | +from torch import Tensor  | 
 | 27 | +from torch.library import custom_op  | 
 | 28 | + | 
 | 29 | + | 
 | 30 | +@custom_op("quant_fusion::_pack_embedding_weight", mutates_args=())  | 
 | 31 | +def _pack_embedding_weight(weight: Tensor, bitwidth: int) -> Tensor:  | 
 | 32 | +    num_embeddings, embedding_dim = weight.shape  | 
 | 33 | + | 
 | 34 | +    if bitwidth == 2:  | 
 | 35 | +        assert embedding_dim % 4 == 0, "embedding_dim must be divisible by 4"  | 
 | 36 | +        weight_range_shifted = weight.add(2).view(torch.uint8)  | 
 | 37 | +        weight_view = weight_range_shifted.view(num_embeddings, embedding_dim // 4, 4)  | 
 | 38 | +        weight_0 = weight_view[:, :, 0]  | 
 | 39 | +        weight_1 = weight_view[:, :, 1] << 2  | 
 | 40 | +        weight_2 = weight_view[:, :, 2] << 4  | 
 | 41 | +        weight_3 = weight_view[:, :, 3] << 6  | 
 | 42 | +        packed_weight = weight_0 | weight_1 | weight_2 | weight_3  | 
 | 43 | +        return packed_weight  | 
 | 44 | +    elif bitwidth == 4:  | 
 | 45 | +        assert embedding_dim % 2 == 0, "embedding_dim must be divisible by 2"  | 
 | 46 | +        weight_range_shifted = weight.add(8).view(torch.uint8)  | 
 | 47 | +        weight_view = weight_range_shifted.view(  | 
 | 48 | +            weight.shape[0], weight.shape[1] // 2, 2  | 
 | 49 | +        )  | 
 | 50 | +        weight_even = weight_view[:, :, 0] << 4  | 
 | 51 | +        weight_odd = weight_view[:, :, 1]  | 
 | 52 | +        packed_weight = weight_even | weight_odd  | 
 | 53 | +        return packed_weight  | 
 | 54 | +    elif bitwidth == 8:  | 
 | 55 | +        return weight  | 
 | 56 | + | 
 | 57 | +    raise RuntimeError(f"Unsupported bitwidth {bitwidth}")  | 
 | 58 | + | 
 | 59 | + | 
 | 60 | +# Use register_fake to add a ``FakeTensor`` kernel for the operator  | 
 | 61 | +@_pack_embedding_weight.register_fake  | 
 | 62 | +def _(weight, bit_width):  | 
 | 63 | +    assert bit_width in [2, 4, 8]  | 
 | 64 | +    num_embeddings, embedding_dim = weight.shape  | 
 | 65 | +    values_per_byte = 8 // bit_width  | 
 | 66 | +    assert embedding_dim % values_per_byte == 0  | 
 | 67 | +    return torch.empty(  | 
 | 68 | +        num_embeddings,  | 
 | 69 | +        embedding_dim // values_per_byte,  | 
 | 70 | +        dtype=torch.uint8,  | 
 | 71 | +        device=weight.device,  | 
 | 72 | +    )  | 
 | 73 | + | 
 | 74 | + | 
25 | 75 | # TODO: extending an existing library that is defined in OSS might be a bit  | 
26 | 76 | # confusing, we can investigate if it is possible to define a new library  | 
27 | 77 | 
 
  | 
@@ -69,9 +119,10 @@ def embedding_weight_checks(weight, weight_scales, weight_zero_points):  | 
69 | 119 |     assert (  | 
70 | 120 |         weight_zero_points is None or weight_zero_points.dtype == weight_scales.dtype  | 
71 | 121 |     ), "Expecting weight_zero_points to be None or have same dtype as weight_scales"  | 
72 |  | -    assert (  | 
73 |  | -        weight_zero_points is None or weight_zero_points.dim() == 1  | 
74 |  | -    ), f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"  | 
 | 122 | +    assert weight_zero_points is None or weight_zero_points.dim() in [  | 
 | 123 | +        1,  | 
 | 124 | +        2,  | 
 | 125 | +    ], f"Expecting weight_zero_points tensor to be None or have dim()==1, but found {weight_zero_points.dim()}"  | 
75 | 126 |     assert weight_zero_points is None or weight_zero_points.size(0) == weight.size(  | 
76 | 127 |         0  | 
77 | 128 |     ), f"Expecting weight_zero_points tensor to be None or have same number of rows as weights, but found {weight.size()} and {weight_zero_points.size()}"  | 
@@ -234,6 +285,21 @@ def embedding_2bit(  | 
234 | 285 |     return torch.ops.aten.embedding.default(weight, indices)  | 
235 | 286 | 
 
  | 
236 | 287 | 
 
  | 
 | 288 | +@register_fake("quantized_decomposed::embedding_2bit")  | 
 | 289 | +def _(  | 
 | 290 | +    weight: torch.Tensor,  | 
 | 291 | +    weight_scales: torch.Tensor,  | 
 | 292 | +    weight_zero_points: Optional[torch.Tensor],  | 
 | 293 | +    weight_quant_min: int,  | 
 | 294 | +    weight_quant_max: int,  | 
 | 295 | +    indices: torch.Tensor,  | 
 | 296 | +):  | 
 | 297 | +    num_embeddings, packed_embedding_dim = weight.shape  | 
 | 298 | +    embedding_dim = packed_embedding_dim * 4  | 
 | 299 | +    embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)  | 
 | 300 | +    return embedding(indices)  | 
 | 301 | + | 
 | 302 | + | 
237 | 303 | @register_fake("quantized_decomposed::embedding_2bit.out")  | 
238 | 304 | def embedding_2bit_out_meta(  | 
239 | 305 |     weight: torch.Tensor,  | 
@@ -296,6 +362,22 @@ def embedding_2bit_dtype(  | 
296 | 362 |     return torch.ops.aten.embedding.default(weight, indices)  | 
297 | 363 | 
 
  | 
298 | 364 | 
 
  | 
 | 365 | +@register_fake("quantized_decomposed::embedding_2bit.dtype")  | 
 | 366 | +def _(  | 
 | 367 | +    weight: torch.Tensor,  | 
 | 368 | +    weight_scales: torch.Tensor,  | 
 | 369 | +    weight_zero_points: Optional[torch.Tensor],  | 
 | 370 | +    weight_quant_min: int,  | 
 | 371 | +    weight_quant_max: int,  | 
 | 372 | +    indices: torch.Tensor,  | 
 | 373 | +    dtype: Optional[torch.dtype],  | 
 | 374 | +) -> torch.Tensor:  | 
 | 375 | +    num_embeddings, packed_embedding_dim = weight.shape  | 
 | 376 | +    embedding_dim = packed_embedding_dim * 4  | 
 | 377 | +    embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)  | 
 | 378 | +    return embedding(indices).to(dtype)  | 
 | 379 | + | 
 | 380 | + | 
299 | 381 | @register_fake("quantized_decomposed::embedding_2bit.dtype_out")  | 
300 | 382 | def embedding_2bit_dtype_out_meta(  | 
301 | 383 |     weight: torch.Tensor,  | 
@@ -378,6 +460,21 @@ def embedding_4bit(  | 
378 | 460 |     return torch.ops.aten.embedding.default(weight, indices)  | 
379 | 461 | 
 
  | 
380 | 462 | 
 
  | 
 | 463 | +@register_fake("quantized_decomposed::embedding_4bit")  | 
 | 464 | +def _(  | 
 | 465 | +    weight: torch.Tensor,  | 
 | 466 | +    weight_scales: torch.Tensor,  | 
 | 467 | +    weight_zero_points: Optional[torch.Tensor],  | 
 | 468 | +    weight_quant_min: int,  | 
 | 469 | +    weight_quant_max: int,  | 
 | 470 | +    indices: torch.Tensor,  | 
 | 471 | +):  | 
 | 472 | +    num_embeddings, packed_embedding_dim = weight.shape  | 
 | 473 | +    embedding_dim = packed_embedding_dim * 2  | 
 | 474 | +    embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)  | 
 | 475 | +    return embedding(indices)  | 
 | 476 | + | 
 | 477 | + | 
381 | 478 | @register_fake("quantized_decomposed::embedding_4bit.out")  | 
382 | 479 | def embedding_4bit_out_meta(  | 
383 | 480 |     weight: torch.Tensor,  | 
@@ -438,6 +535,22 @@ def embedding_4bit_dtype(  | 
438 | 535 |     return torch.ops.aten.embedding.default(weight, indices)  | 
439 | 536 | 
 
  | 
440 | 537 | 
 
  | 
 | 538 | +@register_fake("quantized_decomposed::embedding_4bit.dtype")  | 
 | 539 | +def _(  | 
 | 540 | +    weight: torch.Tensor,  | 
 | 541 | +    weight_scales: torch.Tensor,  | 
 | 542 | +    weight_zero_points: Optional[torch.Tensor],  | 
 | 543 | +    weight_quant_min: int,  | 
 | 544 | +    weight_quant_max: int,  | 
 | 545 | +    indices: torch.Tensor,  | 
 | 546 | +    dtype: Optional[torch.dtype],  | 
 | 547 | +) -> torch.Tensor:  | 
 | 548 | +    num_embeddings, packed_embedding_dim = weight.shape  | 
 | 549 | +    embedding_dim = packed_embedding_dim * 2  | 
 | 550 | +    embedding = torch.nn.Embedding(num_embeddings, embedding_dim, device=weight.device)  | 
 | 551 | +    return embedding(indices).to(dtype)  | 
 | 552 | + | 
 | 553 | + | 
441 | 554 | @register_fake("quantized_decomposed::embedding_4bit.dtype_out")  | 
442 | 555 | def embedding_4bit_dtype_out_meta(  | 
443 | 556 |     weight: torch.Tensor,  | 
@@ -873,6 +986,186 @@ def replacement(x, dim, start, end, x_scale, x_zero_point, x_qmin, x_qmax):  | 
873 | 986 |     ]  | 
874 | 987 | 
 
  | 
875 | 988 | 
 
  | 
 | 989 | +def _get_embedding_ops_patterns_and_replacements_torchao() -> (  # noqa C901  | 
 | 990 | +    List[Tuple[Callable, Callable, List[Callable]]]  | 
 | 991 | +):  | 
 | 992 | +    def embedding_byte_pattern(indices, int_data, group_size, scale, zero_point):  | 
 | 993 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 994 | +            int_data, [1, group_size], scale, zero_point, torch.int8, -128, 127  | 
 | 995 | +        )  | 
 | 996 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 997 | + | 
 | 998 | +    def embedding_byte_replacement(indices, int_data, group_size, scale, zero_point):  | 
 | 999 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1000 | +        return torch.ops.quantized_decomposed.embedding_byte.default(  | 
 | 1001 | +            int_data,  | 
 | 1002 | +            scale,  | 
 | 1003 | +            zero_point_dtype_cast,  | 
 | 1004 | +            -128,  | 
 | 1005 | +            127,  | 
 | 1006 | +            indices,  | 
 | 1007 | +        )  | 
 | 1008 | + | 
 | 1009 | +    def embedding_byte_dtype_pattern(  | 
 | 1010 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1011 | +    ):  | 
 | 1012 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 1013 | +            int_data,  | 
 | 1014 | +            [1, group_size],  | 
 | 1015 | +            scale,  | 
 | 1016 | +            zero_point,  | 
 | 1017 | +            torch.int8,  | 
 | 1018 | +            -128,  | 
 | 1019 | +            127,  | 
 | 1020 | +            "INT",  | 
 | 1021 | +            output_dtype,  | 
 | 1022 | +        )  | 
 | 1023 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 1024 | + | 
 | 1025 | +    def embedding_byte_dtype_replacement(  | 
 | 1026 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1027 | +    ):  | 
 | 1028 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1029 | +        return torch.ops.quantized_decomposed.embedding_byte.dtype(  | 
 | 1030 | +            int_data,  | 
 | 1031 | +            scale,  | 
 | 1032 | +            zero_point_dtype_cast,  | 
 | 1033 | +            -128,  | 
 | 1034 | +            127,  | 
 | 1035 | +            indices,  | 
 | 1036 | +            dtype=output_dtype,  | 
 | 1037 | +        )  | 
 | 1038 | + | 
 | 1039 | +    def embedding_2bit_pattern(indices, int_data, group_size, scale, zero_point):  | 
 | 1040 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 1041 | +            int_data, [1, group_size], scale, zero_point, torch.int8, -2, 1  | 
 | 1042 | +        )  | 
 | 1043 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 1044 | + | 
 | 1045 | +    def embedding_2bit_replacement(indices, int_data, group_size, scale, zero_point):  | 
 | 1046 | +        packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(  | 
 | 1047 | +            int_data, 2  | 
 | 1048 | +        )  | 
 | 1049 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1050 | +        return torch.ops.quantized_decomposed.embedding_2bit.default(  | 
 | 1051 | +            packed_int_data, scale, zero_point_dtype_cast, -2, 1, indices  | 
 | 1052 | +        )  | 
 | 1053 | + | 
 | 1054 | +    def embedding_2bit_dtype_pattern(  | 
 | 1055 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1056 | +    ):  | 
 | 1057 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 1058 | +            int_data,  | 
 | 1059 | +            [1, group_size],  | 
 | 1060 | +            scale,  | 
 | 1061 | +            zero_point,  | 
 | 1062 | +            torch.int8,  | 
 | 1063 | +            -2,  | 
 | 1064 | +            1,  | 
 | 1065 | +            "INT",  | 
 | 1066 | +            output_dtype,  | 
 | 1067 | +        )  | 
 | 1068 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 1069 | + | 
 | 1070 | +    def embedding_2bit_dtype_replacement(  | 
 | 1071 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1072 | +    ):  | 
 | 1073 | +        packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(  | 
 | 1074 | +            int_data, 2  | 
 | 1075 | +        )  | 
 | 1076 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1077 | +        return torch.ops.quantized_decomposed.embedding_2bit.dtype(  | 
 | 1078 | +            packed_int_data,  | 
 | 1079 | +            scale,  | 
 | 1080 | +            zero_point_dtype_cast,  | 
 | 1081 | +            -2,  | 
 | 1082 | +            1,  | 
 | 1083 | +            indices,  | 
 | 1084 | +            dtype=output_dtype,  | 
 | 1085 | +        )  | 
 | 1086 | + | 
 | 1087 | +    def embedding_4bit_pattern(indices, int_data, group_size, scale, zero_point):  | 
 | 1088 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 1089 | +            int_data, [1, group_size], scale, zero_point, torch.int8, -8, 7  | 
 | 1090 | +        )  | 
 | 1091 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 1092 | + | 
 | 1093 | +    def embedding_4bit_replacement(indices, int_data, group_size, scale, zero_point):  | 
 | 1094 | +        packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(  | 
 | 1095 | +            int_data, 4  | 
 | 1096 | +        )  | 
 | 1097 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1098 | +        return torch.ops.quantized_decomposed.embedding_4bit.default(  | 
 | 1099 | +            packed_int_data, scale, zero_point_dtype_cast, -8, 7, indices  | 
 | 1100 | +        )  | 
 | 1101 | + | 
 | 1102 | +    def embedding_4bit_dtype_pattern(  | 
 | 1103 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1104 | +    ):  | 
 | 1105 | +        dq = torch.ops.torchao.dequantize_affine.default(  | 
 | 1106 | +            int_data,  | 
 | 1107 | +            [1, group_size],  | 
 | 1108 | +            scale,  | 
 | 1109 | +            zero_point,  | 
 | 1110 | +            torch.int8,  | 
 | 1111 | +            -8,  | 
 | 1112 | +            7,  | 
 | 1113 | +            "INT",  | 
 | 1114 | +            output_dtype,  | 
 | 1115 | +        )  | 
 | 1116 | +        return torch.ops.aten.embedding.default(dq, indices)  | 
 | 1117 | + | 
 | 1118 | +    def embedding_4bit_dtype_replacement(  | 
 | 1119 | +        indices, int_data, group_size, scale, zero_point, output_dtype  | 
 | 1120 | +    ):  | 
 | 1121 | +        packed_int_data = torch.ops.quant_fusion._pack_embedding_weight.default(  | 
 | 1122 | +            int_data, 4  | 
 | 1123 | +        )  | 
 | 1124 | +        zero_point_dtype_cast = torch.ops.aten.to.dtype(zero_point, scale.dtype)  | 
 | 1125 | +        return torch.ops.quantized_decomposed.embedding_4bit.dtype(  | 
 | 1126 | +            packed_int_data,  | 
 | 1127 | +            scale,  | 
 | 1128 | +            zero_point_dtype_cast,  | 
 | 1129 | +            -8,  | 
 | 1130 | +            7,  | 
 | 1131 | +            indices,  | 
 | 1132 | +            dtype=output_dtype,  | 
 | 1133 | +        )  | 
 | 1134 | + | 
 | 1135 | +    return [  | 
 | 1136 | +        (  | 
 | 1137 | +            _trace_and_lower_to_edge_ops(embedding_byte_pattern),  | 
 | 1138 | +            _trace_and_lower_to_edge_ops(embedding_byte_replacement),  | 
 | 1139 | +            [],  | 
 | 1140 | +        ),  | 
 | 1141 | +        (  | 
 | 1142 | +            _trace_and_lower_to_edge_ops(embedding_byte_dtype_pattern),  | 
 | 1143 | +            _trace_and_lower_to_edge_ops(embedding_byte_dtype_replacement),  | 
 | 1144 | +            [],  | 
 | 1145 | +        ),  | 
 | 1146 | +        (  | 
 | 1147 | +            _trace_and_lower_to_edge_ops(embedding_2bit_pattern),  | 
 | 1148 | +            _trace_and_lower_to_edge_ops(embedding_2bit_replacement),  | 
 | 1149 | +            [],  | 
 | 1150 | +        ),  | 
 | 1151 | +        (  | 
 | 1152 | +            _trace_and_lower_to_edge_ops(embedding_2bit_dtype_pattern),  | 
 | 1153 | +            _trace_and_lower_to_edge_ops(embedding_2bit_dtype_replacement),  | 
 | 1154 | +            [],  | 
 | 1155 | +        ),  | 
 | 1156 | +        (  | 
 | 1157 | +            _trace_and_lower_to_edge_ops(embedding_4bit_pattern),  | 
 | 1158 | +            _trace_and_lower_to_edge_ops(embedding_4bit_replacement),  | 
 | 1159 | +            [],  | 
 | 1160 | +        ),  | 
 | 1161 | +        (  | 
 | 1162 | +            _trace_and_lower_to_edge_ops(embedding_4bit_dtype_pattern),  | 
 | 1163 | +            _trace_and_lower_to_edge_ops(embedding_4bit_dtype_replacement),  | 
 | 1164 | +            [],  | 
 | 1165 | +        ),  | 
 | 1166 | +    ]  | 
 | 1167 | + | 
 | 1168 | + | 
876 | 1169 | def _get_embedding_ops_patterns_and_replacements() -> (  | 
877 | 1170 |     List[Tuple[Callable, Callable, List[Callable]]]  | 
878 | 1171 | ):  | 
@@ -1167,5 +1460,6 @@ def get_quant_patterns_and_replacements() -> (  | 
1167 | 1460 |             *_get_slice_patterns_and_replacements(),  | 
1168 | 1461 |             # *_get_fixed_qparams_ops_patterns_and_replacements(),  | 
1169 | 1462 |             *_get_embedding_ops_patterns_and_replacements(),  | 
 | 1463 | +            *_get_embedding_ops_patterns_and_replacements_torchao(),  | 
1170 | 1464 |         ]  | 
1171 | 1465 |     )  | 
0 commit comments