Skip to content

Commit 5944677

Browse files
author
Ubuntu
committed
fixed minor bugs preventing sweeps
1 parent 52d33c6 commit 5944677

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/bart_reddit_lora/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
6161
logging_steps: int = 50
6262
save_total_limit: int = 2
6363
load_best_model_at_end: bool = True
64-
metric_for_best_model: str = "eval/loss"
64+
metric_for_best_model: str = "loss"
6565
greater_is_better: bool = False
6666

6767
fp16: bool = True
@@ -147,7 +147,7 @@ def to_qa(ex):
147147
"validation": str(data_args.validation_file),
148148
"test": str(data_args.test_file),
149149
}
150-
ds = load_dataset("csv", data_files=data_files, streaming=True)
150+
ds = load_dataset("csv", data_files=data_files)
151151
# # load and tokenize dataset
152152
# # load CSVs
153153
# data_files = {

sweep.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
# sweep spec for bart-base-reddit-lora
22
program: scripts/train.py
3-
4-
project: bart-base-reddit-lora
5-
entity: codinglabsong-keio-jp
3+
name: bart-base-reddit-lora-sweep
64

75
method: bayes # {grid | random | bayes}
8-
run_cap: 15 # sweep run limit
9-
106
metric: # what to optimise
11-
name: eval/loss # must match the key in evaluation.compute_metrics returns
7+
name: eval_loss # must match the key in .evaluate() returns
128
goal: minimize
139

1410
parameters:
@@ -22,6 +18,10 @@ parameters:
2218
values: [32, 64]
2319
train_sample:
2420
values: [True]
21+
run_cap: 10 # sweep run limit
22+
23+
project: bart-base-reddit-lora
24+
entity: codinglabsong-keio-jp
2525

2626
# early_terminate:
2727
# type: hyperband

0 commit comments

Comments
 (0)