Skip to content

Commit bfadd63

Browse files
committed
fix: test_sawb perCh update
Signed-off-by: Brandon Groth <[email protected]>
1 parent 3e8dbce commit bfadd63

File tree

1 file changed

+86
-2
lines changed

1 file changed

+86
-2
lines changed

tests/quantizers/test_sawb.py

Lines changed: 86 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818

1919
# Third Party
2020
from test_quantizer_utils import quantizer_error, set_base_options
21+
22+
from fms_mo.quant.quantizers import SAWB
23+
from fms_mo.quant_refactor.sawb_new import SAWB_new
24+
from fms_mo.quant_refactor.torch_quantizer import TorchQuantizer
25+
2126
import pytest
2227
import torch
2328

@@ -193,6 +198,34 @@ def set_other_options_new(
193198

194199
torch_quantizer.set_quant_bounds()
195200

201+
def set_per_channel(
202+
tensor: torch.FloatTensor,
203+
fms_mo_quantizer: torch.autograd.Function,
204+
torch_quantizer: torch.nn.Module,
205+
axis: int = 0,
206+
):
207+
"""
208+
Setup quantizers to use per channel SAWB
209+
210+
Args:
211+
tensor (torch.FloatTensor): Tensor to quantize.
212+
fms_mo_quantizer (torch.autograd.Function): FMS quantizer.
213+
torch_quantizer (torch.nn.Module): Torch Quantizer
214+
axis (int, optional): Per channel axis dimension. Defaults to 0.
215+
"""
216+
Nch = tensor.shape[axis]
217+
218+
# Setup quantizer to use SAWBPlusZeroPerChSTE
219+
fms_mo_quantizer.clipSTE = True
220+
fms_mo_quantizer.align_zero = True
221+
fms_mo_quantizer.set_quantizer()
222+
223+
torch_quantizer.qscheme.q_unit = "perCh"
224+
torch_quantizer.qscheme.Nch = Nch
225+
torch_quantizer.qscheme.qlevel_lowering = True
226+
torch_quantizer.set_sawb_clip_code(tensor, perCh=True) # sets clip vals
227+
torch_quantizer.set_quant_bounds()
228+
196229

197230
##############
198231
# SAWB tests #
@@ -217,7 +250,7 @@ def test_sawb_symmetric(
217250
other_options (dict): Other Options for quantization.
218251
"""
219252

220-
# Set base quantizer and PACT2 options ; save nativePT
253+
# Set base quantizer and SAWB options ; save nativePT
221254
native_pt = base_options["nativePT"]
222255
base_options["nativePT"] = False # Not supported for SAWB
223256
set_base_options(sawb_quantizer_symmetric, torch_quantizer_symmetric, base_options)
@@ -277,7 +310,7 @@ def test_sawbnew_symmetric(
277310
base_options (dict): Base options for quantization.
278311
other_options (dict): Other Options for quantization.
279312
"""
280-
# Set base quantizer and PACT2 options
313+
# Set base quantizer and SAWB options
281314
set_base_options(
282315
sawbnew_quantizer_symmetric, torch_quantizer_symmetric, base_options
283316
)
@@ -303,3 +336,54 @@ def test_sawbnew_symmetric(
303336
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options
304337
)
305338
torch_quantizer_symmetric.qscheme.qlevel_lowering = qlevel_lowering # reset value
339+
340+
def test_sawb_symmetric_perCh(
341+
tensor: torch.FloatTensor,
342+
quantizer_symmetric_perCh: dict,
343+
base_options: dict,
344+
# other_options: dict, # only 1 STE for this case right now
345+
):
346+
"""
347+
Test SAWB w/ symmetric tensors for per channel
348+
349+
Args:
350+
tensor (torch.FloatTensor): Tensor to quantize.
351+
quantizer_symmetric_perCh (dict): Symmetric quantizer settings for per channel.
352+
base_options (dict): Base options for quantization.
353+
other_options (dict): Other Options for quantization.
354+
"""
355+
356+
Nch = tensor.shape[0]
357+
clip_val = torch.rand(Nch) + 2.5 # [2.5,3.5]
358+
359+
# SAWB computes clip_val_vec in forward()
360+
sawb_quantizer_symmetric_perCh = SAWB(
361+
num_bits = quantizer_symmetric_perCh["num_bits"],
362+
perCh = Nch,
363+
)
364+
365+
# Clip val is not optional, but gets overriden in set_per_channel
366+
torch_quantizer_symmetric_perCh = TorchQuantizer(
367+
num_bits = quantizer_symmetric_perCh["num_bits"],
368+
clip_low = -clip_val,
369+
clip_high = clip_val,
370+
qscheme = quantizer_symmetric_perCh["scheme"]
371+
)
372+
373+
# Set base quantizer and SAWB options ; save nativePT
374+
native_pt = base_options["nativePT"]
375+
base_options["nativePT"] = False # Not supported for SAWB
376+
set_base_options(sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh, base_options)
377+
set_per_channel(tensor, sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh)
378+
379+
# Create quantized tensors from FMS Model Optimizer + torch
380+
qtensor_fms_mo = sawb_quantizer_symmetric_perCh(tensor).detach()
381+
qtensor_torch = torch_quantizer_symmetric_perCh(tensor).detach()
382+
383+
setup = torch_quantizer_symmetric_perCh.get_setup()
384+
385+
# There should be no differences between these two tensors
386+
quantizer_error(
387+
tensor, qtensor_fms_mo, qtensor_torch, setup, base_options, other_options
388+
)
389+
base_options["nativePT"] = native_pt # reset value

0 commit comments

Comments
 (0)