Skip to content

Commit 72bb862

Browse files
committed
fix llamapro
1 parent 406a716 commit 72bb862

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

swift/tuners/llamapro.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,21 @@ def prepare_model(model: nn.Module, config: LLaMAProConfig, adapter_name: str) -
8282

8383
new_module_list = nn.ModuleList()
8484
new_module_idx = []
85+
layer_types = getattr(model.config, 'layer_types', None)
86+
_new_layer_type = []
8587
for idx, module in enumerate(module_list):
8688
new_module_list.append(module)
89+
_layer_type = layer_types[idx] if layer_types else None
90+
_new_layer_type.append(_layer_type)
8791
if (idx + 1) % num_stride == 0:
8892
new_module = deepcopy(module)
8993
ActivationMixin.mark_all_sub_modules_as_plugin(new_module)
9094
new_module_list.append(new_module)
9195
new_module_idx.append(idx + 1 + len(new_module_idx))
96+
_new_layer_type.append(_layer_type)
9297

98+
if layer_types is not None:
99+
model.config.layer_types = _new_layer_type
93100
LLaMAPro._update_module_weight(config, new_module_list, new_module_idx)
94101
LLaMAPro._update_module_attr(config, new_module_list)
95102
model.config.num_hidden_layers = len(new_module_list)

0 commit comments

Comments
 (0)