Skip to content

Commit 6303210

Browse files
committed
removed get_default_params and match_layers_params
1 parent d495843 commit 6303210

File tree

1 file changed

+0
-67
lines changed

1 file changed

+0
-67
lines changed

src/llmcompressor/utils/pytorch/module.py

Lines changed: 0 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@
4848

4949
__all__ = [
5050
"match_targets",
51-
"get_default_params",
52-
"match_layers_params",
5351
"get_terminal_layers",
5452
"get_prunable_layers",
5553
"get_quantizable_layers",
@@ -89,71 +87,6 @@ def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, in
8987
return False, -1
9088

9189

92-
def get_default_params(layers: Dict[str, Module]) -> Dict[str, Parameter]:
93-
params = {}
94-
for name, layer in layers.items():
95-
for param_name, param in layer.named_parameters():
96-
if param_name == "weight":
97-
params[name] = param
98-
break
99-
return params
100-
101-
102-
def match_layers_params(
103-
targets: Union[str, List[str]], module: Module, params: bool = False
104-
) -> Dict[str, Union[Module, Parameter]]:
105-
if targets == ALL_TARGET:
106-
values = get_terminal_layers(module)
107-
108-
return values if not params else get_default_params(values)
109-
110-
if targets == ALL_PRUNABLE_TARGET:
111-
values = get_prunable_layers(module)
112-
113-
return values if not params else get_default_params(values)
114-
115-
if targets == ALL_QUANTIZABLE_TARGET:
116-
values = get_quantizable_layers(module)
117-
118-
return values if not params else get_default_params(values)
119-
120-
if isinstance(targets, str):
121-
targets = [targets]
122-
123-
resolved = {}
124-
targets_found = [False for _ in range(len(targets))]
125-
126-
for name, layer in module.named_modules():
127-
# due to nesting, FSDP may not be the top layer
128-
name = fix_fsdp_module_name(name)
129-
match, match_index = match_targets(name, targets)
130-
if match and not params:
131-
targets_found[match_index] = True
132-
resolved[name] = layer
133-
else:
134-
match, match_index = match_class(layer, targets)
135-
if match:
136-
targets_found[match_index] = True
137-
resolved[name] = layer
138-
139-
for param_name, param in layer.named_parameters():
140-
if "." in param_name: # skip parameters of nested layers
141-
continue
142-
143-
param_match, param_match_index = match_targets(
144-
f"{name}.{param_name}", targets
145-
)
146-
if param_match:
147-
targets_found[param_match_index] = True
148-
resolved[f"{name}"] = layer if not params else param
149-
150-
missed = [target for found, target in zip(targets_found, targets) if not found]
151-
if len(missed) > 0:
152-
raise ValueError(f"Could not find targets {missed} in module {module}")
153-
154-
return resolved
155-
156-
15790
def get_terminal_layers(module: Module) -> Dict[str, Module]:
15891
terminal = {}
15992

0 commit comments

Comments
 (0)