Skip to content

Commit 04b2608

Browse files
committed
lowered input, target token length
1 parent 0dd52df commit 04b2608

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

src/bart_reddit_lora/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,8 @@ def split_and_save(df, out_dir: Union[str, Path]):
192192
def tokenize_and_format(
193193
ds: DatasetDict,
194194
checkpoint: str = "facebook/bart-base",
195-
max_input_length: int = 1024, # max 1024 224
196-
max_target_length: int = 800, # max 1024
195+
max_input_length: int = 512, # max 1024 1024
196+
max_target_length: int = 128, # max 1024 800
197197
) -> Tuple[DatasetDict, AutoTokenizer]:
198198
tok = AutoTokenizer.from_pretrained(checkpoint)
199199

src/bart_reddit_lora/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class CustomTrainingArgs(Seq2SeqTrainingArguments):
4646
default="outputs/bart-base-reddit-lora",
4747
metadata={"help": "Prefix folder for all checkpoints/run logs."},
4848
)
49-
num_train_epochs: int = 6
49+
num_train_epochs: int = 12
5050
per_device_train_batch_size: int = 8
5151
per_device_eval_batch_size: int = 16
5252
learning_rate: float = 6e-5

0 commit comments

Comments
 (0)