Skip to content

Commit d373df3

Browse files
committed
[quantization] [draft] GPTQ for VLM
This PR is the first try-out for full quantization of VLM model by GPTQ+PTQ. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent fcacf65 commit d373df3

File tree

6 files changed

+733
-115
lines changed

6 files changed

+733
-115
lines changed

test/quantization/algorithm/test_gptq.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,36 @@ def forward(self, x):
150150
def get_example_inputs(self):
151151
return (torch.randn(1, 16, 7, 7),), {}
152152

153+
class NormConv3D(torch.nn.Module):
154+
def __init__(self):
155+
super().__init__()
156+
self.m = torch.nn.ModuleList()
157+
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1))
158+
self.m.append(torch.nn.Conv3d(8, 32, (3, 5, 2), stride=2))
159+
160+
def forward(self, x):
161+
z = self.m[0](x)
162+
z = self.m[1](z)
163+
return z
164+
165+
def get_example_inputs(self):
166+
return (torch.randn(5, 16, 17, 19, 35),), {}
167+
168+
def get_zero_inputs(self):
169+
return (torch.zeros(5, 16, 17, 19, 35),), {}
170+
171+
class PaddedNormConv3D(torch.nn.Module):
172+
def __init__(self):
173+
super().__init__()
174+
self.m = torch.nn.ModuleList()
175+
self.m.append(torch.nn.Conv3d(16, 8, (2, 3, 5), stride=1, padding="valid"))
176+
177+
def forward(self, x):
178+
z = self.m[0](x)
179+
return z
180+
181+
def get_example_inputs(self):
182+
return (torch.randn(5, 16, 17, 19, 35),), {}
153183

154184
class GPTQTest(unittest.TestCase):
155185
@unittest.skipIf(
@@ -430,3 +460,65 @@ def test_transposed_conv2d(self):
430460
), "second conv node is not quantized"
431461

432462
# TODO add quantization
463+
464+
@unittest.skipIf(
465+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
466+
)
467+
def test_normconv3d(self):
468+
q_m = NormConv3D()
469+
q_m.eval()
470+
ori_m = q_m
471+
args, kwargs = ori_m.get_example_inputs()
472+
473+
# Apply GPTQ
474+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
475+
for _ in range(30):
476+
args, kwargs = ori_m.get_example_inputs()
477+
q_m(*args, **kwargs)
478+
convert(q_m, inplace=True)
479+
# check that all convolution nodes are quantized
480+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
481+
assert (
482+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
483+
), "first conv node is not quantized"
484+
assert (
485+
"model.layers.0.m.1" in q_m.quantizers # type: ignore[operator]
486+
), "second conv node is not quantized"
487+
488+
@unittest.skipIf(
489+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
490+
)
491+
def test_normconv3d_on_zero_inputs(self):
492+
q_m = NormConv3D()
493+
q_m.eval()
494+
ori_m = q_m
495+
496+
# Apply GPTQ
497+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
498+
for _ in range(30):
499+
args, kwargs = ori_m.get_zero_inputs()
500+
q_m(*args, **kwargs)
501+
convert(q_m, inplace=True)
502+
assert torch.sum(q_m.m[0].weight != 0) > 0, "weights should not be all zeros" # type: ignore[arg-type]
503+
504+
505+
@unittest.skipIf(
506+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
507+
)
508+
def test_paddednormconv3d(self):
509+
q_m = PaddedNormConv3D()
510+
q_m.eval()
511+
ori_m = q_m
512+
args, kwargs = ori_m.get_example_inputs()
513+
514+
# Apply GPTQ
515+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
516+
for _ in range(30):
517+
args, kwargs = ori_m.get_example_inputs()
518+
q_m(*args, **kwargs)
519+
convert(q_m, inplace=True)
520+
# check that all convolution nodes are quantized
521+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
522+
assert (
523+
"model.layers.0.m.0" in q_m.quantizers # type: ignore[operator]
524+
), "first conv node is not quantized"

tico/quantization/algorithm/gptq/gptq.py

Lines changed: 153 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import torch
2626
import torch.nn as nn
27+
import torch.nn.functional as F
2728

2829
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
2930
from tico.quantization.algorithm.gptq.utils import get_numerical_padding
@@ -167,7 +168,11 @@ def __init__(self, layer):
167168
self.layer = layer
168169
self.dev = self.layer.weight.device
169170
W = layer.weight.data.clone()
170-
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
171+
if (
172+
isinstance(self.layer, nn.Conv2d)
173+
or isinstance(self.layer, nn.Conv1d)
174+
or isinstance(self.layer, nn.Conv3d)
175+
):
171176
W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
172177
elif isinstance(self.layer, nn.ConvTranspose2d):
173178
W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
@@ -251,10 +256,87 @@ def add_batch(self, inp, out):
251256
if isinstance(self.layer, nn.ConvTranspose2d):
252257
inp = get_matmul_input_for_convtranspose2d(self.layer, inp)
253258

259+
if isinstance(self.layer, nn.Conv3d):
260+
# adapted from https://discuss.pytorch.org/t/manual-implementation-of-unrolled-3d-convolutions/91021
261+
assert (
262+
self.layer.groups == 1
263+
) # depthwise/groupwise are not supported currently
264+
assert all(dilation == 1 for dilation in self.layer.dilation)
265+
266+
# test
267+
# input_dim = [22, 59, 114]
268+
# in_channels = 10
269+
# out_channels = 5
270+
# kernel_size = (4, 2, 3)
271+
# padding = (1, 4, 3)
272+
# stride = (1, 1, 1)
273+
# N = 51
274+
# input_tensor = torch.zeros(N, in_channels, input_dim[0], input_dim[1], input_dim[2]).uniform_(-1, 1)
275+
# conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
276+
# output_tensor = conv(input_tensor)
277+
# output_dim = [0, 0, 0]
278+
# output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
279+
# output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
280+
# output_dim[2] = int((input_tensor.shape[4] - kernel_size[2] + 2 * padding[2]) / stride[2]) + 1
281+
# if not all(item == 0 for item in padding):
282+
# input_tensor = F.pad(input_tensor, pad=(padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]), mode="constant", value=0)
283+
#
284+
# unfolded_input_tensor = input_tensor.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
285+
# unfolded_input_tensor = unfolded_input_tensor.reshape(N, in_channels, -1, kernel_size[0] * kernel_size[1] * kernel_size[2])
286+
# unfolded_input_tensor = unfolded_input_tensor.permute([0, 2, 1, 3])
287+
# #unfolded_input_tensor = unfolded_input_tensor.reshape(-1, unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
288+
# #unfolded_input_tensor = unfolded_input_tensor.reshape( unfolded_input_tensor.shape[0], unfolded_input_tensor.shape[1], unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
289+
# #unfolded_input_tensor = unfolded_input_tensor.permute([2, 0, 1])
290+
# #unfolded_input_tensor = unfolded_input_tensor.flatten(1).T #(N * NPatches, inner_dim)
291+
# unfolded_input_tensor = unfolded_input_tensor.reshape(unfolded_input_tensor.shape[0] * unfolded_input_tensor.shape[1], unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
292+
#
293+
# kernels_flat = conv.weight.detach().clone().flatten(1)#view(out_channels, -1)
294+
# alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor.T) #(out_channels, N * NPatches)
295+
# alt_output_tensor = alt_output_tensor.view(out_channels, N, output_dim[0], output_dim[1], output_dim[2])
296+
# alt_output_tensor = alt_output_tensor.permute([1, 0, 2, 3, 4])
297+
# eps_max = torch.max(torch.abs(output_tensor - alt_output_tensor))
298+
# eps_mean = torch.mean(torch.abs(output_tensor - alt_output_tensor))
299+
# assert( eps_max < 1.e-04 or eps_mean < 1.e-06)
300+
301+
# inp is assumed to be (N, C_in, H, W, D)
302+
padding = get_numerical_padding(self.layer)
303+
if isinstance(padding, int):
304+
padding = (padding, padding, padding)
305+
if not all(item == 0 for item in padding):
306+
inp = F.pad(
307+
inp,
308+
pad=(
309+
padding[2],
310+
padding[2],
311+
padding[1],
312+
padding[1],
313+
padding[0],
314+
padding[0],
315+
),
316+
mode="constant",
317+
value=0,
318+
)
319+
krn_size = self.layer.kernel_size
320+
stride = self.layer.stride
321+
inp = (
322+
inp.unfold(2, krn_size[0], stride[0])
323+
.unfold(3, krn_size[1], stride[1])
324+
.unfold(4, krn_size[2], stride[2])
325+
) # inp.shape = (N, C_in, ..patches... , krn_size[0], krn_size[1], krn_size[2])
326+
inp = inp.reshape(
327+
inp.shape[0], inp.shape[1], -1, krn_size[0] * krn_size[1] * krn_size[2]
328+
) # inp.shape = (N, C_in, num_patches, krn_size[0] * krn_size[1] * krn_size[2])
329+
inp = inp.permute(
330+
[0, 2, 1, 3]
331+
) # inp.shape = (N, num_patches, C_in, krn_size[0] * krn_size[1] * krn_size[2])
332+
inp = inp.reshape(
333+
inp.shape[0] * inp.shape[1], inp.shape[2] * inp.shape[3]
334+
).T # inp.shape =(C_in * krn_size[0] * krn_size[1] * krn_size[2], N * num_patches)
335+
254336
self.H *= self.nsamples / (self.nsamples + tmp)
255337
self.nsamples += tmp
256338
inp = math.sqrt(2 / self.nsamples) * inp.float()
257-
self.H += inp.matmul(inp.t())
339+
self.H += inp.matmul(inp.t()).to(self.H.device)
258340

259341
def fasterquant(
260342
self,
@@ -266,12 +348,23 @@ def fasterquant(
266348
verbose=False,
267349
):
268350
W = self.layer.weight.data.clone()
269-
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
351+
if (
352+
isinstance(self.layer, nn.Conv2d)
353+
or isinstance(self.layer, nn.Conv1d)
354+
or isinstance(self.layer, nn.Conv3d)
355+
):
270356
W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
357+
if self.quantizer.sensitivity is not None:
358+
self.quantizer.sensitivity = self.quantizer.sensitivity.flatten(1)
271359
elif isinstance(self.layer, nn.ConvTranspose2d):
272360
W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
273361
conv2d_shape = W.shape
274362
W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
363+
if self.quantizer.sensitivity is not None:
364+
self.quantizer.sensitivity = convtranspose2d_weights_to_conv2d_weights(
365+
self.layer, self.quantizer.sensitivity
366+
)
367+
self.quantizer.sensitivity = self.quantizer.sensitivity.flatten(1)
275368

276369
W = W.float()
277370
tick = time.time()
@@ -313,49 +406,58 @@ def fasterquant(
313406
Hinv = H
314407

315408
assert isinstance(Hinv, torch.Tensor)
316-
for i1 in range(0, self.columns, blocksize):
317-
i2 = min(i1 + blocksize, self.columns)
318-
count = i2 - i1
319-
320-
W1 = W[:, i1:i2].clone()
321-
Q1 = torch.zeros_like(W1)
322-
Err1 = torch.zeros_like(W1)
323-
Losses1 = torch.zeros_like(W1)
324-
Hinv1 = Hinv[i1:i2, i1:i2]
325-
326-
for i in range(count):
327-
w = W1[:, i]
328-
d = Hinv1[i, i]
329-
330-
if groupsize != -1:
331-
if not static_groups:
332-
if (i1 + i) % groupsize == 0:
333-
self.quantizer.find_params(
334-
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
335-
)
336-
else:
337-
idx: torch.Tensor | int = i1 + i
338-
if actorder:
339-
idx = perm[idx]
340-
self.quantizer = groups[idx // groupsize]
341-
342-
q = quantize(
343-
w.unsqueeze(1),
344-
self.quantizer.scale,
345-
self.quantizer.zero,
346-
self.quantizer.maxq,
347-
).flatten()
348-
Q1[:, i] = q
349-
Losses1[:, i] = (w - q) ** 2 / d**2
350-
351-
err1 = (w - q) / d
352-
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
353-
Err1[:, i] = err1
354-
355-
Q[:, i1:i2] = Q1
356-
Losses[:, i1:i2] = Losses1 / 2
357-
358-
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
409+
just_quantize = False
410+
if just_quantize:
411+
Q = quantize(
412+
W,
413+
self.quantizer.scale,
414+
self.quantizer.zero,
415+
self.quantizer.maxq,
416+
)
417+
else:
418+
for i1 in range(0, self.columns, blocksize):
419+
i2 = min(i1 + blocksize, self.columns)
420+
count = i2 - i1
421+
422+
W1 = W[:, i1:i2].clone()
423+
Q1 = torch.zeros_like(W1)
424+
Err1 = torch.zeros_like(W1)
425+
Losses1 = torch.zeros_like(W1)
426+
Hinv1 = Hinv[i1:i2, i1:i2]
427+
428+
for i in range(count):
429+
w = W1[:, i]
430+
d = Hinv1[i, i]
431+
432+
if groupsize != -1:
433+
if not static_groups:
434+
if (i1 + i) % groupsize == 0:
435+
self.quantizer.find_params(
436+
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
437+
)
438+
else:
439+
idx: torch.Tensor | int = i1 + i
440+
if actorder:
441+
idx = perm[idx]
442+
self.quantizer = groups[idx // groupsize]
443+
444+
q = quantize(
445+
w.unsqueeze(1),
446+
self.quantizer.scale,
447+
self.quantizer.zero,
448+
self.quantizer.maxq,
449+
).flatten()
450+
Q1[:, i] = q
451+
Losses1[:, i] = (w - q) ** 2 / d**2
452+
453+
err1 = (w - q) / d
454+
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
455+
Err1[:, i] = err1
456+
457+
Q[:, i1:i2] = Q1
458+
Losses[:, i1:i2] = Losses1 / 2
459+
460+
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
359461

360462
if torch.cuda.is_available():
361463
torch.cuda.synchronize()
@@ -366,7 +468,11 @@ def fasterquant(
366468
if actorder:
367469
Q = Q[:, invperm]
368470

369-
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
471+
if (
472+
isinstance(self.layer, nn.Conv2d)
473+
or isinstance(self.layer, nn.Conv1d)
474+
or isinstance(self.layer, nn.Conv3d)
475+
):
370476
if groupsize == -1: # TODO support groupsize != -1
371477
Q[:, dead] = quantize(
372478
self.layer.weight.flatten(1)[:, dead],

tico/quantization/algorithm/gptq/quantizer.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,10 @@ def forward(layer, *args, **kwargs):
110110
):
111111
self._first_layer_ref = model.model.layers[0]
112112
else:
113-
raise RuntimeError(
114-
"GPTQ Quantizer assumes the model has a nested structure like `model.model.layers`, commonly found in LLaMA and other Hugging Face transformer models."
115-
)
113+
self._first_layer_ref = model # let's treat it as a single layer
114+
# raise RuntimeError(
115+
# "GPTQ Quantizer assumes the model has a nested structure like `model.model.layers`, commonly found in LLaMA and other Hugging Face transformer models."
116+
# )
116117
else:
117118
# fallback if the model is not LLaMA-like; treat whole model as single layer
118119
self._first_layer_ref = model
@@ -180,7 +181,10 @@ def convert(self, model):
180181

181182
# Identify layers
182183
if hasattr(model, "model"):
183-
target_layers = model.model.layers
184+
if hasattr(model.model, "layers"):
185+
target_layers = model.model.layers
186+
else:
187+
target_layers = [model]
184188
else:
185189
target_layers = [model]
186190

@@ -204,6 +208,7 @@ def convert(self, model):
204208
torch.nn.Linear,
205209
torch.nn.Conv2d,
206210
torch.nn.Conv1d,
211+
torch.nn.Conv3d,
207212
torch.nn.ConvTranspose2d,
208213
],
209214
)
@@ -300,7 +305,8 @@ def _hook(_, inp, out):
300305
# This line ensures we always take the first element when it's a tuple.
301306
outs = outs[0] if isinstance(outs, tuple) else outs
302307
# Update inputs for next iteration.
303-
self.cache_args[0][batch_idx] = outs
308+
if len(self.cache_args) > 0:
309+
self.cache_args[0][batch_idx] = outs
304310

305311
if torch.cuda.is_available():
306312
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)