@@ -94,17 +94,22 @@ def tensor_single_sided(request):
9494
9595# Symmetric
9696qschemes_symmetric_params = []
97- for qunit in ["perT" ]: # ['perT','perCh','perGrp']:
97+ for qunit in ["perT" ]:
9898 for symmetric in [True ]:
9999 for Nch in [None ]:
100100 for Ngrp in [None ]:
101101 for single_sided in [False ]:
102- for qlevel_lowering in [
103- True
104- ]: # needs to be disabled for some special cases
102+ # needs to be disabled for some special cases
103+ for qlevel_lowering in [True ]:
105104 qschemes_symmetric_params .append (
106105 Qscheme (
107- qunit , symmetric , Nch , Ngrp , single_sided , qlevel_lowering
106+ unit = qunit ,
107+ symmetric = symmetric ,
108+ single_sided = single_sided ,
109+ qlevel_lowering = qlevel_lowering ,
110+ Nch = Nch ,
111+ Ngrp = Ngrp ,
112+ axis = None ,
108113 )
109114 )
110115
@@ -121,7 +126,6 @@ def tensor_single_sided(request):
121126 }
122127 )
123128
124-
125129@pytest .fixture (scope = "session" , params = quantizer_symmetric_params )
126130def quantizer_symmetric (request ):
127131 """
@@ -135,6 +139,53 @@ def quantizer_symmetric(request):
135139 """
136140 return request .param
137141
142+ # Per channel symmetric params
143+ # clip_high, Nch will be computed at test level from tensor
144+ qschemes_symmetric_perCh_params = []
145+ for qunit in ["perCh" ]:
146+ for symmetric in [True ]:
147+ for Ngrp in [False ]:
148+ for single_sided in [False ]:
149+ # needs to be disabled for some special cases
150+ for qlevel_lowering in [True ]:
151+ for axis in [0 ]:
152+ qschemes_symmetric_perCh_params .append (
153+ Qscheme (
154+ unit = qunit ,
155+ symmetric = symmetric ,
156+ single_sided = single_sided ,
157+ qlevel_lowering = qlevel_lowering ,
158+ Nch = 1 , # temp value
159+ axis = axis ,
160+ )
161+ )
162+
163+ quantizer_symmetric_perCh_params = []
164+ for num_bits in torch .tensor ([8 , 4 ]):
165+ for scheme in qschemes_symmetric_perCh_params :
166+ quantizer_symmetric_perCh_params .append (
167+ {
168+ "num_bits" : num_bits ,
169+ # "clip_low": -clip_high,
170+ # "clip_high": clip_high,
171+ "scheme" : scheme ,
172+ }
173+ )
174+
175+
176+ @pytest .fixture (scope = "session" , params = quantizer_symmetric_perCh_params )
177+ def quantizer_symmetric_perCh (request ):
178+ """
179+ Fixture tuple for symmetric quantizer w/ per channel clips
180+
181+ Args:
182+ request (dict): Dict for quantizer args
183+
184+ Returns:
185+ dict: Tuple for quantizer args
186+ """
187+ return request .param
188+
138189
139190# Asymmetric
140191qschemes_asymmetric_params = []
@@ -146,7 +197,13 @@ def quantizer_symmetric(request):
146197 for qlevel_lowering in [False ]:
147198 qschemes_asymmetric_params .append (
148199 Qscheme (
149- qunit , symmetric , Nch , Ngrp , single_sided , qlevel_lowering
200+ unit = qunit ,
201+ symmetric = symmetric ,
202+ single_sided = single_sided ,
203+ qlevel_lowering = qlevel_lowering ,
204+ Nch = Nch ,
205+ Ngrp = Ngrp ,
206+ axis = None ,
150207 )
151208 )
152209
@@ -178,6 +235,41 @@ def quantizer_asymmetric(request):
178235 """
179236 return request .param
180237
238+ # Create random clip vals for Per Channel ; must be accompanied by the same tensor
239+ clip_low_perCh = []
240+ clip_high_perCh = []
241+ for tensor_size in tensor_sizes :
242+ clip_low_row = - torch .rand (tensor_size ) - 2.5 # [-3.5, -2.5]
243+ clip_high_row = torch .rand (tensor_size ) + 2.5 # [2.5, 3.5]
244+ clip_low_perCh .append (clip_low_row )
245+ clip_high_perCh .append (clip_high_row )
246+
247+ quantizer_asymmetric_perCh_params = []
248+ for num_bits in torch .tensor ([8 , 4 ]):
249+ for clip_low in clip_low_perCh :
250+ for clip_high in clip_high_perCh :
251+ for scheme in qschemes_asymmetric_params :
252+ quantizer_asymmetric_params .append (
253+ {
254+ "num_bits" : num_bits ,
255+ "clip_low" : clip_low ,
256+ "clip_high" : clip_high ,
257+ "scheme" : scheme ,
258+ }
259+ )
260+
261+ @pytest .fixture (scope = "session" , params = quantizer_asymmetric_perCh_params )
262+ def quantizer_asymmetric_perCh (request ):
263+ """
264+ Fixture tuple for asymmetric quantizer w/ per channel clips
265+
266+ Args:
267+ request (dict): Dict for quantizer args
268+
269+ Returns:
270+ dict: Tuple for quantizer args
271+ """
272+ return request .param
181273
182274# Single-Sided
183275qschemes_single_sided_params = []
@@ -189,7 +281,13 @@ def quantizer_asymmetric(request):
189281 for qlevel_lowering in [False ]:
190282 qschemes_single_sided_params .append (
191283 Qscheme (
192- qunit , symmetric , Nch , Ngrp , single_sided , qlevel_lowering
284+ unit = qunit ,
285+ symmetric = symmetric ,
286+ single_sided = single_sided ,
287+ qlevel_lowering = qlevel_lowering ,
288+ Nch = Nch ,
289+ Ngrp = Ngrp ,
290+ axis = None ,
193291 )
194292 )
195293
@@ -270,6 +368,23 @@ def torch_quantizer_symmetric(quantizer_symmetric):
270368 qscheme = quantizer_symmetric ["scheme" ],
271369 )
272370
371+ @pytest .fixture
372+ def torch_quantizer_symmetric_perCh (quantizer_symmetric_perCh ):
373+ """
374+ Torch Quantizer w/ symmetric settings for perCh
375+
376+ Args:
377+ quantizer_symmetric (dict): Symmetric quantizer settings
378+
379+ Returns:
380+ torch.nn.Module: TorchQuantizer
381+ """
382+ return TorchQuantizer (
383+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
384+ clip_low = quantizer_symmetric_perCh ["clip_low" ],
385+ clip_high = quantizer_symmetric_perCh ["clip_high" ],
386+ qscheme = quantizer_symmetric_perCh ["scheme" ],
387+ )
273388
274389@pytest .fixture
275390def torch_quantizer_asymmetric (quantizer_asymmetric ):
@@ -289,6 +404,24 @@ def torch_quantizer_asymmetric(quantizer_asymmetric):
289404 qscheme = quantizer_asymmetric ["scheme" ],
290405 )
291406
407+ @pytest .fixture
408+ def torch_quantizer_asymmetric_perCh (quantizer_asymmetric_perCh ):
409+ """
410+ Torch Quantizer w/ asymmetric settings for perCh
411+
412+ Args:
413+ quantizer_asymmetric (dict): Asymmetric quantizer settings
414+
415+ Returns:
416+ torch.nn.Module: TorchQuantizer
417+ """
418+ return TorchQuantizer (
419+ num_bits = quantizer_asymmetric_perCh ["num_bits" ],
420+ clip_low = quantizer_asymmetric_perCh ["clip_low" ],
421+ clip_high = quantizer_asymmetric_perCh ["clip_high" ],
422+ qscheme = quantizer_asymmetric_perCh ["scheme" ],
423+ )
424+
292425
293426@pytest .fixture
294427def torch_quantizer_single_sided (quantizer_single_sided ):
@@ -479,6 +612,24 @@ def sawb_quantizer_symmetric(quantizer_symmetric):
479612 # qscheme=quantizer_symmetric["scheme"],
480613 )
481614
615+ @pytest .fixture
616+ def sawb_quantizer_symmetric_perCh (quantizer_symmetric_perCh ):
617+ """
618+ SAWB quantizer w/ symmetric settings
619+
620+ Args:
621+ quantizer_symmetric (dict): Symmetric quantizer settings
622+
623+ Returns:
624+ torch.autograd.Function: SAWB
625+ """
626+ return SAWB (
627+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
628+ # init_clip_valn=quantizer_symmetric["clip_low"],
629+ # init_clip_val=quantizer_symmetric["clip_high"],
630+ # qscheme=quantizer_symmetric["scheme"],
631+ )
632+
482633
483634@pytest .fixture
484635def sawbnew_quantizer_symmetric (quantizer_symmetric ):
0 commit comments