Skip to content

Commit b474690

Browse files
committed
[quantization] [draft] GPTQ for VLM
This PR is the first try-out for full quantization of VLM model by GPTQ+PTQ. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 3c33243 commit b474690

File tree

5 files changed

+805
-67
lines changed

5 files changed

+805
-67
lines changed

test/quantization/algorithm/test_gptq.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
from tico.quantization import convert, prepare
22+
from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator
2223
from tico.quantization.config.gptq import GPTQConfig
2324
from tico.quantization.config.ptq import PTQConfig
2425
from tico.quantization.evaluation.evaluate import BACKEND, evaluate
@@ -100,6 +101,29 @@ def get_example_inputs(self):
100101
return (torch.randn(1, 32, 16, 16),), {}
101102

102103

104+
class NormConv2DWithLogits(torch.nn.Module):
105+
def __init__(self):
106+
super().__init__()
107+
self.device = torch.device("cpu")
108+
self.dtype = torch.float32
109+
self.m = torch.nn.ModuleList()
110+
self.m.append(torch.nn.Conv2d(128, 256, (3, 3), stride=1))
111+
self.m.append(torch.nn.Conv2d(256, 512, (5, 5), stride=2))
112+
113+
def forward(self, x):
114+
class OutputWithLogits:
115+
def __init__(self, logits):
116+
self.logits = logits
117+
118+
z = self.m[0](x)
119+
z = self.m[1](z)
120+
z = z.reshape((-1, 64)).unsqueeze(0)
121+
return OutputWithLogits(z)
122+
123+
def get_example_inputs(self):
124+
return (torch.randn(1, 128, 32, 32),), {}
125+
126+
103127
class NormConv1D(torch.nn.Module):
104128
def __init__(self):
105129
super().__init__()
@@ -133,6 +157,28 @@ def get_example_inputs(self):
133157
return (torch.randn(1, 32, 16),), {}
134158

135159

160+
class NormConv1DWithLogits(torch.nn.Module):
161+
def __init__(self):
162+
super().__init__()
163+
self.device = torch.device("cpu")
164+
self.dtype = torch.float32
165+
self.conv = torch.nn.Conv1d(128, 256, 3, stride=1)
166+
self.conv2 = torch.nn.Conv1d(256, 512, 5, stride=2)
167+
168+
def forward(self, x):
169+
class OutputWithLogits:
170+
def __init__(self, logits):
171+
self.logits = logits
172+
173+
z = self.conv(x)
174+
z = self.conv2(z)
175+
z = z.reshape((-1, 64)).unsqueeze(0)
176+
return OutputWithLogits(z)
177+
178+
def get_example_inputs(self):
179+
return (torch.randn(1, 128, 32),), {}
180+
181+
136182
class TransposedConv2DGeneral(torch.nn.Module):
137183
def __init__(self):
138184
super().__init__()
@@ -151,6 +197,86 @@ def get_example_inputs(self):
151197
return (torch.randn(1, 16, 7, 7),), {}
152198

153199

200+
class TransposedConv2DGeneralWithLogits(torch.nn.Module):
201+
def __init__(self):
202+
super().__init__()
203+
self.device = torch.device("cpu")
204+
self.dtype = torch.float32
205+
self.tconv = torch.nn.ConvTranspose2d(16, 32, (2, 2), stride=2, groups=1)
206+
self.tconv2 = torch.nn.ConvTranspose2d(
207+
32, 16, (3, 3), stride=4, groups=2
208+
) # general groupwise
209+
210+
def forward(self, x):
211+
class OutputWithLogits:
212+
def __init__(self, logits):
213+
self.logits = logits
214+
215+
z = self.tconv(x)
216+
z = self.tconv2(z)
217+
z = z.reshape((-1, 8)).unsqueeze(0)
218+
return OutputWithLogits(z)
219+
220+
def get_example_inputs(self):
221+
return (torch.randn(1, 16, 7, 7),), {}
222+
223+
224+
class NormConv3D(torch.nn.Module):
225+
def __init__(self):
226+
super().__init__()
227+
self.m = torch.nn.ModuleList()
228+
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1))
229+
self.m.append(torch.nn.Conv3d(8, 32, (3, 5, 2), stride=2))
230+
231+
def forward(self, x):
232+
z = self.m[0](x)
233+
z = self.m[1](z)
234+
return z
235+
236+
def get_example_inputs(self):
237+
return (torch.randn(5, 16, 17, 19, 35),), {}
238+
239+
def get_zero_inputs(self):
240+
return (torch.zeros(5, 16, 17, 19, 35),), {}
241+
242+
243+
class PaddedNormConv3D(torch.nn.Module):
244+
def __init__(self):
245+
super().__init__()
246+
self.m = torch.nn.ModuleList()
247+
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1, padding="valid"))
248+
249+
def forward(self, x):
250+
z = self.m[0](x)
251+
return z
252+
253+
def get_example_inputs(self):
254+
return (torch.randn(5, 16, 17, 19, 35),), {}
255+
256+
257+
class NormConv3DWithLogits(torch.nn.Module):
258+
def __init__(self):
259+
super().__init__()
260+
self.device = torch.device("cpu")
261+
self.dtype = torch.float32
262+
self.m = torch.nn.ModuleList()
263+
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1))
264+
self.m.append(torch.nn.Conv3d(8, 32, (3, 5, 2), stride=2))
265+
266+
def forward(self, x):
267+
class OutputWithLogits:
268+
def __init__(self, logits):
269+
self.logits = logits
270+
271+
z = self.m[0](x)
272+
z = self.m[1](z)
273+
z = z.reshape((-1, 8)).unsqueeze(0)
274+
return OutputWithLogits(z)
275+
276+
def get_example_inputs(self):
277+
return (torch.randn(5, 16, 17, 19, 35),), {}
278+
279+
154280
class GPTQTest(unittest.TestCase):
155281
@unittest.skipIf(
156282
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
@@ -273,6 +399,44 @@ def test_normconv2d(self):
273399
results["peir"][0] < tolerance
274400
), f"PEIR exceeds tolerance. PEIR:{results['peir'][0]}%, tolerance: {tolerance}%"
275401

402+
# @unittest.skipIf(
403+
# not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
404+
# )
405+
def test_normconv2d_with_logits(self):
406+
q_m = NormConv2DWithLogits()
407+
q_m.eval()
408+
ori_m = q_m
409+
410+
dataset = []
411+
for _ in range(30):
412+
args, _ = ori_m.get_example_inputs()
413+
dataset.append(*args)
414+
415+
calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
416+
sens = calibrator.compute_sensitivity_info()
417+
418+
# Apply GPTQ
419+
q_m = prepare(
420+
q_m,
421+
GPTQConfig(
422+
show_progress=False,
423+
mse="smse",
424+
perchannel=True,
425+
sensitivity=sens,
426+
),
427+
)
428+
for input in dataset:
429+
q_m(input)
430+
convert(q_m, inplace=True)
431+
# check that all convolution nodes are quantized
432+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
433+
assert (
434+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
435+
), "first conv node is not quantized"
436+
assert (
437+
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
438+
), "second conv node is not quantized"
439+
276440
@unittest.skipIf(
277441
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
278442
)
@@ -405,6 +569,46 @@ def test_groupwise_conv1d(self):
405569

406570
# TODO add quantization (right now it can't be evaluated on backend)
407571

572+
# @unittest.skipIf(
573+
# not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
574+
# )
575+
def test_normconv1d_with_logits(self):
576+
q_m = NormConv1DWithLogits()
577+
q_m.eval()
578+
ori_m = q_m
579+
580+
dataset = []
581+
for _ in range(30):
582+
args, _ = ori_m.get_example_inputs()
583+
dataset.append(*args)
584+
585+
calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
586+
sens = calibrator.compute_sensitivity_info()
587+
588+
# Apply GPTQ
589+
q_m = prepare(
590+
q_m,
591+
GPTQConfig(
592+
show_progress=False,
593+
mse="smse",
594+
perchannel=True,
595+
sensitivity=sens,
596+
),
597+
)
598+
for input in dataset:
599+
q_m(input)
600+
convert(q_m, inplace=True)
601+
# check that all convolution nodes are quantized
602+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
603+
assert (
604+
"model.layers.0.conv" in q_m.quantizers # type: ignore[operator]
605+
), "first conv node is not quantized"
606+
assert (
607+
"model.layers.0.conv2" in q_m.quantizers # type: ignore[operator]
608+
), "second conv node is not quantized"
609+
610+
# TODO add quantization
611+
408612
@unittest.skipIf(
409613
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
410614
)
@@ -430,3 +634,142 @@ def test_transposed_conv2d(self):
430634
), "second conv node is not quantized"
431635

432636
# TODO add quantization
637+
638+
# @unittest.skipIf(
639+
# not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
640+
# )
641+
def test_transposed_conv2d_with_logits(self):
642+
q_m = TransposedConv2DGeneralWithLogits()
643+
q_m.eval()
644+
ori_m = q_m
645+
646+
dataset = []
647+
for _ in range(30):
648+
args, _ = ori_m.get_example_inputs()
649+
dataset.append(*args)
650+
651+
calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
652+
sens = calibrator.compute_sensitivity_info()
653+
654+
# Apply GPTQ
655+
q_m = prepare(
656+
q_m,
657+
GPTQConfig(
658+
show_progress=False,
659+
mse="smse",
660+
perchannel=True,
661+
sensitivity=sens,
662+
),
663+
)
664+
for input in dataset:
665+
q_m(input)
666+
convert(q_m, inplace=True)
667+
# check that all convolution nodes are quantized
668+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
669+
assert (
670+
"model.layers.0.tconv" in q_m.quantizers # type: ignore[operator]
671+
), "first conv node is not quantized"
672+
assert (
673+
"model.layers.0.tconv2" in q_m.quantizers # type: ignore[operator]
674+
), "second conv node is not quantized"
675+
676+
# TODO add quantization
677+
678+
@unittest.skipIf(
679+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
680+
)
681+
def test_normconv3d(self):
682+
q_m = NormConv3D()
683+
q_m.eval()
684+
ori_m = q_m
685+
args, kwargs = ori_m.get_example_inputs()
686+
687+
# Apply GPTQ
688+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
689+
for _ in range(30):
690+
args, kwargs = ori_m.get_example_inputs()
691+
q_m(*args, **kwargs)
692+
convert(q_m, inplace=True)
693+
# check that all convolution nodes are quantized
694+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
695+
assert (
696+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
697+
), "first conv node is not quantized"
698+
assert (
699+
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
700+
), "second conv node is not quantized"
701+
702+
@unittest.skipIf(
703+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
704+
)
705+
def test_normconv3d_on_zero_inputs(self):
706+
q_m = NormConv3D()
707+
q_m.eval()
708+
ori_m = q_m
709+
710+
# Apply GPTQ
711+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
712+
for _ in range(30):
713+
args, kwargs = ori_m.get_zero_inputs()
714+
q_m(*args, **kwargs)
715+
convert(q_m, inplace=True)
716+
assert torch.sum(q_m.m[0].weight != 0) > 0, "weights should not be all zeros" # type: ignore[arg-type]
717+
718+
@unittest.skipIf(
719+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
720+
)
721+
def test_paddednormconv3d(self):
722+
q_m = PaddedNormConv3D()
723+
q_m.eval()
724+
ori_m = q_m
725+
args, kwargs = ori_m.get_example_inputs()
726+
727+
# Apply GPTQ
728+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
729+
for _ in range(30):
730+
args, kwargs = ori_m.get_example_inputs()
731+
q_m(*args, **kwargs)
732+
convert(q_m, inplace=True)
733+
# check that all convolution nodes are quantized
734+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
735+
assert (
736+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
737+
), "first conv node is not quantized"
738+
739+
# @unittest.skipIf(
740+
# not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
741+
# )
742+
def test_normconv3d(self):
743+
q_m = NormConv3DWithLogits()
744+
q_m.eval()
745+
ori_m = q_m
746+
747+
dataset = []
748+
for _ in range(30):
749+
args, _ = ori_m.get_example_inputs()
750+
dataset.append(*args)
751+
752+
calibrator = SensitivityCalibrator(q_m, dataset, show_progress=False)
753+
sens = calibrator.compute_sensitivity_info()
754+
755+
# Apply GPTQ
756+
q_m = prepare(
757+
q_m,
758+
GPTQConfig(
759+
show_progress=False,
760+
mse="smse",
761+
perchannel=True,
762+
sensitivity=sens,
763+
),
764+
)
765+
for input in dataset:
766+
q_m(input)
767+
convert(q_m, inplace=True)
768+
# check that all convolution nodes are quantized
769+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
770+
assert (
771+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
772+
), "first conv node is not quantized"
773+
assert (
774+
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
775+
), "second conv node is not quantized"

0 commit comments

Comments
 (0)