|
6 | 6 |
|
7 | 7 | # pyre-strict
|
8 | 8 |
|
| 9 | + |
9 | 10 | from typing import Optional
|
10 | 11 |
|
11 | 12 | import torch
|
| 13 | + |
12 | 14 | from executorch.exir.scalar_type import ScalarType
|
13 | 15 | from torch.library import impl, Library
|
14 | 16 |
|
|
21 | 23 | ScalarType.QINT32: torch.qint32,
|
22 | 24 | }
|
23 | 25 |
|
| 26 | +_Number = bool | int | float |
| 27 | + |
24 | 28 |
|
25 | 29 | @impl(m, "quantize_per_tensor")
|
26 | 30 | def quantize_per_tensor(
|
@@ -294,6 +298,82 @@ def quantized_layer_norm_per_tensor(
|
294 | 298 | )
|
295 | 299 |
|
296 | 300 |
|
| 301 | +@impl(m, "quantized_conv_nchw") |
| 302 | +def quantized_conv_nchw( |
| 303 | + input_tensor: torch.Tensor, |
| 304 | + weight: torch.Tensor, |
| 305 | + bias: torch.Tensor, |
| 306 | + stride: tuple[int, int], |
| 307 | + padding: tuple[int, int], |
| 308 | + dilation: tuple[int, int], |
| 309 | + groups: int, |
| 310 | + in_zero_point: int, |
| 311 | + weight_zero_point: torch.Tensor, |
| 312 | + bias_scale: torch.Tensor, |
| 313 | + output_scale: float, |
| 314 | + output_zero_point: int, |
| 315 | + out_multiplier: torch.Tensor, |
| 316 | + out_shift: torch.Tensor, |
| 317 | +) -> torch.Tensor: |
| 318 | + """ |
| 319 | + Quantized convolution operation. |
| 320 | +
|
| 321 | + Args: |
| 322 | + - input_tensor (Tensor): The activations tensor |
| 323 | + - weight (Tensor): The weight tensor |
| 324 | + - bias (Tensor): The bias tensor |
| 325 | + - stride (Tuple[int]): The stride of the convolution |
| 326 | + - padding (Tuple[int]): The padding of the convolution |
| 327 | + - dilation (Tuple[int]): The dilation of the convolution |
| 328 | + - groups (int): The number of groups |
| 329 | + - in_zero_point (int): The quantized mapping of zero for the input |
| 330 | + - weight_zero_point (Tensor): The quantized mapping of zero for the weight |
| 331 | + - bias_scale (Tensor): The quantized bias scale |
| 332 | + - output_scale (float): The scale of the output |
| 333 | + - output_zero_point (int): The zero point of the output |
| 334 | + - out_multiplier (Tensor): Unused |
| 335 | + - out_shift (Tensor): Unused |
| 336 | + """ |
| 337 | + if weight_zero_point.view(-1).shape != (1,): |
| 338 | + raise ValueError("Weight zero point must be a scalar") |
| 339 | + |
| 340 | + if bias_scale.view(-1).shape != (1,): |
| 341 | + raise ValueError("Bias scale must be a scalar") |
| 342 | + |
| 343 | + if len(input_tensor.shape) == 3: |
| 344 | + float_out = torch.nn.functional.conv1d( |
| 345 | + (input_tensor - in_zero_point).float(), |
| 346 | + (weight - weight_zero_point).float(), |
| 347 | + (bias * bias_scale).float(), |
| 348 | + stride[1], |
| 349 | + padding[1], |
| 350 | + dilation[1], |
| 351 | + groups, |
| 352 | + ) |
| 353 | + |
| 354 | + elif len(input_tensor.shape) == 4: |
| 355 | + float_out = torch.nn.functional.conv2d( |
| 356 | + (input_tensor - in_zero_point).float(), |
| 357 | + (weight - weight_zero_point).float(), |
| 358 | + (bias * bias_scale).float(), |
| 359 | + stride, |
| 360 | + padding, |
| 361 | + dilation, |
| 362 | + groups, |
| 363 | + ) |
| 364 | + else: |
| 365 | + raise ValueError("Input tensor must be 3D or 4D") |
| 366 | + |
| 367 | + return quantize_per_tensor( |
| 368 | + float_out, |
| 369 | + 1.0 / output_scale, |
| 370 | + output_zero_point, |
| 371 | + torch.iinfo(input_tensor.dtype).min, |
| 372 | + torch.iinfo(input_tensor.dtype).max, |
| 373 | + input_tensor.dtype, |
| 374 | + ) |
| 375 | + |
| 376 | + |
297 | 377 | @impl(m, "requantize")
|
298 | 378 | def requantize(
|
299 | 379 | input: torch.Tensor,
|
|
0 commit comments