|
4 | 4 | # This source code is licensed under the BSD-style license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
| 7 | +import operator |
| 8 | + |
7 | 9 | from typing import Optional
|
8 | 10 |
|
9 | 11 | import executorch.backends.vulkan.utils as utils
|
@@ -117,8 +119,19 @@ def __init__(self, mm_node: torch.fx.Node) -> None:
|
117 | 119 | self.match_found = True
|
118 | 120 | return
|
119 | 121 |
|
120 |
| - self.input_scales_node = self.quantize_input_node.args[1] |
121 |
| - self.input_zeros_node = self.quantize_input_node.args[2] |
| 122 | + scales_arg_idx = 1 |
| 123 | + zeros_arg_idx = 2 |
| 124 | + |
| 125 | + # torchao op has a slightly different function schema |
| 126 | + if ( |
| 127 | + self.quantize_input_node.target |
| 128 | + == exir_ops.edge.torchao.quantize_affine.default |
| 129 | + ): |
| 130 | + scales_arg_idx = 2 |
| 131 | + zeros_arg_idx = 3 |
| 132 | + |
| 133 | + self.input_scales_node = self.quantize_input_node.args[scales_arg_idx] |
| 134 | + self.input_zeros_node = self.quantize_input_node.args[zeros_arg_idx] |
122 | 135 |
|
123 | 136 | assert dq_node is not None
|
124 | 137 | self.all_nodes.extend(
|
@@ -164,6 +177,27 @@ def is_input_static_per_tensor_quantized(self) -> bool:
|
164 | 177 | # are scalars.
|
165 | 178 | return isinstance(self.input_scales_node, float)
|
166 | 179 |
|
| 180 | + def is_input_dynamic_perchannel_quantized(self) -> bool: |
| 181 | + if self.quantize_input_node is None: |
| 182 | + return False |
| 183 | + |
| 184 | + if not isinstance(self.input_scales_node, torch.fx.Node): |
| 185 | + return False |
| 186 | + |
| 187 | + # For dynamic quantization, input scale node should be a getitem operator |
| 188 | + # retrieving the output of a choose_qparams op |
| 189 | + if self.input_scales_node.target != operator.getitem: |
| 190 | + return False |
| 191 | + |
| 192 | + # The getitem node should be retrieving from a choose_qparams op |
| 193 | + if not utils.is_choose_qparams_node(self.input_scales_node.args[0]): |
| 194 | + return False |
| 195 | + |
| 196 | + scales_shape = self.input_scales_node.meta["val"].shape |
| 197 | + input_shape = self.fp_input_node.meta["val"].shape |
| 198 | + |
| 199 | + return input_shape[-2] == scales_shape[-1] |
| 200 | + |
167 | 201 |
|
168 | 202 | linear_anchor_nodes = {
|
169 | 203 | exir_ops.edge.aten.linear.default,
|
@@ -230,6 +264,34 @@ def pack_4bit_weight_tensor(weight_tensor: torch.Tensor) -> torch.Tensor:
|
230 | 264 | return weight_tensor[::, 1::2] << 4 | weight_tensor[::, ::2]
|
231 | 265 |
|
232 | 266 |
|
| 267 | +def compute_per_group_sums(weight_tensor: torch.Tensor, group_size: int): |
| 268 | + """ |
| 269 | + Compute the sum of weights per quantization group. |
| 270 | +
|
| 271 | + Args: |
| 272 | + weight_tensor (torch.Tensor): Tensor of shape [out_channels, in_channels], dtype int8. |
| 273 | + group_size (int): Number of input channels per quantization group. |
| 274 | +
|
| 275 | + Returns: |
| 276 | + torch.Tensor: Tensor of shape [num_groups, out_channels], where num_groups = in_channels // group_size. |
| 277 | + """ |
| 278 | + out_channels, in_channels = weight_tensor.shape |
| 279 | + num_groups = in_channels // group_size |
| 280 | + # Reshape to [out_channels, num_groups, group_size] |
| 281 | + reshaped = weight_tensor.view(out_channels, num_groups, group_size) |
| 282 | + # Sum over group_size dimension to get [out_channels, num_groups] |
| 283 | + sums = reshaped.sum(dim=2) |
| 284 | + # Transpose to [num_groups, out_channels] |
| 285 | + sums = sums.transpose(0, 1).contiguous() |
| 286 | + # Pad out_channels dim (dim=1) to be a multiple of 8 if needed |
| 287 | + out_channels = sums.shape[1] |
| 288 | + if out_channels % 8 != 0: |
| 289 | + num_pad = 8 - (out_channels % 8) |
| 290 | + sums = F.pad(sums, (0, num_pad)) |
| 291 | + |
| 292 | + return sums.to(torch.int32).contiguous() |
| 293 | + |
| 294 | + |
233 | 295 | ##
|
234 | 296 | ## Pattern Replacement
|
235 | 297 | ##
|
@@ -281,6 +343,73 @@ def make_linear_q4gsw_op(
|
281 | 343 | match.output_node.replace_all_uses_with(linear_q4gsw_node)
|
282 | 344 |
|
283 | 345 |
|
| 346 | +def make_linear_dq8ca_q4gsw_op( |
| 347 | + ep: ExportedProgram, |
| 348 | + graph_module: torch.fx.GraphModule, |
| 349 | + match: QuantizedLinearMatch, |
| 350 | + weight_tensor: torch.Tensor, |
| 351 | + weight_scales_tensor: torch.Tensor, |
| 352 | +): |
| 353 | + num_groups = weight_scales_tensor.shape[-1] |
| 354 | + in_channels = weight_tensor.shape[-1] |
| 355 | + group_size = in_channels // num_groups |
| 356 | + |
| 357 | + # Compute per quant group sums before packing the weight tensor |
| 358 | + sum_per_quant_group = compute_per_group_sums(weight_tensor, group_size) |
| 359 | + |
| 360 | + weight_tensor = pack_4bit_weight_tensor(weight_tensor) |
| 361 | + # Use this function for convenience to update the state dict with the packed |
| 362 | + # weight tensor. Alignment will already have been done in the above function. |
| 363 | + weight_tensor = utils.align_width_and_update_state_dict( |
| 364 | + ep, match.weight_node, weight_tensor, align_to=1, force_update=True |
| 365 | + ) |
| 366 | + |
| 367 | + # Also transpose the weight scales tensor to shape [num_groups, N] |
| 368 | + weight_scales_tensor = weight_scales_tensor.transpose(0, 1).contiguous() |
| 369 | + utils.align_width_and_update_state_dict( |
| 370 | + ep, |
| 371 | + match.weight_scales_node, |
| 372 | + weight_scales_tensor, |
| 373 | + align_to=1, |
| 374 | + force_update=True, |
| 375 | + ) |
| 376 | + |
| 377 | + first_graph_node = list(graph_module.graph.nodes)[0] |
| 378 | + with graph_module.graph.inserting_before(first_graph_node): |
| 379 | + weight_tensor_name = utils.get_tensor_name(ep, match.weight_node) |
| 380 | + # Pre-compute the weight sums which are needed to apply activation zero point |
| 381 | + # when using integer accumulation. |
| 382 | + sums_name = weight_tensor_name + "_sums" |
| 383 | + # Sanitize the name |
| 384 | + sums_name = sums_name.replace(".", "_") |
| 385 | + |
| 386 | + weight_sums_node = create_constant_placeholder( |
| 387 | + exp_program=ep, |
| 388 | + graph=graph_module.graph, |
| 389 | + kind=InputKind.CONSTANT_TENSOR, |
| 390 | + name=sums_name, |
| 391 | + data=sum_per_quant_group, |
| 392 | + ) |
| 393 | + |
| 394 | + with graph_module.graph.inserting_before(match.output_node): |
| 395 | + qlinear_node = graph_module.graph.create_node( |
| 396 | + "call_function", |
| 397 | + exir_ops.edge.et_vk.linear_dq8ca_q4gsw.default, |
| 398 | + args=( |
| 399 | + match.fp_input_node, |
| 400 | + match.input_scales_node, |
| 401 | + match.input_zeros_node, |
| 402 | + match.weight_node, |
| 403 | + weight_sums_node, |
| 404 | + match.weight_scales_node, |
| 405 | + group_size, |
| 406 | + ), |
| 407 | + ) |
| 408 | + |
| 409 | + qlinear_node.meta["val"] = match.output_node.meta["val"] |
| 410 | + match.output_node.replace_all_uses_with(qlinear_node) |
| 411 | + |
| 412 | + |
284 | 413 | def make_linear_q8ta_q8csw_custom_op(
|
285 | 414 | ep: ExportedProgram,
|
286 | 415 | graph_module: torch.fx.GraphModule,
|
@@ -354,10 +483,16 @@ def replace_quantized_linear_patterns(
|
354 | 483 | make_linear_q4gsw_op(
|
355 | 484 | ep, graph_module, match, weight_tensor, weight_scales_tensor
|
356 | 485 | )
|
| 486 | + elif ( |
| 487 | + match.is_input_dynamic_perchannel_quantized() |
| 488 | + and match.is_weight_pergroup_quantized() |
| 489 | + and utils.is_in_4bit_range(weight_tensor) |
| 490 | + ): |
| 491 | + make_linear_dq8ca_q4gsw_op( |
| 492 | + ep, graph_module, match, weight_tensor, weight_scales_tensor |
| 493 | + ) |
357 | 494 | elif (
|
358 | 495 | match.is_input_static_per_tensor_quantized()
|
359 | 496 | and match.is_weight_perchannel_quantized()
|
360 | 497 | ):
|
361 | 498 | make_linear_q8ta_q8csw_custom_op(ep, graph_module, match, weight_tensor)
|
362 |
| - |
363 |
| - # No-op for unsupported quant patterns |
|
0 commit comments