2020from tico .quantization .config .ptq import PTQConfig
2121from tico .quantization .quantizer import BaseQuantizer
2222from tico .quantization .quantizer_registry import register_quantizer
23-
23+ from tico . quantization . wrapq . wrap_helper import PTQWrapHelper
2424from tico .quantization .wrapq .wrappers .ptq_wrapper import PTQWrapper
2525from tico .quantization .wrapq .wrappers .quant_module_base import QuantModuleBase
2626
@@ -43,6 +43,7 @@ def __init__(self, config: PTQConfig):
4343 super ().__init__ (config )
4444 self .qcfg : PTQConfig = config
4545 self .strict_wrap : bool = bool (getattr (config , "strict_wrap" , True ))
46+ self .wrapper = PTQWrapHelper (strict_wrap = self .strict_wrap )
4647
4748 @torch .no_grad ()
4849 def prepare (
@@ -52,7 +53,7 @@ def prepare(
5253 kwargs : Optional [Dict [str , Any ]] = None ,
5354 ):
5455 # Wrap the tree (or single module) according to strictness policy
55- model = self ._wrap_supported (model , self .qcfg )
56+ model = self .wrapper . wrap_supported (model , self .qcfg )
5657
5758 # Switch all quant modules into calibration mode
5859 if isinstance (model , QuantModuleBase ):
@@ -71,154 +72,3 @@ def convert(self, model):
7172 if isinstance (m , QuantModuleBase ):
7273 m .freeze_qparams ()
7374 return model
74-
75- def _wrap_supported (
76- self ,
77- root : nn .Module ,
78- qcfg : PTQConfig ,
79- ) -> nn .Module :
80- """
81- Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
82- """
83- assert not isinstance (root , QuantModuleBase ), "The module is already wrapped."
84- try :
85- return PTQWrapper (root , qcfg = qcfg , fp_name = "model" )
86- except NotImplementedError as e :
87- print ("no special wrapper for model, wrappig using general case" )
88-
89- # Case A: HuggingFace-style transformers: model.model.layers
90- lm = getattr (root , "model" , None )
91-
92- embeddings = (
93- getattr (lm , "embed_tokens" , None ) if isinstance (lm , nn .Module ) else None
94- )
95- if isinstance (embeddings , nn .Module ):
96- child_scope = "model.embeddings"
97- child_cfg = qcfg .child (child_scope )
98- wrapped = self ._try_wrap (
99- embeddings ,
100- child_cfg ,
101- fp_name = child_scope ,
102- raise_on_fail = self .strict_wrap ,
103- )
104- lm .embed_tokens = wrapped # type: ignore[union-attr]
105-
106- model_norm = getattr (lm , "norm" , None ) if isinstance (lm , nn .Module ) else None
107- if isinstance (model_norm , nn .Module ):
108- child_scope = "model.norm"
109- child_cfg = qcfg .child (child_scope )
110- wrapped = self ._try_wrap (
111- model_norm ,
112- child_cfg ,
113- fp_name = child_scope ,
114- raise_on_fail = self .strict_wrap ,
115- )
116- lm .norm = wrapped # type: ignore[union-attr]
117-
118- lm_head = getattr (root , "lm_head" , None ) if isinstance (lm , nn .Module ) else None
119- if isinstance (lm_head , nn .Module ):
120- child_scope = "lm_head"
121- child_cfg = qcfg .child (child_scope )
122- wrapped = self ._try_wrap (
123- lm_head ,
124- child_cfg ,
125- fp_name = child_scope ,
126- raise_on_fail = self .strict_wrap ,
127- )
128- root .lm_head = wrapped
129-
130- layers = getattr (lm , "layers" , None ) if isinstance (lm , nn .Module ) else None
131- if isinstance (layers , nn .ModuleList ):
132- new_list = nn .ModuleList ()
133- for idx , layer in enumerate (layers ):
134- child_scope = f"layer{ idx } "
135- child_cfg = qcfg .child (child_scope )
136-
137- # Enforce strictness at the child boundary
138- wrapped = self ._try_wrap (
139- layer ,
140- child_cfg ,
141- fp_name = child_scope ,
142- raise_on_fail = self .strict_wrap ,
143- )
144- new_list .append (wrapped )
145- lm .layers = new_list # type: ignore[union-attr]
146- return root
147-
148- # Case B: Containers
149- if isinstance (root , (nn .Sequential , nn .ModuleList )):
150- for i , child in enumerate (list (root )):
151- name = str (i )
152- child_cfg = qcfg .child (name )
153-
154- wrapped = self ._try_wrap (
155- child , child_cfg , fp_name = name , raise_on_fail = self .strict_wrap
156- )
157- if wrapped is child :
158- assert not self .strict_wrap
159- wrapped = self ._wrap_supported (wrapped , child_cfg )
160- root [i ] = wrapped # type: ignore[index]
161- return root
162-
163- if isinstance (root , nn .ModuleDict ):
164- for k , child in list (root .items ()):
165- name = k
166- child_cfg = qcfg .child (name )
167-
168- wrapped = self ._try_wrap (
169- child , child_cfg , fp_name = name , raise_on_fail = self .strict_wrap
170- )
171- if wrapped is child :
172- assert not self .strict_wrap
173- wrapped = self ._wrap_supported (wrapped , child_cfg )
174- root [k ] = wrapped # type: ignore[index]
175- return root
176-
177- # Case C: Leaf node
178- root_name = getattr (root , "_get_name" , lambda : None )()
179- wrapped = self ._try_wrap (
180- root , qcfg , fp_name = root_name , raise_on_fail = self .strict_wrap
181- )
182- if wrapped is not root :
183- return wrapped
184-
185- assert not self .strict_wrap
186- # Case D: Named children
187- for name , child in list (root .named_children ()):
188- child_cfg = qcfg .child (name )
189-
190- wrapped = self ._try_wrap (
191- child , child_cfg , fp_name = name , raise_on_fail = self .strict_wrap
192- )
193- if wrapped is child :
194- assert not self .strict_wrap
195- wrapped = self ._wrap_supported (wrapped , child_cfg )
196- setattr (root , name , wrapped )
197-
198- return root
199-
200- def _try_wrap (
201- self ,
202- module : nn .Module ,
203- qcfg_for_child : PTQConfig ,
204- * ,
205- fp_name : Optional [str ],
206- raise_on_fail : bool ,
207- ) -> nn .Module :
208- """
209- Attempt to wrap a boundary with PTQWrapper.
210-
211- Behavior:
212- • If PTQWrapper succeeds: return wrapped module.
213- • If PTQWrapper raises NotImplementedError:
214- - raise_on_fail=True -> re-raise (strict)
215- - raise_on_fail=False -> return original module (permissive)
216- """
217- try :
218- return PTQWrapper (module , qcfg = qcfg_for_child , fp_name = fp_name )
219- except NotImplementedError as e :
220- if raise_on_fail :
221- raise NotImplementedError (
222- f"PTQQuantizer: no quantization wrapper for { type (module ).__name__ } "
223- ) from e
224- return module
0 commit comments