Skip to content

Commit e591965

Browse files
authored
[BUG] Fixes Bottleneck Configs to work with ln_before = True and init_weights = "mam_adapter" (#761)
Fixes #745 When "mam_adapter" is specified, the code will now look for the `nn.Linear` or `PHMLayer` inside the `self.down_adapter` layer sequence and apply the initialization on the correct layer edit: also removes an extra block of code in the `AdapterPlus` notebook
1 parent ec4a59e commit e591965

File tree

2 files changed

+4
-21
lines changed

2 files changed

+4
-21
lines changed

notebooks/ViT_AdapterPlus_FineTuning.ipynb

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -302,25 +302,6 @@
302302
")"
303303
]
304304
},
305-
{
306-
"cell_type": "code",
307-
"execution_count": null,
308-
"metadata": {},
309-
"outputs": [],
310-
"source": [
311-
"trainer = AdapterTrainer(\n",
312-
" model=model,\n",
313-
" args=training_args,\n",
314-
" data_collator=data_collator,\n",
315-
" train_dataset=train_dataset,\n",
316-
" eval_dataset=eval_dataset,\n",
317-
" tokenizer=processor,\n",
318-
" compute_metrics = compute_metrics\n",
319-
")\n",
320-
"\n",
321-
"trainer.train()"
322-
]
323-
},
324305
{
325306
"cell_type": "code",
326307
"execution_count": null,

src/adapters/methods/modeling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,11 @@ def __init__(
123123
self.gate.apply(self.init_bert_weights)
124124
elif config["init_weights"] == "mam_adapter":
125125
with torch.no_grad():
126-
nn.init.kaiming_uniform_(self.adapter_down[0].weight, a=math.sqrt(5))
126+
for layer in self.adapter_down:
127+
if isinstance(layer, nn.Linear) or isinstance(layer, PHMLayer):
128+
nn.init.kaiming_uniform_(layer.weight, a=math.sqrt(5))
129+
nn.init.zeros_(layer.bias)
127130
nn.init.zeros_(self.adapter_up.weight)
128-
nn.init.zeros_(self.adapter_down[0].bias)
129131
nn.init.zeros_(self.adapter_up.bias)
130132
if self.use_gating:
131133
self.gate.apply(self.init_bert_weights)

0 commit comments

Comments
 (0)