1919import torch
2020
2121from tico .quantization import convert , prepare
22+ from tico .quantization .algorithm .gptq .utils import SensitivityCalibrator
2223from tico .quantization .config .gptq import GPTQConfig
2324from tico .quantization .config .ptq import PTQConfig
2425from tico .quantization .evaluation .evaluate import BACKEND , evaluate
@@ -100,6 +101,29 @@ def get_example_inputs(self):
100101 return (torch .randn (1 , 32 , 16 , 16 ),), {}
101102
102103
104+ class NormConv2DWithLogits (torch .nn .Module ):
105+ def __init__ (self ):
106+ super ().__init__ ()
107+ self .device = torch .device ("cpu" )
108+ self .dtype = torch .float32
109+ self .m = torch .nn .ModuleList ()
110+ self .m .append (torch .nn .Conv2d (128 , 256 , (3 , 3 ), stride = 1 ))
111+ self .m .append (torch .nn .Conv2d (256 , 512 , (5 , 5 ), stride = 2 ))
112+
113+ def forward (self , x ):
114+ class OutputWithLogits :
115+ def __init__ (self , logits ):
116+ self .logits = logits
117+
118+ z = self .m [0 ](x )
119+ z = self .m [1 ](z )
120+ z = z .reshape ((- 1 , 64 )).unsqueeze (0 )
121+ return OutputWithLogits (z )
122+
123+ def get_example_inputs (self ):
124+ return (torch .randn (1 , 128 , 32 , 32 ),), {}
125+
126+
103127class NormConv1D (torch .nn .Module ):
104128 def __init__ (self ):
105129 super ().__init__ ()
@@ -133,6 +157,28 @@ def get_example_inputs(self):
133157 return (torch .randn (1 , 32 , 16 ),), {}
134158
135159
160+ class NormConv1DWithLogits (torch .nn .Module ):
161+ def __init__ (self ):
162+ super ().__init__ ()
163+ self .device = torch .device ("cpu" )
164+ self .dtype = torch .float32
165+ self .conv = torch .nn .Conv1d (128 , 256 , 3 , stride = 1 )
166+ self .conv2 = torch .nn .Conv1d (256 , 512 , 5 , stride = 2 )
167+
168+ def forward (self , x ):
169+ class OutputWithLogits :
170+ def __init__ (self , logits ):
171+ self .logits = logits
172+
173+ z = self .conv (x )
174+ z = self .conv2 (z )
175+ z = z .reshape ((- 1 , 64 )).unsqueeze (0 )
176+ return OutputWithLogits (z )
177+
178+ def get_example_inputs (self ):
179+ return (torch .randn (1 , 128 , 32 ),), {}
180+
181+
136182class TransposedConv2DGeneral (torch .nn .Module ):
137183 def __init__ (self ):
138184 super ().__init__ ()
@@ -151,6 +197,86 @@ def get_example_inputs(self):
151197 return (torch .randn (1 , 16 , 7 , 7 ),), {}
152198
153199
200+ class TransposedConv2DGeneralWithLogits (torch .nn .Module ):
201+ def __init__ (self ):
202+ super ().__init__ ()
203+ self .device = torch .device ("cpu" )
204+ self .dtype = torch .float32
205+ self .tconv = torch .nn .ConvTranspose2d (16 , 32 , (2 , 2 ), stride = 2 , groups = 1 )
206+ self .tconv2 = torch .nn .ConvTranspose2d (
207+ 32 , 16 , (3 , 3 ), stride = 4 , groups = 2
208+ ) # general groupwise
209+
210+ def forward (self , x ):
211+ class OutputWithLogits :
212+ def __init__ (self , logits ):
213+ self .logits = logits
214+
215+ z = self .tconv (x )
216+ z = self .tconv2 (z )
217+ z = z .reshape ((- 1 , 8 )).unsqueeze (0 )
218+ return OutputWithLogits (z )
219+
220+ def get_example_inputs (self ):
221+ return (torch .randn (1 , 16 , 7 , 7 ),), {}
222+
223+
224+ class NormConv3D (torch .nn .Module ):
225+ def __init__ (self ):
226+ super ().__init__ ()
227+ self .m = torch .nn .ModuleList ()
228+ self .m .append (torch .nn .Conv3d (16 , 8 , (2 , 3 , 5 ), stride = 1 ))
229+ self .m .append (torch .nn .Conv3d (8 , 32 , (3 , 5 , 2 ), stride = 2 ))
230+
231+ def forward (self , x ):
232+ z = self .m [0 ](x )
233+ z = self .m [1 ](z )
234+ return z
235+
236+ def get_example_inputs (self ):
237+ return (torch .randn (5 , 16 , 17 , 19 , 35 ),), {}
238+
239+ def get_zero_inputs (self ):
240+ return (torch .zeros (5 , 16 , 17 , 19 , 35 ),), {}
241+
242+
243+ class PaddedNormConv3D (torch .nn .Module ):
244+ def __init__ (self ):
245+ super ().__init__ ()
246+ self .m = torch .nn .ModuleList ()
247+ self .m .append (torch .nn .Conv3d (16 , 8 , (2 , 3 , 5 ), stride = 1 , padding = "valid" ))
248+
249+ def forward (self , x ):
250+ z = self .m [0 ](x )
251+ return z
252+
253+ def get_example_inputs (self ):
254+ return (torch .randn (5 , 16 , 17 , 19 , 35 ),), {}
255+
256+
257+ class NormConv3DWithLogits (torch .nn .Module ):
258+ def __init__ (self ):
259+ super ().__init__ ()
260+ self .device = torch .device ("cpu" )
261+ self .dtype = torch .float32
262+ self .m = torch .nn .ModuleList ()
263+ self .m .append (torch .nn .Conv3d (16 , 8 , (2 , 3 , 5 ), stride = 1 ))
264+ self .m .append (torch .nn .Conv3d (8 , 32 , (3 , 5 , 2 ), stride = 2 ))
265+
266+ def forward (self , x ):
267+ class OutputWithLogits :
268+ def __init__ (self , logits ):
269+ self .logits = logits
270+
271+ z = self .m [0 ](x )
272+ z = self .m [1 ](z )
273+ z = z .reshape ((- 1 , 8 )).unsqueeze (0 )
274+ return OutputWithLogits (z )
275+
276+ def get_example_inputs (self ):
277+ return (torch .randn (5 , 16 , 17 , 19 , 35 ),), {}
278+
279+
154280class GPTQTest (unittest .TestCase ):
155281 @unittest .skipIf (
156282 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
@@ -273,6 +399,44 @@ def test_normconv2d(self):
273399 results ["peir" ][0 ] < tolerance
274400 ), f"PEIR exceeds tolerance. PEIR:{ results ['peir' ][0 ]} %, tolerance: { tolerance } %"
275401
402+ # @unittest.skipIf(
403+ # not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
404+ # )
405+ def test_normconv2d_with_logits (self ):
406+ q_m = NormConv2DWithLogits ()
407+ q_m .eval ()
408+ ori_m = q_m
409+
410+ dataset = []
411+ for _ in range (30 ):
412+ args , _ = ori_m .get_example_inputs ()
413+ dataset .append (* args )
414+
415+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
416+ sens = calibrator .compute_sensitivity_info ()
417+
418+ # Apply GPTQ
419+ q_m = prepare (
420+ q_m ,
421+ GPTQConfig (
422+ show_progress = False ,
423+ mse = "smse" ,
424+ perchannel = True ,
425+ sensitivity = sens ,
426+ ),
427+ )
428+ for input in dataset :
429+ q_m (input )
430+ convert (q_m , inplace = True )
431+ # check that all convolution nodes are quantized
432+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
433+ assert (
434+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
435+ ), "first conv node is not quantized"
436+ assert (
437+ "model.layers.0.m.1" in q_m .quantizers # type: ignore[operator]
438+ ), "second conv node is not quantized"
439+
276440 @unittest .skipIf (
277441 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
278442 )
@@ -405,6 +569,46 @@ def test_groupwise_conv1d(self):
405569
406570 # TODO add quantization (right now it can't be evaluated on backend)
407571
572+ # @unittest.skipIf(
573+ # not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
574+ # )
575+ def test_normconv1d_with_logits (self ):
576+ q_m = NormConv1DWithLogits ()
577+ q_m .eval ()
578+ ori_m = q_m
579+
580+ dataset = []
581+ for _ in range (30 ):
582+ args , _ = ori_m .get_example_inputs ()
583+ dataset .append (* args )
584+
585+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
586+ sens = calibrator .compute_sensitivity_info ()
587+
588+ # Apply GPTQ
589+ q_m = prepare (
590+ q_m ,
591+ GPTQConfig (
592+ show_progress = False ,
593+ mse = "smse" ,
594+ perchannel = True ,
595+ sensitivity = sens ,
596+ ),
597+ )
598+ for input in dataset :
599+ q_m (input )
600+ convert (q_m , inplace = True )
601+ # check that all convolution nodes are quantized
602+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
603+ assert (
604+ "model.layers.0.conv" in q_m .quantizers # type: ignore[operator]
605+ ), "first conv node is not quantized"
606+ assert (
607+ "model.layers.0.conv2" in q_m .quantizers # type: ignore[operator]
608+ ), "second conv node is not quantized"
609+
610+ # TODO add quantization
611+
408612 @unittest .skipIf (
409613 not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
410614 )
@@ -430,3 +634,142 @@ def test_transposed_conv2d(self):
430634 ), "second conv node is not quantized"
431635
432636 # TODO add quantization
637+
638+ # @unittest.skipIf(
639+ # not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
640+ # )
641+ def test_transposed_conv2d_with_logits (self ):
642+ q_m = TransposedConv2DGeneralWithLogits ()
643+ q_m .eval ()
644+ ori_m = q_m
645+
646+ dataset = []
647+ for _ in range (30 ):
648+ args , _ = ori_m .get_example_inputs ()
649+ dataset .append (* args )
650+
651+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
652+ sens = calibrator .compute_sensitivity_info ()
653+
654+ # Apply GPTQ
655+ q_m = prepare (
656+ q_m ,
657+ GPTQConfig (
658+ show_progress = False ,
659+ mse = "smse" ,
660+ perchannel = True ,
661+ sensitivity = sens ,
662+ ),
663+ )
664+ for input in dataset :
665+ q_m (input )
666+ convert (q_m , inplace = True )
667+ # check that all convolution nodes are quantized
668+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
669+ assert (
670+ "model.layers.0.tconv" in q_m .quantizers # type: ignore[operator]
671+ ), "first conv node is not quantized"
672+ assert (
673+ "model.layers.0.tconv2" in q_m .quantizers # type: ignore[operator]
674+ ), "second conv node is not quantized"
675+
676+ # TODO add quantization
677+
678+ @unittest .skipIf (
679+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
680+ )
681+ def test_normconv3d (self ):
682+ q_m = NormConv3D ()
683+ q_m .eval ()
684+ ori_m = q_m
685+ args , kwargs = ori_m .get_example_inputs ()
686+
687+ # Apply GPTQ
688+ q_m = prepare (q_m , GPTQConfig (show_progress = False ))
689+ for _ in range (30 ):
690+ args , kwargs = ori_m .get_example_inputs ()
691+ q_m (* args , ** kwargs )
692+ convert (q_m , inplace = True )
693+ # check that all convolution nodes are quantized
694+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
695+ assert (
696+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
697+ ), "first conv node is not quantized"
698+ assert (
699+ "model.layers.0.m.1" in q_m .quantizers # type: ignore[operator]
700+ ), "second conv node is not quantized"
701+
702+ @unittest .skipIf (
703+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
704+ )
705+ def test_normconv3d_on_zero_inputs (self ):
706+ q_m = NormConv3D ()
707+ q_m .eval ()
708+ ori_m = q_m
709+
710+ # Apply GPTQ
711+ q_m = prepare (q_m , GPTQConfig (show_progress = False ))
712+ for _ in range (30 ):
713+ args , kwargs = ori_m .get_zero_inputs ()
714+ q_m (* args , ** kwargs )
715+ convert (q_m , inplace = True )
716+ assert torch .sum (q_m .m [0 ].weight != 0 ) > 0 , "weights should not be all zeros" # type: ignore[arg-type]
717+
718+ @unittest .skipIf (
719+ not IS_INTERNAL_TEST , "Internal test — run only if --include-internal is set"
720+ )
721+ def test_paddednormconv3d (self ):
722+ q_m = PaddedNormConv3D ()
723+ q_m .eval ()
724+ ori_m = q_m
725+ args , kwargs = ori_m .get_example_inputs ()
726+
727+ # Apply GPTQ
728+ q_m = prepare (q_m , GPTQConfig (show_progress = False ))
729+ for _ in range (30 ):
730+ args , kwargs = ori_m .get_example_inputs ()
731+ q_m (* args , ** kwargs )
732+ convert (q_m , inplace = True )
733+ # check that all convolution nodes are quantized
734+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
735+ assert (
736+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
737+ ), "first conv node is not quantized"
738+
739+ # @unittest.skipIf(
740+ # not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
741+ # )
742+ def test_normconv3d (self ):
743+ q_m = NormConv3DWithLogits ()
744+ q_m .eval ()
745+ ori_m = q_m
746+
747+ dataset = []
748+ for _ in range (30 ):
749+ args , _ = ori_m .get_example_inputs ()
750+ dataset .append (* args )
751+
752+ calibrator = SensitivityCalibrator (q_m , dataset , show_progress = False )
753+ sens = calibrator .compute_sensitivity_info ()
754+
755+ # Apply GPTQ
756+ q_m = prepare (
757+ q_m ,
758+ GPTQConfig (
759+ show_progress = False ,
760+ mse = "smse" ,
761+ perchannel = True ,
762+ sensitivity = sens ,
763+ ),
764+ )
765+ for input in dataset :
766+ q_m (input )
767+ convert (q_m , inplace = True )
768+ # check that all convolution nodes are quantized
769+ assert hasattr (q_m , "quantizers" ), "quantized model does not have quantizers"
770+ assert (
771+ "model.layers.0.m.0" in q_m .quantizers # type: ignore[operator]
772+ ), "first conv node is not quantized"
773+ assert (
774+ "model.layers.0.m.1" in q_m .quantizers # type: ignore[operator]
775+ ), "second conv node is not quantized"
0 commit comments