|
48 | 48 |
|
49 | 49 | __all__ = [
|
50 | 50 | "match_targets",
|
51 |
| - "get_default_params", |
52 |
| - "match_layers_params", |
53 | 51 | "get_terminal_layers",
|
54 | 52 | "get_prunable_layers",
|
55 | 53 | "get_quantizable_layers",
|
@@ -89,71 +87,6 @@ def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, in
|
89 | 87 | return False, -1
|
90 | 88 |
|
91 | 89 |
|
92 |
| -def get_default_params(layers: Dict[str, Module]) -> Dict[str, Parameter]: |
93 |
| - params = {} |
94 |
| - for name, layer in layers.items(): |
95 |
| - for param_name, param in layer.named_parameters(): |
96 |
| - if param_name == "weight": |
97 |
| - params[name] = param |
98 |
| - break |
99 |
| - return params |
100 |
| - |
101 |
| - |
102 |
| -def match_layers_params( |
103 |
| - targets: Union[str, List[str]], module: Module, params: bool = False |
104 |
| -) -> Dict[str, Union[Module, Parameter]]: |
105 |
| - if targets == ALL_TARGET: |
106 |
| - values = get_terminal_layers(module) |
107 |
| - |
108 |
| - return values if not params else get_default_params(values) |
109 |
| - |
110 |
| - if targets == ALL_PRUNABLE_TARGET: |
111 |
| - values = get_prunable_layers(module) |
112 |
| - |
113 |
| - return values if not params else get_default_params(values) |
114 |
| - |
115 |
| - if targets == ALL_QUANTIZABLE_TARGET: |
116 |
| - values = get_quantizable_layers(module) |
117 |
| - |
118 |
| - return values if not params else get_default_params(values) |
119 |
| - |
120 |
| - if isinstance(targets, str): |
121 |
| - targets = [targets] |
122 |
| - |
123 |
| - resolved = {} |
124 |
| - targets_found = [False for _ in range(len(targets))] |
125 |
| - |
126 |
| - for name, layer in module.named_modules(): |
127 |
| - # due to nesting, FSDP may not be the top layer |
128 |
| - name = fix_fsdp_module_name(name) |
129 |
| - match, match_index = match_targets(name, targets) |
130 |
| - if match and not params: |
131 |
| - targets_found[match_index] = True |
132 |
| - resolved[name] = layer |
133 |
| - else: |
134 |
| - match, match_index = match_class(layer, targets) |
135 |
| - if match: |
136 |
| - targets_found[match_index] = True |
137 |
| - resolved[name] = layer |
138 |
| - |
139 |
| - for param_name, param in layer.named_parameters(): |
140 |
| - if "." in param_name: # skip parameters of nested layers |
141 |
| - continue |
142 |
| - |
143 |
| - param_match, param_match_index = match_targets( |
144 |
| - f"{name}.{param_name}", targets |
145 |
| - ) |
146 |
| - if param_match: |
147 |
| - targets_found[param_match_index] = True |
148 |
| - resolved[f"{name}"] = layer if not params else param |
149 |
| - |
150 |
| - missed = [target for found, target in zip(targets_found, targets) if not found] |
151 |
| - if len(missed) > 0: |
152 |
| - raise ValueError(f"Could not find targets {missed} in module {module}") |
153 |
| - |
154 |
| - return resolved |
155 |
| - |
156 |
| - |
157 | 90 | def get_terminal_layers(module: Module) -> Dict[str, Module]:
|
158 | 91 | terminal = {}
|
159 | 92 |
|
|
0 commit comments