Skip to content

Commit 49ae7d9

Browse files
committed
fix: Updates for test_sawb perCh
Signed-off-by: Brandon Groth <[email protected]>
1 parent 86a82ee commit 49ae7d9

File tree

1 file changed

+100
-10
lines changed

1 file changed

+100
-10
lines changed

tests/quantizers/test_sawb.py

Lines changed: 100 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
# Third Party
2020
from test_quantizer_utils import quantizer_error, set_base_options
2121

22-
from fms_mo.quant.quantizers import SAWB
22+
from fms_mo.quant_refactor.quantizers_new import SAWB
2323
from fms_mo.quant_refactor.sawb_new import SAWB_new
2424
from fms_mo.quant_refactor.torch_quantizer import TorchQuantizer
2525

@@ -29,7 +29,7 @@
2929
other_option_params = []
3030
for clipSTE in [True, False]:
3131
for align_zero in [True, False]:
32-
for use16bins in [True, False]:
32+
for use16bins in [False]:
3333
other_option_params.append(
3434
{"clipSTE": clipSTE, "align_zero": align_zero, "use16bins": use16bins}
3535
)
@@ -55,6 +55,7 @@ def set_other_options(
5555
fms_mo_quantizer: torch.autograd.Function,
5656
torch_quantizer: torch.nn.Module,
5757
other_option: dict,
58+
axis: int = 0,
5859
):
5960
"""
6061
Set other options for FMS and Torch quantizer
@@ -64,6 +65,7 @@ def set_other_options(
6465
fms_mo_quantizer (torch.autograd.Function): FMS quantizer
6566
torch_quantizer (torch.nn.Module): Torch Quantizer
6667
other_option (dict): Other Option params
68+
axis (int, optional): Per channel axis dimension. Defaults to 0.
6769
"""
6870
fms_mo_quantizer.clipSTE = other_option["clipSTE"]
6971
fms_mo_quantizer.align_zero = other_option["align_zero"]
@@ -76,9 +78,12 @@ def set_other_options(
7678
# For SAWB Zero STEs
7779
if other_option["align_zero"]:
7880
# SAWBPlus16ZeroSTE - no sawb_params
79-
if other_option["clipSTE"] and other_option["use16bins"] and num_bits == 4:
81+
if torch_quantizer.qscheme.q_unit == "perT" \
82+
and other_option["clipSTE"] and other_option["use16bins"] and num_bits == 4:
8083
# Set num_bits and [clip_low,clip_high]
8184
torch_quantizer.n_levels = 2**num_bits - 1
85+
torch_quantizer.quant_min = -8
86+
torch_quantizer.quant_max = 7
8287
torch_quantizer.set_sawb_clip_code(tensor)
8388

8489
# Scale uses clip_high
@@ -88,9 +93,17 @@ def set_other_options(
8893
torch_quantizer.zero_point = torch.tensor(0)
8994
torch_quantizer.set_shift_sawb(0)
9095

96+
# Do not call set_quant_range() ; overriden w/ fixed [qint_min, qint_max]
97+
return
98+
9199
# SAWBPlusZeroPerChSTE - TODO: perCh test not functional yet
92100
elif other_option["clipSTE"] and torch_quantizer.qscheme.q_unit == "perCh":
93-
pass
101+
Nch = tensor.shape[axis]
102+
103+
torch_quantizer.qscheme.q_unit = "perCh"
104+
torch_quantizer.qscheme.Nch = Nch
105+
torch_quantizer.qscheme.qlevel_lowering = True
106+
torch_quantizer.set_sawb_clip_code(tensor, perCh=True) # sets clip vals
94107

95108
else: # SAWBZeroSTE, SAWBPlusZeroSTE - sawb_params_code
96109
# Set num_bits and [clip_low,clip_high]
@@ -146,6 +159,7 @@ def set_other_options_new(
146159
fms_mo_quantizer: torch.autograd.Function,
147160
torch_quantizer: torch.nn.Module,
148161
other_option: dict,
162+
axis: int = 0,
149163
):
150164
"""
151165
Set other options for new FMS and Torch quantizer
@@ -155,6 +169,7 @@ def set_other_options_new(
155169
fms_mo_quantizer (torch.autograd.Function): FMS quantizer
156170
torch_quantizer (torch.nn.Module): Torch Quantizer
157171
other_option (dict): Other Option params
172+
axis (int, optional): Per channel axis dimension. Defaults to 0.
158173
"""
159174
fms_mo_quantizer.clipSTE = other_option["clipSTE"]
160175
fms_mo_quantizer.align_zero = other_option["align_zero"]
@@ -170,7 +185,8 @@ def set_other_options_new(
170185
# For SAWB Zero STEs
171186
if other_option["align_zero"]:
172187
# SAWBPlus16ZeroSTE_new - no sawb_params
173-
if other_option["clipSTE"] and other_option["use16bins"] and num_bits == 4:
188+
if torch_quantizer.qscheme.q_unit == "perT" \
189+
and other_option["clipSTE"] and other_option["use16bins"] and num_bits == 4:
174190
# Set num_bits and [clip_low,clip_high]
175191
torch_quantizer.n_levels = 2**num_bits - 1
176192
torch_quantizer.set_sawb_clip_code(tensor, code=403) # sets clip_high
@@ -184,8 +200,15 @@ def set_other_options_new(
184200
torch_quantizer.set_quant_range()
185201

186202
# SAWBPlusZeroPerChSTE_new - TODO: perCh test not functional yet
187-
elif other_option["clipSTE"] and torch_quantizer.qscheme.q_unit == "perCh":
188-
pass
203+
elif torch_quantizer.qscheme.q_unit == "perCh" and other_option["clipSTE"] and\
204+
not other_option["use16bins"]:
205+
Nch = tensor.shape[axis]
206+
207+
torch_quantizer.qscheme.q_unit = "perCh"
208+
torch_quantizer.qscheme.Nch = Nch
209+
torch_quantizer.qscheme.qlevel_lowering = True
210+
torch_quantizer.set_sawb_clip_code(tensor, perCh=True) # sets clip vals
211+
torch_quantizer.set_quant_bounds()
189212

190213
else: # SAWBZeroSTE_new, SAWBPlusZeroSTE_new - sawb_params_code
191214
torch_quantizer.set_sawb_clip_code(tensor)
@@ -218,6 +241,7 @@ def set_per_channel(
218241
# Setup quantizer to use SAWBPlusZeroPerChSTE
219242
fms_mo_quantizer.clipSTE = True
220243
fms_mo_quantizer.align_zero = True
244+
fms_mo_quantizer.recompute_clips = True
221245
fms_mo_quantizer.set_quantizer()
222246

223247
torch_quantizer.qscheme.q_unit = "perCh"
@@ -255,6 +279,9 @@ def test_sawb_symmetric(
255279
base_options["nativePT"] = False # Not supported for SAWB
256280
set_base_options(sawb_quantizer_symmetric, torch_quantizer_symmetric, base_options)
257281
# SAWB requires tensor to set parameters for TorchQuantizer
282+
283+
use16bins = other_options["use16bins"]
284+
other_options["use16bins"] = False # Not implemented for quantizer_new.SAWB
258285
set_other_options(
259286
tensor, sawb_quantizer_symmetric, torch_quantizer_symmetric, other_options
260287
)
@@ -291,6 +318,7 @@ def test_sawb_symmetric(
291318
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options
292319
)
293320
base_options["nativePT"] = native_pt # reset value
321+
other_options["use16bins"] = use16bins
294322

295323

296324
def test_sawbnew_symmetric(
@@ -341,7 +369,7 @@ def test_sawb_symmetric_perCh(
341369
tensor: torch.FloatTensor,
342370
quantizer_symmetric_perCh: dict,
343371
base_options: dict,
344-
# other_options: dict, # only 1 STE for this case right now
372+
other_options: dict, # only 1 STE for this case right now
345373
):
346374
"""
347375
Test SAWB w/ symmetric tensors for per channel
@@ -375,6 +403,9 @@ def test_sawb_symmetric_perCh(
375403
base_options["nativePT"] = False # Not supported for SAWB
376404
set_base_options(sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh, base_options)
377405
set_per_channel(tensor, sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh)
406+
# set_other_options_new(
407+
# tensor, sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh, other_options
408+
# )
378409

379410
# Create quantized tensors from FMS Model Optimizer + torch
380411
qtensor_fms_mo = sawb_quantizer_symmetric_perCh(tensor).detach()
@@ -383,7 +414,66 @@ def test_sawb_symmetric_perCh(
383414
setup = torch_quantizer_symmetric_perCh.get_setup()
384415

385416
# There should be no differences between these two tensors
417+
# SAWB uses torch functions, so zero out errors
386418
quantizer_error(
387-
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options
419+
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options,
420+
max_norm_tol=0.0, l2_norm_tol=0.0, nonzero_tol=0.0
388421
)
389-
base_options["nativePT"] = native_pt # reset value
422+
base_options["nativePT"] = native_pt # reset value
423+
424+
def test_sawbnew_symmetric_perCh(
425+
tensor: torch.FloatTensor,
426+
quantizer_symmetric_perCh: dict,
427+
base_options: dict,
428+
other_options: dict, # only 1 STE for this case right now
429+
):
430+
"""
431+
Test SAWB_new w/ symmetric tensors for perCh
432+
433+
Args:
434+
tensor (torch.FloatTensor): Tensor to quantize.
435+
base_options (dict): Base options for quantization.
436+
other_options (dict): Other Options for quantization.
437+
"""
438+
Nch = tensor.shape[0]
439+
clip_val = torch.rand(Nch) + 2.5 # [2.5,3.5]
440+
441+
# Need to set proper Nch; registered parameters can't change shape (Quantizer.init())
442+
qscheme = quantizer_symmetric_perCh["scheme"]
443+
qscheme.Nch = Nch
444+
445+
# SAWB computes clip_val_vec in forward()
446+
sawbnew_quantizer_symmetric_perCh = SAWB_new(
447+
num_bits = quantizer_symmetric_perCh["num_bits"],
448+
init_clip_valn=-clip_val,
449+
init_clip_val=clip_val,
450+
qscheme = qscheme,
451+
)
452+
453+
# Clip val is not optional, but gets overriden in set_per_channel
454+
torch_quantizer_symmetric_perCh = TorchQuantizer(
455+
num_bits = quantizer_symmetric_perCh["num_bits"],
456+
clip_low = -clip_val,
457+
clip_high = clip_val,
458+
qscheme = qscheme,
459+
)
460+
461+
# Set base quantizer and SAWB options ; save nativePT
462+
native_pt = base_options["nativePT"]
463+
base_options["nativePT"] = False # Not supported for SAWB
464+
set_base_options(
465+
sawbnew_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh, base_options
466+
)
467+
# SAWB requires tensor to set parameters for TorchQuantizer
468+
set_per_channel(tensor, sawbnew_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh)
469+
470+
471+
# Create quantized tensors from FMS Model Optimizer + torch
472+
qtensor_fms_mo = sawbnew_quantizer_symmetric_perCh(tensor).detach()
473+
qtensor_torch = torch_quantizer_symmetric_perCh(tensor).detach()
474+
475+
setup = torch_quantizer_symmetric_perCh.get_setup()
476+
477+
quantizer_error(
478+
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options,
479+
)

0 commit comments

Comments
 (0)