Skip to content

Commit fc67e35

Browse files
authored
[quantization] Introduce wrap helper (#557)
This commit introduces wrap helper. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 01a402a commit fc67e35

File tree

2 files changed

+202
-153
lines changed

2 files changed

+202
-153
lines changed

tico/quantization/wrapq/quantizer.py

Lines changed: 3 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from tico.quantization.config.ptq import PTQConfig
2121
from tico.quantization.quantizer import BaseQuantizer
2222
from tico.quantization.quantizer_registry import register_quantizer
23-
23+
from tico.quantization.wrapq.wrap_helper import PTQWrapHelper
2424
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
2525
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
2626

@@ -43,6 +43,7 @@ def __init__(self, config: PTQConfig):
4343
super().__init__(config)
4444
self.qcfg: PTQConfig = config
4545
self.strict_wrap: bool = bool(getattr(config, "strict_wrap", True))
46+
self.wrapper = PTQWrapHelper(strict_wrap=self.strict_wrap)
4647

4748
@torch.no_grad()
4849
def prepare(
@@ -52,7 +53,7 @@ def prepare(
5253
kwargs: Optional[Dict[str, Any]] = None,
5354
):
5455
# Wrap the tree (or single module) according to strictness policy
55-
model = self._wrap_supported(model, self.qcfg)
56+
model = self.wrapper.wrap_supported(model, self.qcfg)
5657

5758
# Switch all quant modules into calibration mode
5859
if isinstance(model, QuantModuleBase):
@@ -71,154 +72,3 @@ def convert(self, model):
7172
if isinstance(m, QuantModuleBase):
7273
m.freeze_qparams()
7374
return model
74-
75-
def _wrap_supported(
76-
self,
77-
root: nn.Module,
78-
qcfg: PTQConfig,
79-
) -> nn.Module:
80-
"""
81-
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
82-
"""
83-
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84-
try:
85-
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
86-
except NotImplementedError as e:
87-
print("no special wrapper for model, wrappig using general case")
88-
89-
# Case A: HuggingFace-style transformers: model.model.layers
90-
lm = getattr(root, "model", None)
91-
92-
embeddings = (
93-
getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None
94-
)
95-
if isinstance(embeddings, nn.Module):
96-
child_scope = "model.embeddings"
97-
child_cfg = qcfg.child(child_scope)
98-
wrapped = self._try_wrap(
99-
embeddings,
100-
child_cfg,
101-
fp_name=child_scope,
102-
raise_on_fail=self.strict_wrap,
103-
)
104-
lm.embed_tokens = wrapped # type: ignore[union-attr]
105-
106-
model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None
107-
if isinstance(model_norm, nn.Module):
108-
child_scope = "model.norm"
109-
child_cfg = qcfg.child(child_scope)
110-
wrapped = self._try_wrap(
111-
model_norm,
112-
child_cfg,
113-
fp_name=child_scope,
114-
raise_on_fail=self.strict_wrap,
115-
)
116-
lm.norm = wrapped # type: ignore[union-attr]
117-
118-
lm_head = getattr(root, "lm_head", None) if isinstance(lm, nn.Module) else None
119-
if isinstance(lm_head, nn.Module):
120-
child_scope = "lm_head"
121-
child_cfg = qcfg.child(child_scope)
122-
wrapped = self._try_wrap(
123-
lm_head,
124-
child_cfg,
125-
fp_name=child_scope,
126-
raise_on_fail=self.strict_wrap,
127-
)
128-
root.lm_head = wrapped
129-
130-
layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
131-
if isinstance(layers, nn.ModuleList):
132-
new_list = nn.ModuleList()
133-
for idx, layer in enumerate(layers):
134-
child_scope = f"layer{idx}"
135-
child_cfg = qcfg.child(child_scope)
136-
137-
# Enforce strictness at the child boundary
138-
wrapped = self._try_wrap(
139-
layer,
140-
child_cfg,
141-
fp_name=child_scope,
142-
raise_on_fail=self.strict_wrap,
143-
)
144-
new_list.append(wrapped)
145-
lm.layers = new_list # type: ignore[union-attr]
146-
return root
147-
148-
# Case B: Containers
149-
if isinstance(root, (nn.Sequential, nn.ModuleList)):
150-
for i, child in enumerate(list(root)):
151-
name = str(i)
152-
child_cfg = qcfg.child(name)
153-
154-
wrapped = self._try_wrap(
155-
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
156-
)
157-
if wrapped is child:
158-
assert not self.strict_wrap
159-
wrapped = self._wrap_supported(wrapped, child_cfg)
160-
root[i] = wrapped # type: ignore[index]
161-
return root
162-
163-
if isinstance(root, nn.ModuleDict):
164-
for k, child in list(root.items()):
165-
name = k
166-
child_cfg = qcfg.child(name)
167-
168-
wrapped = self._try_wrap(
169-
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
170-
)
171-
if wrapped is child:
172-
assert not self.strict_wrap
173-
wrapped = self._wrap_supported(wrapped, child_cfg)
174-
root[k] = wrapped # type: ignore[index]
175-
return root
176-
177-
# Case C: Leaf node
178-
root_name = getattr(root, "_get_name", lambda: None)()
179-
wrapped = self._try_wrap(
180-
root, qcfg, fp_name=root_name, raise_on_fail=self.strict_wrap
181-
)
182-
if wrapped is not root:
183-
return wrapped
184-
185-
assert not self.strict_wrap
186-
# Case D: Named children
187-
for name, child in list(root.named_children()):
188-
child_cfg = qcfg.child(name)
189-
190-
wrapped = self._try_wrap(
191-
child, child_cfg, fp_name=name, raise_on_fail=self.strict_wrap
192-
)
193-
if wrapped is child:
194-
assert not self.strict_wrap
195-
wrapped = self._wrap_supported(wrapped, child_cfg)
196-
setattr(root, name, wrapped)
197-
198-
return root
199-
200-
def _try_wrap(
201-
self,
202-
module: nn.Module,
203-
qcfg_for_child: PTQConfig,
204-
*,
205-
fp_name: Optional[str],
206-
raise_on_fail: bool,
207-
) -> nn.Module:
208-
"""
209-
Attempt to wrap a boundary with PTQWrapper.
210-
211-
Behavior:
212-
• If PTQWrapper succeeds: return wrapped module.
213-
• If PTQWrapper raises NotImplementedError:
214-
- raise_on_fail=True -> re-raise (strict)
215-
- raise_on_fail=False -> return original module (permissive)
216-
"""
217-
try:
218-
return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
219-
except NotImplementedError as e:
220-
if raise_on_fail:
221-
raise NotImplementedError(
222-
f"PTQQuantizer: no quantization wrapper for {type(module).__name__}"
223-
) from e
224-
return module
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
import torch.nn as nn
18+
19+
from tico.quantization.config.ptq import PTQConfig
20+
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
21+
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
22+
23+
24+
class PTQWrapHelper:
25+
"""
26+
Reusable helper that applies PTQWrapper recursively according to PTQConfig.
27+
28+
This class contains only the structural wrapping logic and can be reused by
29+
other algorithms (e.g. SpinQuant) that want to leverage the same wrapper
30+
hierarchy without enabling calibration immediately.
31+
"""
32+
33+
def __init__(self, *, strict_wrap: bool = True):
34+
self.strict_wrap = strict_wrap
35+
36+
def wrap_supported(
37+
self,
38+
root: nn.Module,
39+
qcfg: PTQConfig,
40+
) -> nn.Module:
41+
"""
42+
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
43+
"""
44+
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
45+
46+
try:
47+
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
48+
except NotImplementedError:
49+
print(
50+
f"No specialized wrapper found for {type(root).__name__}; applying recursive wrapping."
51+
)
52+
53+
# Case A: HuggingFace-style transformers: model.model.layers
54+
lm = getattr(root, "model", None)
55+
56+
embeddings = (
57+
getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None
58+
)
59+
if isinstance(embeddings, nn.Module):
60+
child_scope = "model.embeddings"
61+
child_cfg = qcfg.child(child_scope)
62+
wrapped = self.try_wrap(
63+
embeddings,
64+
child_cfg,
65+
fp_name=child_scope,
66+
raise_on_fail=self.strict_wrap,
67+
)
68+
lm.embed_tokens = wrapped # type: ignore[union-attr]
69+
70+
model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None
71+
if isinstance(model_norm, nn.Module):
72+
child_scope = "model.norm"
73+
child_cfg = qcfg.child(child_scope)
74+
wrapped = self.try_wrap(
75+
model_norm,
76+
child_cfg,
77+
fp_name=child_scope,
78+
raise_on_fail=self.strict_wrap,
79+
)
80+
lm.norm = wrapped # type: ignore[union-attr]
81+
82+
lm_head = getattr(root, "lm_head", None)
83+
if isinstance(lm_head, nn.Module):
84+
child_scope = "lm_head"
85+
child_cfg = qcfg.child(child_scope)
86+
wrapped = self.try_wrap(
87+
lm_head,
88+
child_cfg,
89+
fp_name=child_scope,
90+
raise_on_fail=self.strict_wrap,
91+
)
92+
root.lm_head = wrapped # type: ignore[attr-defined]
93+
94+
layers = getattr(lm, "layers", None) if isinstance(lm, nn.Module) else None
95+
if isinstance(layers, nn.ModuleList):
96+
new_list = nn.ModuleList()
97+
for idx, layer in enumerate(layers):
98+
child_scope = f"layer{idx}"
99+
child_cfg = qcfg.child(child_scope)
100+
101+
wrapped = self.try_wrap(
102+
layer,
103+
child_cfg,
104+
fp_name=child_scope,
105+
raise_on_fail=self.strict_wrap,
106+
)
107+
new_list.append(wrapped)
108+
lm.layers = new_list # type: ignore[union-attr]
109+
return root
110+
111+
# Case B: Containers
112+
if isinstance(root, (nn.Sequential, nn.ModuleList)):
113+
for i, child in enumerate(list(root)):
114+
name = str(i)
115+
child_cfg = qcfg.child(name)
116+
117+
wrapped = self.try_wrap(
118+
child,
119+
child_cfg,
120+
fp_name=name,
121+
raise_on_fail=self.strict_wrap,
122+
)
123+
if wrapped is child:
124+
assert not self.strict_wrap
125+
wrapped = self.wrap_supported(wrapped, child_cfg)
126+
root[i] = wrapped # type: ignore[index]
127+
return root
128+
129+
if isinstance(root, nn.ModuleDict):
130+
for k, child in list(root.items()):
131+
name = k
132+
child_cfg = qcfg.child(name)
133+
134+
wrapped = self.try_wrap(
135+
child,
136+
child_cfg,
137+
fp_name=name,
138+
raise_on_fail=self.strict_wrap,
139+
)
140+
if wrapped is child:
141+
assert not self.strict_wrap
142+
wrapped = self.wrap_supported(wrapped, child_cfg)
143+
root[k] = wrapped # type: ignore[index]
144+
return root
145+
146+
# Case C: Leaf node
147+
root_name = getattr(root, "_get_name", lambda: None)()
148+
wrapped = self.try_wrap(
149+
root,
150+
qcfg,
151+
fp_name=root_name,
152+
raise_on_fail=self.strict_wrap,
153+
)
154+
if wrapped is not root:
155+
return wrapped
156+
157+
assert not self.strict_wrap
158+
159+
# Case D: Named children
160+
for name, child in list(root.named_children()):
161+
child_cfg = qcfg.child(name)
162+
163+
wrapped = self.try_wrap(
164+
child,
165+
child_cfg,
166+
fp_name=name,
167+
raise_on_fail=self.strict_wrap,
168+
)
169+
if wrapped is child:
170+
wrapped = self.wrap_supported(wrapped, child_cfg)
171+
setattr(root, name, wrapped)
172+
173+
return root
174+
175+
def try_wrap(
176+
self,
177+
module: nn.Module,
178+
qcfg_for_child: PTQConfig,
179+
*,
180+
fp_name: Optional[str],
181+
raise_on_fail: bool,
182+
) -> nn.Module:
183+
"""
184+
Attempt to wrap a boundary with PTQWrapper.
185+
186+
Behavior:
187+
- If PTQWrapper succeeds: return wrapped module.
188+
- If PTQWrapper raises NotImplementedError:
189+
* raise_on_fail=True -> re-raise
190+
* raise_on_fail=False -> return original module
191+
"""
192+
try:
193+
return PTQWrapper(module, qcfg=qcfg_for_child, fp_name=fp_name)
194+
except NotImplementedError as e:
195+
if raise_on_fail:
196+
raise NotImplementedError(
197+
f"PTQWrapHelper: no quantization wrapper for {type(module).__name__}"
198+
) from e
199+
return module

0 commit comments

Comments
 (0)