1919import torch
2020
2121from tico .quantization import convert , prepare
22+ from tico .quantization .algorithm .gptq .utils import SensitivityCalibrator
2223from tico .quantization .config .gptq import GPTQConfig
2324from tico .quantization .config .ptq import PTQConfig
2425from tico .quantization .evaluation .evaluate import BACKEND , evaluate
@@ -100,6 +101,29 @@ def get_example_inputs(self):
100101 return (torch .randn (1 , 32 , 16 , 16 ),), {}
101102
102103
104+ class NormConv2DWithLogits (torch .nn .Module ):
105+ def __init__ (self ):
106+ super ().__init__ ()
107+ self .device = torch .device ("cpu" )
108+ self .dtype = torch .float32
109+ self .m = torch .nn .ModuleList ()
110+ self .m .append (torch .nn .Conv2d (128 , 256 , (3 , 3 ), stride = 1 ))
111+ self .m .append (torch .nn .Conv2d (256 , 512 , (5 , 5 ), stride = 2 ))
112+
113+ def forward (self , x ):
114+ class OutputWithLogits :
115+ def __init__ (self , logits ):
116+ self .logits = logits
117+
118+ z = self .m [0 ](x )
119+ z = self .m [1 ](z )
120+ z = z .reshape ((- 1 , 64 )).unsqueeze (0 )
121+ return OutputWithLogits (z )
122+
123+ def get_example_inputs (self ):
124+ return (torch .randn (1 , 128 , 32 , 32 ),), {}
125+
126+
103127class NormConv1D (torch .nn .Module ):
104128 def __init__ (self ):
105129 super ().__init__ ()
@@ -133,6 +157,28 @@ def get_example_inputs(self):
133157 return (torch .randn (1 , 32 , 16 ),), {}
134158
135159
160+ class NormConv1DWithLogits (torch .nn .Module ):
161+ def __init__ (self ):
162+ super ().__init__ ()
163+ self .device = torch .device ("cpu" )
164+ self .dtype = torch .float32
165+ self .conv = torch .nn .Conv1d (128 , 256 , 3 , stride = 1 )
166+ self .conv2 = torch .nn .Conv1d (256 , 512 , 5 , stride = 2 )
167+
168+ def forward (self , x ):
169+ class OutputWithLogits :
170+ def __init__ (self , logits ):
171+ self .logits = logits
172+
173+ z = self .conv (x )
174+ z = self .conv2 (z )
175+ z = z .reshape ((- 1 , 64 )).unsqueeze (0 )
176+ return OutputWithLogits (z )
177+
178+ def get_example_inputs (self ):
179+ return (torch .randn (1 , 128 , 32 ),), {}
180+
181+
136182class TransposedConv2DGeneral (torch .nn .Module ):
137183 def __init__ (self ):
138184 super ().__init__ ()
@@ -151,6 +197,30 @@ def get_example_inputs(self):
151197 return (torch .randn (1 , 16 , 7 , 7 ),), {}
152198
153199
200+ class TransposedConv2DGeneralWithLogits (torch .nn .Module ):
201+ def __init__ (self ):
202+ super ().__init__ ()
203+ self .device = torch .device ("cpu" )
204+ self .dtype = torch .float32
205+ self .tconv = torch .nn .ConvTranspose2d (16 , 32 , (2 , 2 ), stride = 2 , groups = 1 )
206+ self .tconv2 = torch .nn .ConvTranspose2d (
207+ 32 , 16 , (3 , 3 ), stride = 4 , groups = 2
208+ ) # general groupwise
209+
210+ def forward (self , x ):
211+ class OutputWithLogits :
212+ def __init__ (self , logits ):
213+ self .logits = logits
214+
215+ z = self .tconv (x )
216+ z = self .tconv2 (z )
217+ z = z .reshape ((- 1 , 8 )).unsqueeze (0 )
218+ return OutputWithLogits (z )
219+
220+ def get_example_inputs (self ):
221+ return (torch .randn (1 , 16 , 7 , 7 ),), {}
222+
223+
154224class NormConv3D (torch .nn .Module ):
155225 def __init__ (self ):
156226 super ().__init__ ()
@@ -184,6 +254,29 @@ def get_example_inputs(self):
184254 return (torch .randn (5 , 16 , 17 , 19 , 35 ),), {}
185255
186256
257+ class NormConv3DWithLogits (torch .nn .Module ):
258+ def __init__ (self ):
259+ super ().__init__ ()
260+ self .device = torch .device ("cpu" )
261+ self .dtype = torch .float32
262+ self .m = torch .nn .ModuleList ()
263+ self .m .append (torch .nn .Conv3d (16 , 8 , (2 , 3 , 5 ), stride = 1 ))
264+ self .m .append (torch .nn .Conv3d (8 , 32 , (3 , 5 , 2 ), stride = 2 ))
265+
266+ def forward (self , x ):
267+ class OutputWithLogits :
268+ def __init__ (self , logits ):
269+ self .logits = logits
270+
271+ z = self .m [0 ](x )
272+ z = self .m [1 ](z )
273+ z = z .reshape ((- 1 , 8 )).unsqueeze (0 )
274+ return OutputWithLogits (z )
275+
276+ def get_example_inputs (self ):
277+ return (torch .randn (5 , 16 , 17 , 19 , 35 ),), {}
278+
279+
187280class GPTQTest (unittest .TestCase ):
188281 @unittest .skipIf (
189282 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
@@ -306,6 +399,44 @@ def test_normconv2d(self):
306399 results ["peir" ][0 ] < tolerance
307400 ), f"PEIR exceeds tolerance. PEIR:{ results ['peir' ][0 ]} %, tolerance: { tolerance } %"
308401
402+ @unittest .skipIf (
403+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
404+ )
405+ def test_normconv2d_with_logits (self ):
406+ q_m = NormConv2DWithLogits ()
407+ q_m .eval ()
408+ ori_m = q_m
409+
410+ dataset = []
411+ for _ in range (30 ):
412+ args , _ = ori_m .get_example_inputs ()
413+ dataset .append (* args )
414+
415+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
416+ sens = calibrator .compute_sensitivity_info ()
417+
418+ # Apply GPTQ
419+ q_m = prepare (
420+ q_m ,
421+ GPTQConfig (
422+ show_progress = False ,
423+ mse = "smse" ,
424+ perchannel = True ,
425+ sensitivity = sens ,
426+ ),
427+ )
428+ for input in dataset :
429+ q_m (input )
430+ convert (q_m , inplace = True )
431+ # check that all convolution nodes are quantized
432+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
433+ assert (
434+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
435+ ), "first conv node is not quantized"
436+ assert (
437+ "model.layers.0.m.1" in q_m .quantizers # type: ignore[operator]
438+ ), "second conv node is not quantized"
439+
309440 @unittest .skipIf (
310441 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
311442 )
@@ -438,6 +569,46 @@ def test_groupwise_conv1d(self):
438569
439570 # TODO add quantization (right now it can't be evaluated on backend)
440571
572+ @unittest .skipIf (
573+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
574+ )
575+ def test_normconv1d_with_logits (self ):
576+ q_m = NormConv1DWithLogits ()
577+ q_m .eval ()
578+ ori_m = q_m
579+
580+ dataset = []
581+ for _ in range (30 ):
582+ args , _ = ori_m .get_example_inputs ()
583+ dataset .append (* args )
584+
585+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
586+ sens = calibrator .compute_sensitivity_info ()
587+
588+ # Apply GPTQ
589+ q_m = prepare (
590+ q_m ,
591+ GPTQConfig (
592+ show_progress = False ,
593+ mse = "smse" ,
594+ perchannel = True ,
595+ sensitivity = sens ,
596+ ),
597+ )
598+ for input in dataset :
599+ q_m (input )
600+ convert (q_m , inplace = True )
601+ # check that all convolution nodes are quantized
602+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
603+ assert (
604+ "model.layers.0.conv" in q_m .quantizers # type: ignore[operator]
605+ ), "first conv node is not quantized"
606+ assert (
607+ "model.layers.0.conv2" in q_m .quantizers # type: ignore[operator]
608+ ), "second conv node is not quantized"
609+
610+ # TODO add quantization
611+
441612 @unittest .skipIf (
442613 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
443614 )
@@ -464,6 +635,46 @@ def test_transposed_conv2d(self):
464635
465636 # TODO add quantization
466637
638+ @unittest .skipIf (
639+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
640+ )
641+ def test_transposed_conv2d_with_logits (self ):
642+ q_m = TransposedConv2DGeneralWithLogits ()
643+ q_m .eval ()
644+ ori_m = q_m
645+
646+ dataset = []
647+ for _ in range (30 ):
648+ args , _ = ori_m .get_example_inputs ()
649+ dataset .append (* args )
650+
651+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
652+ sens = calibrator .compute_sensitivity_info ()
653+
654+ # Apply GPTQ
655+ q_m = prepare (
656+ q_m ,
657+ GPTQConfig (
658+ show_progress = False ,
659+ mse = "smse" ,
660+ perchannel = True ,
661+ sensitivity = sens ,
662+ ),
663+ )
664+ for input in dataset :
665+ q_m (input )
666+ convert (q_m , inplace = True )
667+ # check that all convolution nodes are quantized
668+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
669+ assert (
670+ "model.layers.0.tconv" in q_m .quantizers # type: ignore[operator]
671+ ), "first conv node is not quantized"
672+ assert (
673+ "model.layers.0.tconv2" in q_m .quantizers # type: ignore[operator]
674+ ), "second conv node is not quantized"
675+
676+ # TODO add quantization
677+
467678 @unittest .skipIf (
468679 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
469680 )
@@ -524,3 +735,41 @@ def test_paddednormconv3d(self):
524735 assert (
525736 "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
526737 ), "first conv node is not quantized"
738+
739+ @unittest .skipIf (
740+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
741+ )
742+ def test_normconv3d_with_logits (self ):
743+ q_m = NormConv3DWithLogits ()
744+ q_m .eval ()
745+ ori_m = q_m
746+
747+ dataset = []
748+ for _ in range (30 ):
749+ args , _ = ori_m .get_example_inputs ()
750+ dataset .append (* args )
751+
752+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
753+ sens = calibrator .compute_sensitivity_info ()
754+
755+ # Apply GPTQ
756+ q_m = prepare (
757+ q_m ,
758+ GPTQConfig (
759+ show_progress = False ,
760+ mse = "smse" ,
761+ perchannel = True ,
762+ sensitivity = sens ,
763+ ),
764+ )
765+ for input in dataset :
766+ q_m (input )
767+ convert (q_m , inplace = True )
768+ # check that all convolution nodes are quantized
769+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
770+ assert (
771+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
772+ ), "first conv node is not quantized"
773+ assert (
774+ "model.layers.0.m.1" in q_m .quantizers # type: ignore[operator]
775+ ), "second conv node is not quantized"
0 commit comments