1818
1919# Third Party
2020from test_quantizer_utils import quantizer_error , set_base_options
21+
22+ from fms_mo .quant .quantizers import SAWB
23+ from fms_mo .quant_refactor .sawb_new import SAWB_new
24+ from fms_mo .quant_refactor .torch_quantizer import TorchQuantizer
25+
2126import pytest
2227import torch
2328
@@ -193,6 +198,34 @@ def set_other_options_new(
193198
194199 torch_quantizer .set_quant_bounds ()
195200
201+ def set_per_channel (
202+ tensor : torch .FloatTensor ,
203+ fms_mo_quantizer : torch .autograd .Function ,
204+ torch_quantizer : torch .nn .Module ,
205+ axis : int = 0 ,
206+ ):
207+ """
208+ Setup quantizers to use per channel SAWB
209+
210+ Args:
211+ tensor (torch.FloatTensor): Tensor to quantize.
212+ fms_mo_quantizer (torch.autograd.Function): FMS quantizer.
213+ torch_quantizer (torch.nn.Module): Torch Quantizer
214+ axis (int, optional): Per channel axis dimension. Defaults to 0.
215+ """
216+ Nch = tensor .shape [axis ]
217+
218+ # Setup quantizer to use SAWBPlusZeroPerChSTE
219+ fms_mo_quantizer .clipSTE = True
220+ fms_mo_quantizer .align_zero = True
221+ fms_mo_quantizer .set_quantizer ()
222+
223+ torch_quantizer .qscheme .q_unit = "perCh"
224+ torch_quantizer .qscheme .Nch = Nch
225+ torch_quantizer .qscheme .qlevel_lowering = True
226+ torch_quantizer .set_sawb_clip_code (tensor , perCh = True ) # sets clip vals
227+ torch_quantizer .set_quant_bounds ()
228+
196229
197230##############
198231# SAWB tests #
@@ -217,7 +250,7 @@ def test_sawb_symmetric(
217250 other_options (dict): Other Options for quantization.
218251 """
219252
220- # Set base quantizer and PACT2 options ; save nativePT
253+ # Set base quantizer and SAWB options ; save nativePT
221254 native_pt = base_options ["nativePT" ]
222255 base_options ["nativePT" ] = False # Not supported for SAWB
223256 set_base_options (sawb_quantizer_symmetric , torch_quantizer_symmetric , base_options )
@@ -277,7 +310,7 @@ def test_sawbnew_symmetric(
277310 base_options (dict): Base options for quantization.
278311 other_options (dict): Other Options for quantization.
279312 """
280- # Set base quantizer and PACT2 options
313+ # Set base quantizer and SAWB options
281314 set_base_options (
282315 sawbnew_quantizer_symmetric , torch_quantizer_symmetric , base_options
283316 )
@@ -303,3 +336,54 @@ def test_sawbnew_symmetric(
303336 tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options
304337 )
305338 torch_quantizer_symmetric .qscheme .qlevel_lowering = qlevel_lowering # reset value
339+
340+ def test_sawb_symmetric_perCh (
341+ tensor : torch .FloatTensor ,
342+ quantizer_symmetric_perCh : dict ,
343+ base_options : dict ,
344+ # other_options: dict, # only 1 STE for this case right now
345+ ):
346+ """
347+ Test SAWB w/ symmetric tensors for per channel
348+
349+ Args:
350+ tensor (torch.FloatTensor): Tensor to quantize.
351+ quantizer_symmetric_perCh (dict): Symmetric quantizer settings for per channel.
352+ base_options (dict): Base options for quantization.
353+ other_options (dict): Other Options for quantization.
354+ """
355+
356+ Nch = tensor .shape [0 ]
357+ clip_val = torch .rand (Nch ) + 2.5 # [2.5,3.5]
358+
359+ # SAWB computes clip_val_vec in forward()
360+ sawb_quantizer_symmetric_perCh = SAWB (
361+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
362+ perCh = Nch ,
363+ )
364+
365+ # Clip val is not optional, but gets overriden in set_per_channel
366+ torch_quantizer_symmetric_perCh = TorchQuantizer (
367+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
368+ clip_low = - clip_val ,
369+ clip_high = clip_val ,
370+ qscheme = quantizer_symmetric_perCh ["scheme" ]
371+ )
372+
373+ # Set base quantizer and SAWB options ; save nativePT
374+ native_pt = base_options ["nativePT" ]
375+ base_options ["nativePT" ] = False # Not supported for SAWB
376+ set_base_options (sawb_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh , base_options )
377+ set_per_channel (tensor , sawb_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh )
378+
379+ # Create quantized tensors from FMS Model Optimizer + torch
380+ qtensor_fms_mo = sawb_quantizer_symmetric_perCh (tensor ).detach ()
381+ qtensor_torch = torch_quantizer_symmetric_perCh (tensor ).detach ()
382+
383+ setup = torch_quantizer_symmetric_perCh .get_setup ()
384+
385+ # There should be no differences between these two tensors
386+ quantizer_error (
387+ tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options
388+ )
389+ base_options ["nativePT" ] = native_pt # reset value
0 commit comments