Skip to content

Commit bf2b0d0

Browse files
Add INT8 realquant support (#166)
Co-authored-by: ishan-modi <[email protected]>
1 parent 7a04743 commit bf2b0d0

File tree

5 files changed

+158
-0
lines changed

5 files changed

+158
-0
lines changed

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
BaseQuantizedTensor,
3838
FP8QTensor,
3939
INT4QTensor,
40+
INT8QTensor,
4041
NF4QTensor,
4142
NVFP4QTensor,
4243
QTensorWrapper,
@@ -547,6 +548,7 @@ def _is_real_quantize_support(self):
547548
(self._num_bits == 4 and self._block_sizes) # NF4 and Int4
548549
or (self._num_bits == (2, 1) and self._block_sizes) # NVFP4
549550
or (self._num_bits == (4, 3)) # FP8
551+
or (self._num_bits == 8) # Int8
550552
):
551553
return True
552554
return False
@@ -565,6 +567,11 @@ def _real_quantize(self, inputs):
565567
scales=self.amax / 448.0 if self.amax is not None else None,
566568
)
567569
buffer_to_register["_scale"] = _scale
570+
elif self._num_bits == 8:
571+
outputs, _scale = INT8QTensor.quantize(
572+
inputs, axis=self._axis, block_sizes=self._block_sizes
573+
)
574+
buffer_to_register["_scale"] = _scale
568575
elif self._block_sizes.get("scale_bits", 0) == 8 and self._block_sizes.get(
569576
"scale_block_sizes", None
570577
):

modelopt/torch/quantization/qtensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@
2020
from .base_qtensor import *
2121
from .fp8_tensor import *
2222
from .int4_tensor import *
23+
from .int8_tensor import *
2324
from .nf4_tensor import *
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Implements INT8 quantization for efficient tensor storage and computation."""
16+
17+
from typing import Union
18+
19+
import torch
20+
21+
from ..qtensor.base_qtensor import BaseQuantizedTensor
22+
from ..utils import (
23+
convert_quantization_axis_to_reduce_axis,
24+
reduce_amax,
25+
reduce_block_amax,
26+
reduce_block_padding,
27+
)
28+
29+
30+
class INT8QTensor(BaseQuantizedTensor):
31+
"""Implements the INT8 quantization on tensors for more efficient storage or computation.
32+
33+
Attributes:
34+
quantized_data (torch.Tensor): The quantized data stored as an INT8 tensor.
35+
"""
36+
37+
@classmethod
38+
def quantize(
39+
cls,
40+
input: torch.Tensor,
41+
scales: torch.Tensor = None,
42+
axis: Union[tuple, int, None] = None,
43+
block_sizes: dict = None,
44+
) -> tuple:
45+
"""Converting a tensor to a quantized format based on INT8 quantization.
46+
47+
Args:
48+
input (torch.Tensor): The input tensor to be quantized.
49+
scales (torch.Tensor): The scales for quantization.
50+
axis: The dimensions to reduce for quantization. None or int or tuple of ints.
51+
block_sizes (dict): A dictionary specifying the block size for each dimension.
52+
Note: One can only provide axis or block_sizes for INT8 quantization.
53+
54+
Returns:
55+
tuple: INT8QTensor, scales
56+
"""
57+
original_input = input
58+
if scales is None:
59+
if block_sizes:
60+
input = reduce_block_padding(input, block_sizes)
61+
amax = reduce_block_amax(input, block_sizes)
62+
else:
63+
reduce_axis = convert_quantization_axis_to_reduce_axis(input, axis)
64+
amax = reduce_amax(input, axis=reduce_axis)
65+
scales = amax / 127.0
66+
67+
# Calculate the scale shape and make sure it aligns with input and block_sizes
68+
expected_shape = list(input.shape)
69+
expanded_scales = scales.clone()
70+
if block_sizes:
71+
for dim, block_size in block_sizes.items():
72+
dim = dim if dim >= 0 else len(input.shape) + dim # Convert negative index
73+
assert input.shape[dim] % block_size == 0, (
74+
f"Tensor dimension {dim}, {input.shape[dim]} is not divisible by {block_size}."
75+
)
76+
expected_shape[dim] = (
77+
input.shape[dim] // block_size
78+
) # Adjust expected shape for blocks
79+
80+
# Assert the shape of `scales` matches expected reduced dimensions
81+
assert scales.shape == tuple(expected_shape), (
82+
f"Mismatch in expected scale shape: {scales.shape} vs {tuple(expected_shape)}"
83+
)
84+
85+
# Expand scales for broadcasting
86+
for dim, block_size in block_sizes.items():
87+
expanded_scales = expanded_scales.repeat_interleave(block_size, dim=dim)
88+
89+
# Quantization
90+
quantized_data = (input / expanded_scales).round().clamp(-128, 127).to(torch.int8)
91+
92+
return cls(original_input.shape, original_input.dtype, quantized_data), scales
93+
94+
def dequantize(self, dtype: torch.dtype = None, **kwarg):
95+
"""Dequantize INT8 packed tensor to a target dtype."""
96+
if dtype is None:
97+
dtype = self.metadata["dtype"]
98+
assert "scale" in kwarg, "Require scale for INT8 dequantization."
99+
100+
# Get args
101+
scales = kwarg["scale"]
102+
block_sizes = kwarg.get("block_sizes", None)
103+
104+
shape = self._quantized_data.shape
105+
if block_sizes:
106+
# Compute expanded shape for broadcasting scales
107+
expanded_shape = list(shape)
108+
for dim, block_size in block_sizes.items():
109+
assert shape[dim] % block_size == 0, (
110+
f"Dimension {shape[dim]} is not divisible by {block_size}."
111+
)
112+
expanded_shape[dim] //= block_size # Reduce the dimension size for blocks
113+
114+
assert tuple(expanded_shape) == scales.shape, (
115+
f"Scales shape {scales.shape} must match expected {tuple(expanded_shape)}."
116+
)
117+
118+
# Expand scales for broadcasting
119+
for dim, block_size in block_sizes.items():
120+
scales = scales.repeat_interleave(block_size, dim=dim)
121+
122+
# Handle padded tensors
123+
slices = tuple(slice(0, dim) for dim in self.metadata["shape"])
124+
125+
return (self._quantized_data.view(torch.int8).to(dtype) * scales.to(dtype))[slices]

tests/gpu/torch/quantization/test_qtensor_cuda.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,28 @@ def test_amax_from_tensor_quantizer(
146146
dtype=torch.bfloat16,
147147
),
148148
),
149+
# INT8 per channel quantization
150+
(
151+
8,
152+
None,
153+
0,
154+
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]], dtype=torch.bfloat16),
155+
torch.tensor(
156+
[[0.0000, 0.9922, 1.9844, 2.9844, 3.9688, 4.9688, 5.9688, 7.0000]],
157+
dtype=torch.bfloat16,
158+
),
159+
),
160+
# INT8 2D block quantization
161+
(
162+
8,
163+
{-1: 2, -2: 2},
164+
None,
165+
torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]], dtype=torch.bfloat16),
166+
torch.tensor(
167+
[[0.0000, 1.0234, 1.9844, 2.9844], [4.0000, 5.0000, 5.9688, 7.0000]],
168+
dtype=torch.bfloat16,
169+
),
170+
),
149171
# FP8, 2D block scales
150172
(
151173
(4, 3),

tests/gpu/torch/quantization/test_real_quantize_cuda.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"config",
3232
[
3333
mtq.INT4_AWQ_CFG,
34+
mtq.INT8_DEFAULT_CFG,
3435
mtq.FP8_DEFAULT_CFG,
3536
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
3637
mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
@@ -82,6 +83,7 @@ def forward_loop(model):
8283
"config",
8384
[
8485
mtq.INT4_AWQ_CFG,
86+
mtq.INT8_DEFAULT_CFG,
8587
mtq.FP8_DEFAULT_CFG,
8688
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
8789
mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,
@@ -105,6 +107,7 @@ def test_save_restore(model_cls, config):
105107
"quant_config",
106108
[
107109
mtq.INT4_AWQ_CFG,
110+
mtq.INT8_DEFAULT_CFG,
108111
mtq.FP8_DEFAULT_CFG,
109112
mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
110113
mtq.FP8_PER_CHANNEL_PER_TOKEN_CFG,

0 commit comments

Comments
 (0)