1919# Third Party
2020from test_quantizer_utils import quantizer_error , set_base_options
2121
22- from fms_mo .quant . quantizers import SAWB
22+ from fms_mo .quant_refactor . quantizers_new import SAWB
2323from fms_mo .quant_refactor .sawb_new import SAWB_new
2424from fms_mo .quant_refactor .torch_quantizer import TorchQuantizer
2525
2929other_option_params = []
3030for clipSTE in [True , False ]:
3131 for align_zero in [True , False ]:
32- for use16bins in [True , False ]:
32+ for use16bins in [False ]:
3333 other_option_params .append (
3434 {"clipSTE" : clipSTE , "align_zero" : align_zero , "use16bins" : use16bins }
3535 )
@@ -55,6 +55,7 @@ def set_other_options(
5555 fms_mo_quantizer : torch .autograd .Function ,
5656 torch_quantizer : torch .nn .Module ,
5757 other_option : dict ,
58+ axis : int = 0 ,
5859):
5960 """
6061 Set other options for FMS and Torch quantizer
@@ -64,6 +65,7 @@ def set_other_options(
6465 fms_mo_quantizer (torch.autograd.Function): FMS quantizer
6566 torch_quantizer (torch.nn.Module): Torch Quantizer
6667 other_option (dict): Other Option params
68+ axis (int, optional): Per channel axis dimension. Defaults to 0.
6769 """
6870 fms_mo_quantizer .clipSTE = other_option ["clipSTE" ]
6971 fms_mo_quantizer .align_zero = other_option ["align_zero" ]
@@ -76,9 +78,12 @@ def set_other_options(
7678 # For SAWB Zero STEs
7779 if other_option ["align_zero" ]:
7880 # SAWBPlus16ZeroSTE - no sawb_params
79- if other_option ["clipSTE" ] and other_option ["use16bins" ] and num_bits == 4 :
81+ if torch_quantizer .qscheme .q_unit == "perT" \
82+ and other_option ["clipSTE" ] and other_option ["use16bins" ] and num_bits == 4 :
8083 # Set num_bits and [clip_low,clip_high]
8184 torch_quantizer .n_levels = 2 ** num_bits - 1
85+ torch_quantizer .quant_min = - 8
86+ torch_quantizer .quant_max = 7
8287 torch_quantizer .set_sawb_clip_code (tensor )
8388
8489 # Scale uses clip_high
@@ -88,9 +93,17 @@ def set_other_options(
8893 torch_quantizer .zero_point = torch .tensor (0 )
8994 torch_quantizer .set_shift_sawb (0 )
9095
96+ # Do not call set_quant_range() ; overriden w/ fixed [qint_min, qint_max]
97+ return
98+
9199 # SAWBPlusZeroPerChSTE - TODO: perCh test not functional yet
92100 elif other_option ["clipSTE" ] and torch_quantizer .qscheme .q_unit == "perCh" :
93- pass
101+ Nch = tensor .shape [axis ]
102+
103+ torch_quantizer .qscheme .q_unit = "perCh"
104+ torch_quantizer .qscheme .Nch = Nch
105+ torch_quantizer .qscheme .qlevel_lowering = True
106+ torch_quantizer .set_sawb_clip_code (tensor , perCh = True ) # sets clip vals
94107
95108 else : # SAWBZeroSTE, SAWBPlusZeroSTE - sawb_params_code
96109 # Set num_bits and [clip_low,clip_high]
@@ -146,6 +159,7 @@ def set_other_options_new(
146159 fms_mo_quantizer : torch .autograd .Function ,
147160 torch_quantizer : torch .nn .Module ,
148161 other_option : dict ,
162+ axis : int = 0 ,
149163):
150164 """
151165 Set other options for new FMS and Torch quantizer
@@ -155,6 +169,7 @@ def set_other_options_new(
155169 fms_mo_quantizer (torch.autograd.Function): FMS quantizer
156170 torch_quantizer (torch.nn.Module): Torch Quantizer
157171 other_option (dict): Other Option params
172+ axis (int, optional): Per channel axis dimension. Defaults to 0.
158173 """
159174 fms_mo_quantizer .clipSTE = other_option ["clipSTE" ]
160175 fms_mo_quantizer .align_zero = other_option ["align_zero" ]
@@ -170,7 +185,8 @@ def set_other_options_new(
170185 # For SAWB Zero STEs
171186 if other_option ["align_zero" ]:
172187 # SAWBPlus16ZeroSTE_new - no sawb_params
173- if other_option ["clipSTE" ] and other_option ["use16bins" ] and num_bits == 4 :
188+ if torch_quantizer .qscheme .q_unit == "perT" \
189+ and other_option ["clipSTE" ] and other_option ["use16bins" ] and num_bits == 4 :
174190 # Set num_bits and [clip_low,clip_high]
175191 torch_quantizer .n_levels = 2 ** num_bits - 1
176192 torch_quantizer .set_sawb_clip_code (tensor , code = 403 ) # sets clip_high
@@ -184,8 +200,15 @@ def set_other_options_new(
184200 torch_quantizer .set_quant_range ()
185201
186202 # SAWBPlusZeroPerChSTE_new - TODO: perCh test not functional yet
187- elif other_option ["clipSTE" ] and torch_quantizer .qscheme .q_unit == "perCh" :
188- pass
203+ elif torch_quantizer .qscheme .q_unit == "perCh" and other_option ["clipSTE" ] and \
204+ not other_option ["use16bins" ]:
205+ Nch = tensor .shape [axis ]
206+
207+ torch_quantizer .qscheme .q_unit = "perCh"
208+ torch_quantizer .qscheme .Nch = Nch
209+ torch_quantizer .qscheme .qlevel_lowering = True
210+ torch_quantizer .set_sawb_clip_code (tensor , perCh = True ) # sets clip vals
211+ torch_quantizer .set_quant_bounds ()
189212
190213 else : # SAWBZeroSTE_new, SAWBPlusZeroSTE_new - sawb_params_code
191214 torch_quantizer .set_sawb_clip_code (tensor )
@@ -218,6 +241,7 @@ def set_per_channel(
218241 # Setup quantizer to use SAWBPlusZeroPerChSTE
219242 fms_mo_quantizer .clipSTE = True
220243 fms_mo_quantizer .align_zero = True
244+ fms_mo_quantizer .recompute_clips = True
221245 fms_mo_quantizer .set_quantizer ()
222246
223247 torch_quantizer .qscheme .q_unit = "perCh"
@@ -255,6 +279,9 @@ def test_sawb_symmetric(
255279 base_options ["nativePT" ] = False # Not supported for SAWB
256280 set_base_options (sawb_quantizer_symmetric , torch_quantizer_symmetric , base_options )
257281 # SAWB requires tensor to set parameters for TorchQuantizer
282+
283+ use16bins = other_options ["use16bins" ]
284+ other_options ["use16bins" ] = False # Not implemented for quantizer_new.SAWB
258285 set_other_options (
259286 tensor , sawb_quantizer_symmetric , torch_quantizer_symmetric , other_options
260287 )
@@ -291,6 +318,7 @@ def test_sawb_symmetric(
291318 tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options
292319 )
293320 base_options ["nativePT" ] = native_pt # reset value
321+ other_options ["use16bins" ] = use16bins
294322
295323
296324def test_sawbnew_symmetric (
@@ -341,7 +369,7 @@ def test_sawb_symmetric_perCh(
341369 tensor : torch .FloatTensor ,
342370 quantizer_symmetric_perCh : dict ,
343371 base_options : dict ,
344- # other_options: dict, # only 1 STE for this case right now
372+ other_options : dict , # only 1 STE for this case right now
345373):
346374 """
347375 Test SAWB w/ symmetric tensors for per channel
@@ -375,6 +403,9 @@ def test_sawb_symmetric_perCh(
375403 base_options ["nativePT" ] = False # Not supported for SAWB
376404 set_base_options (sawb_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh , base_options )
377405 set_per_channel (tensor , sawb_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh )
406+ # set_other_options_new(
407+ # tensor, sawb_quantizer_symmetric_perCh, torch_quantizer_symmetric_perCh, other_options
408+ # )
378409
379410 # Create quantized tensors from FMS Model Optimizer + torch
380411 qtensor_fms_mo = sawb_quantizer_symmetric_perCh (tensor ).detach ()
@@ -383,7 +414,66 @@ def test_sawb_symmetric_perCh(
383414 setup = torch_quantizer_symmetric_perCh .get_setup ()
384415
385416 # There should be no differences between these two tensors
417+ # SAWB uses torch functions, so zero out errors
386418 quantizer_error (
387- tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options
419+ tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options ,
420+ max_norm_tol = 0.0 , l2_norm_tol = 0.0 , nonzero_tol = 0.0
388421 )
389- base_options ["nativePT" ] = native_pt # reset value
422+ base_options ["nativePT" ] = native_pt # reset value
423+
424+ def test_sawbnew_symmetric_perCh (
425+ tensor : torch .FloatTensor ,
426+ quantizer_symmetric_perCh : dict ,
427+ base_options : dict ,
428+ other_options : dict , # only 1 STE for this case right now
429+ ):
430+ """
431+ Test SAWB_new w/ symmetric tensors for perCh
432+
433+ Args:
434+ tensor (torch.FloatTensor): Tensor to quantize.
435+ base_options (dict): Base options for quantization.
436+ other_options (dict): Other Options for quantization.
437+ """
438+ Nch = tensor .shape [0 ]
439+ clip_val = torch .rand (Nch ) + 2.5 # [2.5,3.5]
440+
441+ # Need to set proper Nch; registered parameters can't change shape (Quantizer.init())
442+ qscheme = quantizer_symmetric_perCh ["scheme" ]
443+ qscheme .Nch = Nch
444+
445+ # SAWB computes clip_val_vec in forward()
446+ sawbnew_quantizer_symmetric_perCh = SAWB_new (
447+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
448+ init_clip_valn = - clip_val ,
449+ init_clip_val = clip_val ,
450+ qscheme = qscheme ,
451+ )
452+
453+ # Clip val is not optional, but gets overriden in set_per_channel
454+ torch_quantizer_symmetric_perCh = TorchQuantizer (
455+ num_bits = quantizer_symmetric_perCh ["num_bits" ],
456+ clip_low = - clip_val ,
457+ clip_high = clip_val ,
458+ qscheme = qscheme ,
459+ )
460+
461+ # Set base quantizer and SAWB options ; save nativePT
462+ native_pt = base_options ["nativePT" ]
463+ base_options ["nativePT" ] = False # Not supported for SAWB
464+ set_base_options (
465+ sawbnew_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh , base_options
466+ )
467+ # SAWB requires tensor to set parameters for TorchQuantizer
468+ set_per_channel (tensor , sawbnew_quantizer_symmetric_perCh , torch_quantizer_symmetric_perCh )
469+
470+
471+ # Create quantized tensors from FMS Model Optimizer + torch
472+ qtensor_fms_mo = sawbnew_quantizer_symmetric_perCh (tensor ).detach ()
473+ qtensor_torch = torch_quantizer_symmetric_perCh (tensor ).detach ()
474+
475+ setup = torch_quantizer_symmetric_perCh .get_setup ()
476+
477+ quantizer_error (
478+ tensor , qtensor_fms_mo , qtensor_torch , setup , base_options , other_options ,
479+ )
0 commit comments