@@ -53,10 +53,12 @@ def __init__(
5353 self ,
5454 unit : str ,
5555 symmetric : bool = True ,
56- Nch : int = None ,
57- Ngrp : int = None ,
5856 single_sided : bool = False ,
5957 qlevel_lowering : bool = True ,
58+ Nch : int = None ,
59+ Ngrp : int = None ,
60+ NperGrp : int = None ,
61+ axis : int = None ,
6062 ):
6163 """
6264 Init Qscheme
@@ -69,6 +71,8 @@ def __init__(
6971 single_sided (bool, optional): Specify if clip values are positive. Defaults to False.
7072 qlevel_lowering (bool, optional): Specify lowering of quantized levels.
7173 Defaults to True.
74+ axis (int, optional): Specify which tensor dimension to quantize indiviually.
75+ Defaults to 0.
7276
7377 Raises:
7478 RuntimeError: New PyTorch qscheme found. Need to update.
@@ -101,17 +105,16 @@ def __init__(
101105 "perCh was selected without specifying Nch."
102106 )
103107 elif unit == "perGrp" :
104- if issubclass (type (Ngrp ), int ):
108+ if issubclass (type (NperGrp ), int ):
109+ assert NperGrp > 0 , "Provided NperGrp is negative"
110+ self .NperGrp = NperGrp
111+ elif issubclass (type (Ngrp ), int ):
105112 assert Ngrp > 0 , "Provided Ngrp is negative"
106113 self .Ngrp = Ngrp
107114 else :
108115 raise RuntimeError (
109- "perGrp was selected without specifying Ngrp."
116+ "perGrp was selected without specifying Ngrp or NperGrp ."
110117 )
111- # perGrp can be across channels, but is not required
112- if issubclass (type (Nch ), int ):
113- assert Nch > 0 , "Provided Nch is negative"
114- self .Nch = Nch
115118
116119 self .single_sided = single_sided
117120 self .qlevel_lowering = qlevel_lowering
@@ -128,12 +131,13 @@ def __repr__(self):
128131 q_uint_str = f"qunit={ self .q_unit } "
129132 symmetric_str = f", symmetric={ self .symmetric } "
130133 Nch_str = f", Nch={ self .Nch } " if self .Nch is not None else "" ,
131- Ngrp_str = f", Ngrp={ self .Ngrp } " if self .Ngrp is not None else "" ,
134+ Ngrp_str = f", NperGrp={ self .NperGrp } " if self .NperGrp else "" ,
135+ NperGrp_str = f", Ngrp={ self .NperGrp } " if self .NperGrp is not None else "" ,
132136 single_sided_str = f", single_sided={ self .single_sided } "
133137 qlevel_lowering_str = f", qlevel_lowering={ self .qlevel_lowering } "
134138 return (
135- f"{ self .__class__ .__name__ } ({ q_uint_str } { symmetric_str } { Nch_str } { Ngrp_str } "
136- f"{ single_sided_str } { qlevel_lowering_str } )"
139+ f"{ self .__class__ .__name__ } ({ q_uint_str } { symmetric_str } { Nch_str } "
140+ f"{ Ngrp_str } { NperGrp_str } { single_sided_str } { qlevel_lowering_str } )"
137141 )
138142
139143
0 commit comments