Skip to content

Commit e49ffdb

Browse files
authored
[quantization] Enable nn.ConvTranspose2D in GPTQs (#428)
This PR enables quantization of `nn.ConvTranspose2D` in FPIGPTQ/GPTQ and adds tests for it. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 09b7a88 commit e49ffdb

File tree

6 files changed

+296
-8
lines changed

6 files changed

+296
-8
lines changed

test/quantization/algorithm/test_fpi_gptq.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def get_example_inputs(self):
119119
return (torch.randn(1, 32, 16),), {}
120120

121121

122+
class TransposedConv2DGeneral(torch.nn.Module):
123+
def __init__(self):
124+
super().__init__()
125+
126+
self.tconv = torch.nn.ConvTranspose2d(16, 32, (2, 2), stride=2, groups=1)
127+
self.tconv2 = torch.nn.ConvTranspose2d(
128+
32, 16, (3, 3), stride=4, groups=2
129+
) # general groupwise
130+
131+
def forward(self, x):
132+
z = self.tconv(x)
133+
z = self.tconv2(z)
134+
return z
135+
136+
def get_example_inputs(self):
137+
return (torch.randn(1, 16, 7, 7),), {}
138+
139+
122140
class FPIGPTQTest(unittest.TestCase):
123141
@unittest.skipIf(
124142
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
@@ -321,3 +339,29 @@ def test_groupwise_conv1d(self):
321339
), "second conv node is not quantized"
322340

323341
# TODO add PT2E quantization (right now it can't be evaluated on backend)
342+
343+
@unittest.skipIf(
344+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
345+
)
346+
def test_transposed_conv2d(self):
347+
q_m = TransposedConv2DGeneral()
348+
q_m.eval()
349+
ori_m = q_m
350+
args, kwargs = ori_m.get_example_inputs()
351+
352+
# Apply GPTQ
353+
q_m = prepare(q_m, FPIGPTQConfig(show_progress=False))
354+
for _ in range(30):
355+
args, kwargs = ori_m.get_example_inputs()
356+
q_m(*args, **kwargs)
357+
convert(q_m, inplace=True)
358+
# check that all convolution nodes are quantized
359+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
360+
assert (
361+
"model.layers.0.tconv" in q_m.quantizers
362+
), "first conv node is not quantized"
363+
assert (
364+
"model.layers.0.tconv2" in q_m.quantizers
365+
), "second conv node is not quantized"
366+
367+
# TODO add PT2E quantization

test/quantization/algorithm/test_gptq.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,24 @@ def get_example_inputs(self):
119119
return (torch.randn(1, 32, 16),), {}
120120

121121

122+
class TransposedConv2DGeneral(torch.nn.Module):
123+
def __init__(self):
124+
super().__init__()
125+
126+
self.tconv = torch.nn.ConvTranspose2d(16, 32, (2, 2), stride=2, groups=1)
127+
self.tconv2 = torch.nn.ConvTranspose2d(
128+
32, 16, (3, 3), stride=4, groups=2
129+
) # general groupwise
130+
131+
def forward(self, x):
132+
z = self.tconv(x)
133+
z = self.tconv2(z)
134+
return z
135+
136+
def get_example_inputs(self):
137+
return (torch.randn(1, 16, 7, 7),), {}
138+
139+
122140
class GPTQTest(unittest.TestCase):
123141
@unittest.skipIf(
124142
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
@@ -352,3 +370,29 @@ def test_groupwise_conv1d(self):
352370
), "second conv node is not quantized"
353371

354372
# TODO add PT2E quantization (right now it can't be evaluated on backend)
373+
374+
@unittest.skipIf(
375+
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
376+
)
377+
def test_transposed_conv2d(self):
378+
q_m = TransposedConv2DGeneral()
379+
q_m.eval()
380+
ori_m = q_m
381+
args, kwargs = ori_m.get_example_inputs()
382+
383+
# Apply GPTQ
384+
q_m = prepare(q_m, GPTQConfig(show_progress=False))
385+
for _ in range(30):
386+
args, kwargs = ori_m.get_example_inputs()
387+
q_m(*args, **kwargs)
388+
convert(q_m, inplace=True)
389+
# check that all convolution nodes are quantized
390+
assert hasattr(q_m, "quantizers"), "quantized model does not have quantizers"
391+
assert (
392+
"model.layers.0.tconv" in q_m.quantizers
393+
), "first conv node is not quantized"
394+
assert (
395+
"model.layers.0.tconv2" in q_m.quantizers
396+
), "second conv node is not quantized"
397+
398+
# TODO add PT2E quantization

tico/quantization/algorithm/fpi_gptq/fpi_gptq.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@
2525
import torch
2626
import torch.nn as nn
2727

28+
from tico.quantization.algorithm.gptq.gptq import (
29+
conv2d_weights_to_convtranspose2d_weights,
30+
convtranspose2d_weights_to_conv2d_weights,
31+
get_matmul_input_for_convtranspose2d,
32+
)
33+
2834
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
2935

3036

@@ -58,6 +64,9 @@ def __init__(self, layer):
5864
W = layer.weight.data.clone()
5965
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
6066
W = W.flatten(1)
67+
elif isinstance(self.layer, nn.ConvTranspose2d):
68+
W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
69+
W = W.flatten(1)
6170

6271
self.rows = W.shape[0]
6372
self.columns = W.shape[1]
@@ -132,6 +141,8 @@ def add_batch(self, inp, out):
132141
inp = unfold(inp)
133142
inp = inp.permute([1, 0, 2])
134143
inp = inp.flatten(1)
144+
if isinstance(self.layer, nn.ConvTranspose2d):
145+
inp = get_matmul_input_for_convtranspose2d(self.layer, inp)
135146

136147
self.H *= self.nsamples / (self.nsamples + tmp)
137148
self.nsamples += tmp
@@ -146,6 +157,11 @@ def fasterquant(
146157
W = self.layer.weight.data.clone()
147158
if isinstance(self.layer, nn.Conv2d) or isinstance(self.layer, nn.Conv1d):
148159
W = W.flatten(1)
160+
elif isinstance(self.layer, nn.ConvTranspose2d):
161+
W = convtranspose2d_weights_to_conv2d_weights(self.layer, W)
162+
conv2d_shape = W.shape
163+
W = W.flatten(1) # reshaped to matrix (OUT_channels x the_rest)
164+
149165
W = W.float()
150166
tick = time.time()
151167
if not self.quantizer.ready():
@@ -202,6 +218,15 @@ def fasterquant(
202218
self.quantizer.zero,
203219
self.quantizer.maxq,
204220
)
221+
elif isinstance(self.layer, nn.ConvTranspose2d):
222+
Q[:, dead] = quantize(
223+
convtranspose2d_weights_to_conv2d_weights(
224+
self.layer, self.layer.weight.data
225+
).flatten(1)[:, dead],
226+
self.quantizer.scale,
227+
self.quantizer.zero,
228+
self.quantizer.maxq,
229+
)
205230
else:
206231
Q[:, dead] = quantize(
207232
self.layer.weight[:, dead],
@@ -210,9 +235,15 @@ def fasterquant(
210235
self.quantizer.maxq,
211236
)
212237

213-
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
214-
self.layer.weight.data.dtype
215-
)
238+
if isinstance(self.layer, nn.ConvTranspose2d):
239+
Q_conv2d = Q.reshape(conv2d_shape).to(self.layer.weight.data.dtype)
240+
self.layer.weight.data = conv2d_weights_to_convtranspose2d_weights(
241+
self.layer, Q_conv2d
242+
)
243+
else:
244+
self.layer.weight.data = Q.reshape(self.layer.weight.shape).to(
245+
self.layer.weight.data.dtype
246+
)
216247

217248
def free(self):
218249
self.H = None

tico/quantization/algorithm/fpi_gptq/quantizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,13 @@ def convert(self, model):
7777
):
7878
# 1) Identify quantizable submodules within the layer
7979
full = find_layers(
80-
layer, layers=[torch.nn.Linear, torch.nn.Conv2d, torch.nn.Conv1d]
80+
layer,
81+
layers=[
82+
torch.nn.Linear,
83+
torch.nn.Conv2d,
84+
torch.nn.Conv1d,
85+
torch.nn.ConvTranspose2d,
86+
],
8187
)
8288
sequential = [list(full.keys())]
8389

0 commit comments

Comments
 (0)