Skip to content

Commit 3bb1bcc

Browse files
committed
removed lora parse args to simplify; tuned sweep config
1 parent d9f8b98 commit 3bb1bcc

File tree

3 files changed

+13
-17
lines changed

3 files changed

+13
-17
lines changed

src/bart_reddit_lora/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,14 @@ def build_base_model(
1313
def build_peft_model(
1414
base_model: BartForConditionalGeneration,
1515
r: int = 8,
16-
lora_alpha: int = 16,
1716
lora_dropout: float = 0.1,
1817
bias: str = "none",
1918
target_modules: list[str] = ("q_proj", "k_proj", "v_proj", "o_proj", "fc1", "fc2"),
2019
modules_to_save: list[str] = ("lm_head",),
2120
) -> PeftModel:
2221
config = LoraConfig(
2322
r=r,
24-
lora_alpha=lora_alpha,
23+
lora_alpha=r * 2,
2524
target_modules=list(target_modules),
2625
lora_dropout=lora_dropout,
2726
bias=bias,

src/bart_reddit_lora/train.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
7575

7676
# additional custom args
7777
peft_rank: int = field(default=32, metadata={"help": "LoRA adapter rank (r)."})
78-
lora_alpha: int = 64
7978
hf_hub_repo_id: str | None = None
8079
run_test: bool = field(
8180
default=False,
@@ -171,11 +170,9 @@ def to_qa(ex):
171170
logger.info(
172171
f"Base model trainable params:\n{print_trainable_parameters(base_model)}"
173172
)
174-
lora_model = build_peft_model(
175-
base_model, training_args.peft_rank, training_args.lora_alpha
176-
)
173+
lora_model = build_peft_model(base_model, training_args.peft_rank)
177174
logger.info(
178-
f"LoRA model (peft_rank={training_args.peft_rank}, lora_alpha={training_args.lora_alpha}) trainable params:\n{print_trainable_parameters(lora_model)}"
175+
f"LoRA model (peft_rank={training_args.peft_rank}) trainable params:\n{print_trainable_parameters(lora_model)}"
179176
)
180177

181178
# ---------- Train ----------

sweep.yaml

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,24 @@ project: bart-base-reddit-lora
55
entity: codinglabsong-keio-jp
66

77
method: bayes # {grid | random | bayes}
8-
run_cap: 10 # sweep run limit
8+
run_cap: 15 # sweep run limit
99

1010
metric: # what to optimise
1111
name: eval/loss # must match the key in evaluation.compute_metrics returns
12-
goal: maximize
12+
goal: minimize
1313

1414
parameters:
1515
learning_rate:
16-
min: 0.00001
17-
max: 0.001
16+
min: 0.00001 # 1e-5
17+
max: 0.001 # 1e-3
1818
distribution: log_uniform_values
1919
num_train_epochs:
20-
values: [2]
20+
values: 4
2121
peft_rank:
22-
values: [32]
22+
values: [32, 64]
2323
train_sample:
24-
values: [True]
24+
values: True
2525

26-
early_terminate:
27-
type: hyperband
28-
min_iter: 1
26+
# early_terminate:
27+
# type: hyperband
28+
# min_iter: 3

0 commit comments

Comments
 (0)