|
19 | 19 | fix_fsdp_module_name,
|
20 | 20 | summon_full_params_context,
|
21 | 21 | )
|
| 22 | +from compressed_tensors import match_named_modules |
22 | 23 |
|
23 | 24 | try:
|
24 | 25 | quant_err = None
|
|
49 | 50 | "match_targets",
|
50 | 51 | "get_default_params",
|
51 | 52 | "match_layers_params",
|
52 |
| - "get_layers", |
53 |
| - "get_layer", |
54 | 53 | "get_terminal_layers",
|
55 | 54 | "get_prunable_layers",
|
56 | 55 | "get_quantizable_layers",
|
@@ -155,45 +154,6 @@ def match_layers_params(
|
155 | 154 | return resolved
|
156 | 155 |
|
157 | 156 |
|
158 |
| -def get_layers( |
159 |
| - targets: Union[str, List[str]], |
160 |
| - module: Module, |
161 |
| - exclude_internal_modules: bool = False, |
162 |
| -) -> Dict[str, Module]: |
163 |
| - """ |
164 |
| - Get layers (also known as submodules) of module based on targets |
165 |
| -
|
166 |
| - :param targets: names or regexes to search for |
167 |
| - Can be regex, e.g. "re:.*input_layernorm$" to find all layers |
168 |
| - in module whose names end in string "input_layernorm" |
169 |
| - :param module: Parent module in which to search for targets |
170 |
| - :param exclude_internal_modules: If True, don't include internal |
171 |
| - modules added by llm-compressor, e.g. Observers and Transforms. |
172 |
| - Defaults to False to maintain backward compatibility |
173 |
| -
|
174 |
| - :return: dict of {layer name -> module} of all layers in module |
175 |
| - that match targets |
176 |
| - """ |
177 |
| - layer_dict = match_layers_params(targets, module) |
178 |
| - if exclude_internal_modules: |
179 |
| - layer_dict = { |
180 |
| - name: layer |
181 |
| - for name, layer in layer_dict.items() |
182 |
| - if not isinstance(layer, InternalModule) |
183 |
| - } |
184 |
| - |
185 |
| - return layer_dict |
186 |
| - |
187 |
| - |
188 |
| -def get_layer(target: str, module: Module) -> Tuple[str, Module]: |
189 |
| - layers = get_layers(target, module) |
190 |
| - if len(layers) != 1: |
191 |
| - raise ValueError(f"Expected 1 layer for target {target}, found {len(layers)}") |
192 |
| - name, layer = next(iter(layers.items())) |
193 |
| - |
194 |
| - return name, layer |
195 |
| - |
196 |
| - |
197 | 157 | def get_terminal_layers(module: Module) -> Dict[str, Module]:
|
198 | 158 | terminal = {}
|
199 | 159 |
|
@@ -271,7 +231,7 @@ def get_matching_layer(
|
271 | 231 | :return: Tuple containing the layer name and module that fits the target regex and
|
272 | 232 | best matches name_to_match, or None if no match can be found
|
273 | 233 | """
|
274 |
| - potential_matches = get_layers(target, module) |
| 234 | + potential_matches = match_named_modules(target, module) |
275 | 235 | largest_substring = 0
|
276 | 236 | match = None
|
277 | 237 | for name, module in potential_matches.items():
|
|
0 commit comments