Skip to content

Commit b17a25e

Browse files
committed
feat: Added perCh fixtures for sawb to conftest
Signed-off-by: Brandon Groth <[email protected]>
1 parent 603f060 commit b17a25e

File tree

1 file changed

+159
-8
lines changed

1 file changed

+159
-8
lines changed

tests/quantizers/conftest.py

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,22 @@ def tensor_single_sided(request):
9494

9595
# Symmetric
9696
qschemes_symmetric_params = []
97-
for qunit in ["perT"]: # ['perT','perCh','perGrp']:
97+
for qunit in ["perT"]:
9898
for symmetric in [True]:
9999
for Nch in [None]:
100100
for Ngrp in [None]:
101101
for single_sided in [False]:
102-
for qlevel_lowering in [
103-
True
104-
]: # needs to be disabled for some special cases
102+
# needs to be disabled for some special cases
103+
for qlevel_lowering in [True]:
105104
qschemes_symmetric_params.append(
106105
Qscheme(
107-
qunit, symmetric, Nch, Ngrp, single_sided, qlevel_lowering
106+
unit=qunit,
107+
symmetric=symmetric,
108+
single_sided=single_sided,
109+
qlevel_lowering=qlevel_lowering,
110+
Nch=Nch,
111+
Ngrp=Ngrp,
112+
axis=None,
108113
)
109114
)
110115

@@ -121,7 +126,6 @@ def tensor_single_sided(request):
121126
}
122127
)
123128

124-
125129
@pytest.fixture(scope="session", params=quantizer_symmetric_params)
126130
def quantizer_symmetric(request):
127131
"""
@@ -135,6 +139,53 @@ def quantizer_symmetric(request):
135139
"""
136140
return request.param
137141

142+
# Per channel symmetric params
143+
# clip_high, Nch will be computed at test level from tensor
144+
qschemes_symmetric_perCh_params = []
145+
for qunit in ["perCh"]:
146+
for symmetric in [True]:
147+
for Ngrp in [False]:
148+
for single_sided in [False]:
149+
# needs to be disabled for some special cases
150+
for qlevel_lowering in [True]:
151+
for axis in [0]:
152+
qschemes_symmetric_perCh_params.append(
153+
Qscheme(
154+
unit=qunit,
155+
symmetric=symmetric,
156+
single_sided=single_sided,
157+
qlevel_lowering=qlevel_lowering,
158+
Nch=1, # temp value
159+
axis=axis,
160+
)
161+
)
162+
163+
quantizer_symmetric_perCh_params = []
164+
for num_bits in torch.tensor([8, 4]):
165+
for scheme in qschemes_symmetric_perCh_params:
166+
quantizer_symmetric_perCh_params.append(
167+
{
168+
"num_bits": num_bits,
169+
# "clip_low": -clip_high,
170+
# "clip_high": clip_high,
171+
"scheme": scheme,
172+
}
173+
)
174+
175+
176+
@pytest.fixture(scope="session", params=quantizer_symmetric_perCh_params)
177+
def quantizer_symmetric_perCh(request):
178+
"""
179+
Fixture tuple for symmetric quantizer w/ per channel clips
180+
181+
Args:
182+
request (dict): Dict for quantizer args
183+
184+
Returns:
185+
dict: Tuple for quantizer args
186+
"""
187+
return request.param
188+
138189

139190
# Asymmetric
140191
qschemes_asymmetric_params = []
@@ -146,7 +197,13 @@ def quantizer_symmetric(request):
146197
for qlevel_lowering in [False]:
147198
qschemes_asymmetric_params.append(
148199
Qscheme(
149-
qunit, symmetric, Nch, Ngrp, single_sided, qlevel_lowering
200+
unit=qunit,
201+
symmetric=symmetric,
202+
single_sided=single_sided,
203+
qlevel_lowering=qlevel_lowering,
204+
Nch=Nch,
205+
Ngrp=Ngrp,
206+
axis=None,
150207
)
151208
)
152209

@@ -178,6 +235,41 @@ def quantizer_asymmetric(request):
178235
"""
179236
return request.param
180237

238+
# Create random clip vals for Per Channel ; must be accompanied by the same tensor
239+
clip_low_perCh = []
240+
clip_high_perCh = []
241+
for tensor_size in tensor_sizes:
242+
clip_low_row = -torch.rand(tensor_size) - 2.5 # [-3.5, -2.5]
243+
clip_high_row = torch.rand(tensor_size) + 2.5 # [2.5, 3.5]
244+
clip_low_perCh.append(clip_low_row)
245+
clip_high_perCh.append(clip_high_row)
246+
247+
quantizer_asymmetric_perCh_params = []
248+
for num_bits in torch.tensor([8, 4]):
249+
for clip_low in clip_low_perCh:
250+
for clip_high in clip_high_perCh:
251+
for scheme in qschemes_asymmetric_params:
252+
quantizer_asymmetric_params.append(
253+
{
254+
"num_bits": num_bits,
255+
"clip_low": clip_low,
256+
"clip_high": clip_high,
257+
"scheme": scheme,
258+
}
259+
)
260+
261+
@pytest.fixture(scope="session", params=quantizer_asymmetric_perCh_params)
262+
def quantizer_asymmetric_perCh(request):
263+
"""
264+
Fixture tuple for asymmetric quantizer w/ per channel clips
265+
266+
Args:
267+
request (dict): Dict for quantizer args
268+
269+
Returns:
270+
dict: Tuple for quantizer args
271+
"""
272+
return request.param
181273

182274
# Single-Sided
183275
qschemes_single_sided_params = []
@@ -189,7 +281,13 @@ def quantizer_asymmetric(request):
189281
for qlevel_lowering in [False]:
190282
qschemes_single_sided_params.append(
191283
Qscheme(
192-
qunit, symmetric, Nch, Ngrp, single_sided, qlevel_lowering
284+
unit=qunit,
285+
symmetric=symmetric,
286+
single_sided=single_sided,
287+
qlevel_lowering=qlevel_lowering,
288+
Nch=Nch,
289+
Ngrp=Ngrp,
290+
axis=None,
193291
)
194292
)
195293

@@ -270,6 +368,23 @@ def torch_quantizer_symmetric(quantizer_symmetric):
270368
qscheme=quantizer_symmetric["scheme"],
271369
)
272370

371+
@pytest.fixture
372+
def torch_quantizer_symmetric_perCh(quantizer_symmetric_perCh):
373+
"""
374+
Torch Quantizer w/ symmetric settings for perCh
375+
376+
Args:
377+
quantizer_symmetric (dict): Symmetric quantizer settings
378+
379+
Returns:
380+
torch.nn.Module: TorchQuantizer
381+
"""
382+
return TorchQuantizer(
383+
num_bits=quantizer_symmetric_perCh["num_bits"],
384+
clip_low=quantizer_symmetric_perCh["clip_low"],
385+
clip_high=quantizer_symmetric_perCh["clip_high"],
386+
qscheme=quantizer_symmetric_perCh["scheme"],
387+
)
273388

274389
@pytest.fixture
275390
def torch_quantizer_asymmetric(quantizer_asymmetric):
@@ -289,6 +404,24 @@ def torch_quantizer_asymmetric(quantizer_asymmetric):
289404
qscheme=quantizer_asymmetric["scheme"],
290405
)
291406

407+
@pytest.fixture
408+
def torch_quantizer_asymmetric_perCh(quantizer_asymmetric_perCh):
409+
"""
410+
Torch Quantizer w/ asymmetric settings for perCh
411+
412+
Args:
413+
quantizer_asymmetric (dict): Asymmetric quantizer settings
414+
415+
Returns:
416+
torch.nn.Module: TorchQuantizer
417+
"""
418+
return TorchQuantizer(
419+
num_bits=quantizer_asymmetric_perCh["num_bits"],
420+
clip_low=quantizer_asymmetric_perCh["clip_low"],
421+
clip_high=quantizer_asymmetric_perCh["clip_high"],
422+
qscheme=quantizer_asymmetric_perCh["scheme"],
423+
)
424+
292425

293426
@pytest.fixture
294427
def torch_quantizer_single_sided(quantizer_single_sided):
@@ -479,6 +612,24 @@ def sawb_quantizer_symmetric(quantizer_symmetric):
479612
# qscheme=quantizer_symmetric["scheme"],
480613
)
481614

615+
@pytest.fixture
616+
def sawb_quantizer_symmetric_perCh(quantizer_symmetric_perCh):
617+
"""
618+
SAWB quantizer w/ symmetric settings
619+
620+
Args:
621+
quantizer_symmetric (dict): Symmetric quantizer settings
622+
623+
Returns:
624+
torch.autograd.Function: SAWB
625+
"""
626+
return SAWB(
627+
num_bits=quantizer_symmetric_perCh["num_bits"],
628+
# init_clip_valn=quantizer_symmetric["clip_low"],
629+
# init_clip_val=quantizer_symmetric["clip_high"],
630+
# qscheme=quantizer_symmetric["scheme"],
631+
)
632+
482633

483634
@pytest.fixture
484635
def sawbnew_quantizer_symmetric(quantizer_symmetric):

0 commit comments

Comments
 (0)