|
23 | 23 | ScalarType.QINT32: torch.qint32,
|
24 | 24 | }
|
25 | 25 |
|
26 |
| -_Number = bool | int | float |
27 |
| - |
28 | 26 |
|
29 | 27 | @impl(m, "quantize_per_tensor")
|
30 | 28 | def quantize_per_tensor(
|
@@ -298,8 +296,7 @@ def quantized_layer_norm_per_tensor(
|
298 | 296 | )
|
299 | 297 |
|
300 | 298 |
|
301 |
| -@impl(m, "quantized_conv_nchw") |
302 |
| -def quantized_conv_nchw( |
| 299 | +def quantized_conv( |
303 | 300 | input_tensor: torch.Tensor,
|
304 | 301 | weight: torch.Tensor,
|
305 | 302 | bias: torch.Tensor,
|
@@ -374,6 +371,120 @@ def quantized_conv_nchw(
|
374 | 371 | )
|
375 | 372 |
|
376 | 373 |
|
| 374 | +@impl(m, "quantized_conv_nchw") |
| 375 | +def quantized_conv_nchw( |
| 376 | + input_tensor: torch.Tensor, |
| 377 | + weight: torch.Tensor, |
| 378 | + bias: torch.Tensor, |
| 379 | + stride: tuple[int, int], |
| 380 | + padding: tuple[int, int], |
| 381 | + dilation: tuple[int, int], |
| 382 | + groups: int, |
| 383 | + in_zero_point: int, |
| 384 | + weight_zero_point: torch.Tensor, |
| 385 | + bias_scale: torch.Tensor, |
| 386 | + output_scale: float, |
| 387 | + output_zero_point: int, |
| 388 | + out_multiplier: torch.Tensor, |
| 389 | + out_shift: torch.Tensor, |
| 390 | +) -> torch.Tensor: |
| 391 | + """ |
| 392 | + Quantized convolution operation. |
| 393 | +
|
| 394 | + Args: |
| 395 | + - input_tensor (Tensor): The activations tensor |
| 396 | + - weight (Tensor): The weight tensor |
| 397 | + - bias (Tensor): The bias tensor |
| 398 | + - stride (Tuple[int]): The stride of the convolution |
| 399 | + - padding (Tuple[int]): The padding of the convolution |
| 400 | + - dilation (Tuple[int]): The dilation of the convolution |
| 401 | + - groups (int): The number of groups |
| 402 | + - in_zero_point (int): The quantized mapping of zero for the input |
| 403 | + - weight_zero_point (Tensor): The quantized mapping of zero for the weight |
| 404 | + - bias_scale (Tensor): The quantized bias scale |
| 405 | + - output_scale (float): The scale of the output |
| 406 | + - output_zero_point (int): The zero point of the output |
| 407 | + - out_multiplier (Tensor): Unused |
| 408 | + - out_shift (Tensor): Unused |
| 409 | + """ |
| 410 | + if not input_tensor.is_contiguous(memory_format=torch.contiguous_format): |
| 411 | + raise ValueError("Input tensor must be in NCHW format") |
| 412 | + return quantized_conv( |
| 413 | + input_tensor, |
| 414 | + weight, |
| 415 | + bias, |
| 416 | + stride, |
| 417 | + padding, |
| 418 | + dilation, |
| 419 | + groups, |
| 420 | + in_zero_point, |
| 421 | + weight_zero_point, |
| 422 | + bias_scale, |
| 423 | + output_scale, |
| 424 | + output_zero_point, |
| 425 | + out_multiplier, |
| 426 | + out_shift, |
| 427 | + ) |
| 428 | + |
| 429 | + |
| 430 | +@impl(m, "quantized_conv_nhwc") |
| 431 | +def quantized_conv_nhwc( |
| 432 | + input_tensor: torch.Tensor, |
| 433 | + weight: torch.Tensor, |
| 434 | + bias: torch.Tensor, |
| 435 | + stride: tuple[int, int], |
| 436 | + padding: tuple[int, int], |
| 437 | + dilation: tuple[int, int], |
| 438 | + groups: int, |
| 439 | + in_zero_point: int, |
| 440 | + weight_zero_point: torch.Tensor, |
| 441 | + bias_scale: torch.Tensor, |
| 442 | + output_scale: float, |
| 443 | + output_zero_point: int, |
| 444 | + out_multiplier: torch.Tensor, |
| 445 | + out_shift: torch.Tensor, |
| 446 | +) -> torch.Tensor: |
| 447 | + """ |
| 448 | + Quantized convolution operation. |
| 449 | +
|
| 450 | + Args: |
| 451 | + - input_tensor (Tensor): The activations tensor |
| 452 | + - weight (Tensor): The weight tensor |
| 453 | + - bias (Tensor): The bias tensor |
| 454 | + - stride (Tuple[int]): The stride of the convolution |
| 455 | + - padding (Tuple[int]): The padding of the convolution |
| 456 | + - dilation (Tuple[int]): The dilation of the convolution |
| 457 | + - groups (int): The number of groups |
| 458 | + - in_zero_point (int): The quantized mapping of zero for the input |
| 459 | + - weight_zero_point (Tensor): The quantized mapping of zero for the weight |
| 460 | + - bias_scale (Tensor): The quantized bias scale |
| 461 | + - output_scale (float): The scale of the output |
| 462 | + - output_zero_point (int): The zero point of the output |
| 463 | + - out_multiplier (Tensor): Unused |
| 464 | + - out_shift (Tensor): Unused |
| 465 | + """ |
| 466 | + |
| 467 | + if not input_tensor.is_contiguous(memory_format=torch.channels_last): |
| 468 | + raise ValueError("Input tensor must be in NHWC format") |
| 469 | + |
| 470 | + return quantized_conv( |
| 471 | + input_tensor, |
| 472 | + weight, |
| 473 | + bias, |
| 474 | + stride, |
| 475 | + padding, |
| 476 | + dilation, |
| 477 | + groups, |
| 478 | + in_zero_point, |
| 479 | + weight_zero_point, |
| 480 | + bias_scale, |
| 481 | + output_scale, |
| 482 | + output_zero_point, |
| 483 | + out_multiplier, |
| 484 | + out_shift, |
| 485 | + ) |
| 486 | + |
| 487 | + |
377 | 488 | @impl(m, "requantize")
|
378 | 489 | def requantize(
|
379 | 490 | input: torch.Tensor,
|
|
0 commit comments