Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions test/quantization/algorithm/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from tico.quantization import convert, prepare
from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator
from tico.quantization.config.gptq import GPTQConfig
from tico.quantization.config.ptq import PTQConfig
from tico.quantization.evaluation.evaluate import BACKEND, evaluate
Expand Down Expand Up @@ -100,6 +101,29 @@ def get_example_inputs(self):
return (torch.randn(1, 32, 16, 16),), {}


class NormConv2DWithLogits(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self.dtype = torch.float32
self.m = torch.nn.ModuleList()
self.m.append(torch.nn.Conv2d(128, 256, (3, 3), stride=1))
self.m.append(torch.nn.Conv2d(256, 512, (5, 5), stride=2))

def forward(self, x):
class OutputWithLogits:
def __init__(self, logits):
self.logits = logits

z = self.m[0](x)
z = self.m[1](z)
z = z.reshape((-1, 64)).unsqueeze(0)
return OutputWithLogits(z)

def get_example_inputs(self):
return (torch.randn(1, 128, 32, 32),), {}


class NormConv1D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -133,6 +157,28 @@ def get_example_inputs(self):
return (torch.randn(1, 32, 16),), {}


class NormConv1DWithLogits(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self.dtype = torch.float32
self.conv = torch.nn.Conv1d(128, 256, 3, stride=1)
self.conv2 = torch.nn.Conv1d(256, 512, 5, stride=2)

def forward(self, x):
class OutputWithLogits:
def __init__(self, logits):
self.logits = logits

z = self.conv(x)
z = self.conv2(z)
z = z.reshape((-1, 64)).unsqueeze(0)
return OutputWithLogits(z)

def get_example_inputs(self):
return (torch.randn(1, 128, 32),), {}


class TransposedConv2DGeneral(torch.nn.Module):
def __init__(self):
super().__init__()
Expand All @@ -151,6 +197,30 @@ def get_example_inputs(self):
return (torch.randn(1, 16, 7, 7),), {}


class TransposedConv2DGeneralWithLogits(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self.dtype = torch.float32
self.tconv = torch.nn.ConvTranspose2d(16, 32, (2, 2), stride=2, groups=1)
self.tconv2 = torch.nn.ConvTranspose2d(
32, 16, (3, 3), stride=4, groups=2
) # general groupwise

def forward(self, x):
class OutputWithLogits:
def __init__(self, logits):
self.logits = logits

z = self.tconv(x)
z = self.tconv2(z)
z = z.reshape((-1, 8)).unsqueeze(0)
return OutputWithLogits(z)

def get_example_inputs(self):
return (torch.randn(1, 16, 7, 7),), {}


class NormConv3D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -184,6 +254,29 @@ def get_example_inputs(self):
return (torch.randn(5, 16, 17, 19, 35),), {}


class NormConv3DWithLogits(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = torch.device("cpu")
self.dtype = torch.float32
self.m = torch.nn.ModuleList()
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1))
self.m.append(torch.nn.Conv3d(8, 32, (3, 5, 2), stride=2))

def forward(self, x):
class OutputWithLogits:
def __init__(self, logits):
self.logits = logits

z = self.m[0](x)
z = self.m[1](z)
z = z.reshape((-1, 8)).unsqueeze(0)
return OutputWithLogits(z)

def get_example_inputs(self):
return (torch.randn(5, 16, 17, 19, 35),), {}


class GPTQTest(unittest.TestCase):
@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
Expand Down Expand Up @@ -306,6 +399,44 @@ def test_normconv2d(self):
results["peir"][0] < tolerance
), f"PEIR exceeds tolerance. PEIR:{results['peir'][0]}%, tolerance: {tolerance}%"

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
def test_normconv2d_with_logits(self):
q_m = NormConv2DWithLogits()
q_m.eval()
ori_m = q_m

dataset = [] # type: ignore[var-annotated]
for _ in range(30):
args, _ = ori_m.get_example_inputs()
dataset.append(*args)

calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
sens = calibrator.compute_sensitivity_info()

# Apply GPTQ
q_m = prepare(
q_m,
GPTQConfig(
show_progress=False,
mse="smse",
perchannel=True,
sensitivity=sens,
),
)
for input in dataset:
q_m(input)
convert(q_m, inplace=True)
# check that all convolution nodes are quantized
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
assert (
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
), "first conv node is not quantized"
assert (
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
), "second conv node is not quantized"

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
Expand Down Expand Up @@ -438,6 +569,46 @@ def test_groupwise_conv1d(self):

# TODO add quantization (right now it can't be evaluated on backend)

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
def test_normconv1d_with_logits(self):
q_m = NormConv1DWithLogits()
q_m.eval()
ori_m = q_m

dataset = [] # type: ignore[var-annotated]
for _ in range(30):
args, _ = ori_m.get_example_inputs()
dataset.append(*args)

calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
sens = calibrator.compute_sensitivity_info()

# Apply GPTQ
q_m = prepare(
q_m,
GPTQConfig(
show_progress=False,
mse="smse",
perchannel=True,
sensitivity=sens,
),
)
for input in dataset:
q_m(input)
convert(q_m, inplace=True)
# check that all convolution nodes are quantized
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
assert (
"model.layers.0.conv" in q_m.quantizers # type: ignore[operator]
), "first conv node is not quantized"
assert (
"model.layers.0.conv2" in q_m.quantizers # type: ignore[operator]
), "second conv node is not quantized"

# TODO add quantization

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
Expand All @@ -464,6 +635,46 @@ def test_transposed_conv2d(self):

# TODO add quantization

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
def test_transposed_conv2d_with_logits(self):
q_m = TransposedConv2DGeneralWithLogits()
q_m.eval()
ori_m = q_m

dataset = [] # type: ignore[var-annotated]
for _ in range(30):
args, _ = ori_m.get_example_inputs()
dataset.append(*args)

calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
sens = calibrator.compute_sensitivity_info()

# Apply GPTQ
q_m = prepare(
q_m,
GPTQConfig(
show_progress=False,
mse="smse",
perchannel=True,
sensitivity=sens,
),
)
for input in dataset:
q_m(input)
convert(q_m, inplace=True)
# check that all convolution nodes are quantized
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
assert (
"model.layers.0.tconv" in q_m.quantizers # type: ignore[operator]
), "first conv node is not quantized"
assert (
"model.layers.0.tconv2" in q_m.quantizers # type: ignore[operator]
), "second conv node is not quantized"

# TODO add quantization

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
Expand Down Expand Up @@ -524,3 +735,41 @@ def test_paddednormconv3d(self):
assert (
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
), "first conv node is not quantized"

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
def test_normconv3d_with_logits(self):
q_m = NormConv3DWithLogits()
q_m.eval()
ori_m = q_m

dataset = [] # type: ignore[var-annotated]
for _ in range(30):
args, _ = ori_m.get_example_inputs()
dataset.append(*args)

calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
sens = calibrator.compute_sensitivity_info()

# Apply GPTQ
q_m = prepare(
q_m,
GPTQConfig(
show_progress=False,
mse="smse",
perchannel=True,
sensitivity=sens,
),
)
for input in dataset:
q_m(input)
convert(q_m, inplace=True)
# check that all convolution nodes are quantized
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
assert (
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
), "first conv node is not quantized"
assert (
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
), "second conv node is not quantized"
Loading
Loading