2424
2525import torch
2626import torch .nn as nn
27+ import torch .nn .functional as F
2728
2829from tico .quantization .algorithm .gptq .quant import quantize , Quantizer
2930from tico .quantization .algorithm .gptq .utils import get_numerical_padding
@@ -167,7 +168,11 @@ def __init__(self, layer):
167168 self .layer = layer
168169 self .dev = self .layer .weight .device
169170 W = layer .weight .data .clone ()
170- if isinstance (self .layer , nn .Conv2d ) or isinstance (self .layer , nn .Conv1d ):
171+ if (
172+ isinstance (self .layer , nn .Conv2d )
173+ or isinstance (self .layer , nn .Conv1d )
174+ or isinstance (self .layer , nn .Conv3d )
175+ ):
171176 W = W .flatten (1 ) # reshaped to matrix (OUT_channels x the_rest)
172177 elif isinstance (self .layer , nn .ConvTranspose2d ):
173178 W = convtranspose2d_weights_to_conv2d_weights (self .layer , W )
@@ -251,10 +256,87 @@ def add_batch(self, inp, out):
251256 if isinstance (self .layer , nn .ConvTranspose2d ):
252257 inp = get_matmul_input_for_convtranspose2d (self .layer , inp )
253258
259+ if isinstance (self .layer , nn .Conv3d ):
260+ # adapted from https://discuss.pytorch.org/t/manual-implementation-of-unrolled-3d-convolutions/91021
261+ assert (
262+ self .layer .groups == 1
263+ ) # depthwise/groupwise are not supported currently
264+ assert all (dilation == 1 for dilation in self .layer .dilation )
265+
266+ # test
267+ # input_dim = [22, 59, 114]
268+ # in_channels = 10
269+ # out_channels = 5
270+ # kernel_size = (4, 2, 3)
271+ # padding = (1, 4, 3)
272+ # stride = (1, 1, 1)
273+ # N = 51
274+ # input_tensor = torch.zeros(N, in_channels, input_dim[0], input_dim[1], input_dim[2]).uniform_(-1, 1)
275+ # conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, padding=padding, stride=stride, bias=False)
276+ # output_tensor = conv(input_tensor)
277+ # output_dim = [0, 0, 0]
278+ # output_dim[0] = int((input_tensor.shape[2] - kernel_size[0] + 2 * padding[0]) / stride[0]) + 1
279+ # output_dim[1] = int((input_tensor.shape[3] - kernel_size[1] + 2 * padding[1]) / stride[1]) + 1
280+ # output_dim[2] = int((input_tensor.shape[4] - kernel_size[2] + 2 * padding[2]) / stride[2]) + 1
281+ # if not all(item == 0 for item in padding):
282+ # input_tensor = F.pad(input_tensor, pad=(padding[2], padding[2], padding[1], padding[1], padding[0], padding[0]), mode="constant", value=0)
283+ #
284+ # unfolded_input_tensor = input_tensor.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
285+ # unfolded_input_tensor = unfolded_input_tensor.reshape(N, in_channels, -1, kernel_size[0] * kernel_size[1] * kernel_size[2])
286+ # unfolded_input_tensor = unfolded_input_tensor.permute([0, 2, 1, 3])
287+ # #unfolded_input_tensor = unfolded_input_tensor.reshape(-1, unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
288+ # #unfolded_input_tensor = unfolded_input_tensor.reshape( unfolded_input_tensor.shape[0], unfolded_input_tensor.shape[1], unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
289+ # #unfolded_input_tensor = unfolded_input_tensor.permute([2, 0, 1])
290+ # #unfolded_input_tensor = unfolded_input_tensor.flatten(1).T #(N * NPatches, inner_dim)
291+ # unfolded_input_tensor = unfolded_input_tensor.reshape(unfolded_input_tensor.shape[0] * unfolded_input_tensor.shape[1], unfolded_input_tensor.shape[2] * unfolded_input_tensor.shape[3])
292+ #
293+ # kernels_flat = conv.weight.detach().clone().flatten(1)#view(out_channels, -1)
294+ # alt_output_tensor = torch.matmul(kernels_flat, unfolded_input_tensor.T) #(out_channels, N * NPatches)
295+ # alt_output_tensor = alt_output_tensor.view(out_channels, N, output_dim[0], output_dim[1], output_dim[2])
296+ # alt_output_tensor = alt_output_tensor.permute([1, 0, 2, 3, 4])
297+ # eps_max = torch.max(torch.abs(output_tensor - alt_output_tensor))
298+ # eps_mean = torch.mean(torch.abs(output_tensor - alt_output_tensor))
299+ # assert( eps_max < 1.e-04 or eps_mean < 1.e-06)
300+
301+ # inp is assumed to be (N, C_in, H, W, D)
302+ padding = get_numerical_padding (self .layer )
303+ if isinstance (padding , int ):
304+ padding = (padding , padding , padding )
305+ if not all (item == 0 for item in padding ):
306+ inp = F .pad (
307+ inp ,
308+ pad = (
309+ padding [2 ],
310+ padding [2 ],
311+ padding [1 ],
312+ padding [1 ],
313+ padding [0 ],
314+ padding [0 ],
315+ ),
316+ mode = "constant" ,
317+ value = 0 ,
318+ )
319+ krn_size = self .layer .kernel_size
320+ stride = self .layer .stride
321+ inp = (
322+ inp .unfold (2 , krn_size [0 ], stride [0 ])
323+ .unfold (3 , krn_size [1 ], stride [1 ])
324+ .unfold (4 , krn_size [2 ], stride [2 ])
325+ ) # inp.shape = (N, C_in, ..patches... , krn_size[0], krn_size[1], krn_size[2])
326+ inp = inp .reshape (
327+ inp .shape [0 ], inp .shape [1 ], - 1 , krn_size [0 ] * krn_size [1 ] * krn_size [2 ]
328+ ) # inp.shape = (N, C_in, num_patches, krn_size[0] * krn_size[1] * krn_size[2])
329+ inp = inp .permute (
330+ [0 , 2 , 1 , 3 ]
331+ ) # inp.shape = (N, num_patches, C_in, krn_size[0] * krn_size[1] * krn_size[2])
332+ inp = inp .reshape (
333+ inp .shape [0 ] * inp .shape [1 ], inp .shape [2 ] * inp .shape [3 ]
334+ ).T # inp.shape =(C_in * krn_size[0] * krn_size[1] * krn_size[2], N * num_patches)
335+
254336 self .H *= self .nsamples / (self .nsamples + tmp )
255337 self .nsamples += tmp
256338 inp = math .sqrt (2 / self .nsamples ) * inp .float ()
257- self .H += inp .matmul (inp .t ())
339+ self .H += inp .matmul (inp .t ()). to ( self . H . device )
258340
259341 def fasterquant (
260342 self ,
@@ -266,12 +348,23 @@ def fasterquant(
266348 verbose = False ,
267349 ):
268350 W = self .layer .weight .data .clone ()
269- if isinstance (self .layer , nn .Conv2d ) or isinstance (self .layer , nn .Conv1d ):
351+ if (
352+ isinstance (self .layer , nn .Conv2d )
353+ or isinstance (self .layer , nn .Conv1d )
354+ or isinstance (self .layer , nn .Conv3d )
355+ ):
270356 W = W .flatten (1 ) # reshaped to matrix (OUT_channels x the_rest)
357+ if self .quantizer .sensitivity is not None :
358+ self .quantizer .sensitivity = self .quantizer .sensitivity .flatten (1 )
271359 elif isinstance (self .layer , nn .ConvTranspose2d ):
272360 W = convtranspose2d_weights_to_conv2d_weights (self .layer , W )
273361 conv2d_shape = W .shape
274362 W = W .flatten (1 ) # reshaped to matrix (OUT_channels x the_rest)
363+ if self .quantizer .sensitivity is not None :
364+ self .quantizer .sensitivity = convtranspose2d_weights_to_conv2d_weights (
365+ self .layer , self .quantizer .sensitivity
366+ )
367+ self .quantizer .sensitivity = self .quantizer .sensitivity .flatten (1 )
275368
276369 W = W .float ()
277370 tick = time .time ()
@@ -313,49 +406,58 @@ def fasterquant(
313406 Hinv = H
314407
315408 assert isinstance (Hinv , torch .Tensor )
316- for i1 in range (0 , self .columns , blocksize ):
317- i2 = min (i1 + blocksize , self .columns )
318- count = i2 - i1
319-
320- W1 = W [:, i1 :i2 ].clone ()
321- Q1 = torch .zeros_like (W1 )
322- Err1 = torch .zeros_like (W1 )
323- Losses1 = torch .zeros_like (W1 )
324- Hinv1 = Hinv [i1 :i2 , i1 :i2 ]
325-
326- for i in range (count ):
327- w = W1 [:, i ]
328- d = Hinv1 [i , i ]
329-
330- if groupsize != - 1 :
331- if not static_groups :
332- if (i1 + i ) % groupsize == 0 :
333- self .quantizer .find_params (
334- W [:, (i1 + i ) : (i1 + i + groupsize )], weight = True
335- )
336- else :
337- idx : torch .Tensor | int = i1 + i
338- if actorder :
339- idx = perm [idx ]
340- self .quantizer = groups [idx // groupsize ]
341-
342- q = quantize (
343- w .unsqueeze (1 ),
344- self .quantizer .scale ,
345- self .quantizer .zero ,
346- self .quantizer .maxq ,
347- ).flatten ()
348- Q1 [:, i ] = q
349- Losses1 [:, i ] = (w - q ) ** 2 / d ** 2
350-
351- err1 = (w - q ) / d
352- W1 [:, i :] -= err1 .unsqueeze (1 ).matmul (Hinv1 [i , i :].unsqueeze (0 ))
353- Err1 [:, i ] = err1
354-
355- Q [:, i1 :i2 ] = Q1
356- Losses [:, i1 :i2 ] = Losses1 / 2
357-
358- W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
409+ just_quantize = False
410+ if just_quantize :
411+ Q = quantize (
412+ W ,
413+ self .quantizer .scale ,
414+ self .quantizer .zero ,
415+ self .quantizer .maxq ,
416+ )
417+ else :
418+ for i1 in range (0 , self .columns , blocksize ):
419+ i2 = min (i1 + blocksize , self .columns )
420+ count = i2 - i1
421+
422+ W1 = W [:, i1 :i2 ].clone ()
423+ Q1 = torch .zeros_like (W1 )
424+ Err1 = torch .zeros_like (W1 )
425+ Losses1 = torch .zeros_like (W1 )
426+ Hinv1 = Hinv [i1 :i2 , i1 :i2 ]
427+
428+ for i in range (count ):
429+ w = W1 [:, i ]
430+ d = Hinv1 [i , i ]
431+
432+ if groupsize != - 1 :
433+ if not static_groups :
434+ if (i1 + i ) % groupsize == 0 :
435+ self .quantizer .find_params (
436+ W [:, (i1 + i ) : (i1 + i + groupsize )], weight = True
437+ )
438+ else :
439+ idx : torch .Tensor | int = i1 + i
440+ if actorder :
441+ idx = perm [idx ]
442+ self .quantizer = groups [idx // groupsize ]
443+
444+ q = quantize (
445+ w .unsqueeze (1 ),
446+ self .quantizer .scale ,
447+ self .quantizer .zero ,
448+ self .quantizer .maxq ,
449+ ).flatten ()
450+ Q1 [:, i ] = q
451+ Losses1 [:, i ] = (w - q ) ** 2 / d ** 2
452+
453+ err1 = (w - q ) / d
454+ W1 [:, i :] -= err1 .unsqueeze (1 ).matmul (Hinv1 [i , i :].unsqueeze (0 ))
455+ Err1 [:, i ] = err1
456+
457+ Q [:, i1 :i2 ] = Q1
458+ Losses [:, i1 :i2 ] = Losses1 / 2
459+
460+ W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
359461
360462 if torch .cuda .is_available ():
361463 torch .cuda .synchronize ()
@@ -366,7 +468,11 @@ def fasterquant(
366468 if actorder :
367469 Q = Q [:, invperm ]
368470
369- if isinstance (self .layer , nn .Conv2d ) or isinstance (self .layer , nn .Conv1d ):
471+ if (
472+ isinstance (self .layer , nn .Conv2d )
473+ or isinstance (self .layer , nn .Conv1d )
474+ or isinstance (self .layer , nn .Conv3d )
475+ ):
370476 if groupsize == - 1 : # TODO support groupsize != -1
371477 Q [:, dead ] = quantize (
372478 self .layer .weight .flatten (1 )[:, dead ],
0 commit comments