Skip to content

Commit f1d90d0

Browse files
committed
fix wandb config update
1 parent 98c0284 commit f1d90d0

File tree

3 files changed

+5
-4
lines changed

3 files changed

+5
-4
lines changed

docs/multi_gpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ The args used in the command above are:
3434

3535
* `--use_peft` boolean flag to enable PEFT methods in the script
3636

37-
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
37+
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
3838

3939
We use `torchrun` here to spawn multiple processes for FSDP.
4040

docs/single_gpu.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ The args used in the command above are:
2727

2828
* `--use_peft` boolean flag to enable PEFT methods in the script
2929

30-
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`, `prefix`.
30+
* `--peft_method` to specify the PEFT method, here we use `lora` other options are `llama_adapter`.
3131

3232
* `--quantization` boolean flag to enable int8 quantization
3333

src/llama_recipes/finetuning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ def main(**kwargs):
154154
# Load the pre-trained peft model checkpoint and setup its configuration
155155
if train_config.from_peft_checkpoint:
156156
model = PeftModel.from_pretrained(model, train_config.from_peft_checkpoint, is_trainable=True)
157+
peft_config = model.peft_config()
157158
# Generate the peft config and start fine-tuning from original model
158159
else:
159160
peft_config = generate_peft_config(train_config, kwargs)
160161
model = get_peft_model(model, peft_config)
161-
if wandb_run:
162-
wandb_run.config.update(peft_config)
162+
if wandb_run:
163+
wandb_run.config.update(peft_config)
163164
model.print_trainable_parameters()
164165

165166

0 commit comments

Comments
 (0)