Skip to content

Commit d152369

Browse files
committed
tests: gradient accum steps = 1, get checkpoint path
Signed-off-by: Will Johnson <mwjohnson728@gmail.com>
1 parent a825a8a commit d152369

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

tests/build/test_launch_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"num_train_epochs": 5,
4747
"per_device_train_batch_size": 4,
4848
"per_device_eval_batch_size": 4,
49-
"gradient_accumulation_steps": 4,
49+
"gradient_accumulation_steps": 1,
5050
"learning_rate": 0.00001,
5151
"weight_decay": 0,
5252
"warmup_ratio": 0.03,

tests/test_sft_trainer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import copy
2323
import json
2424
import os
25+
import re
2526
import tempfile
2627

2728
# Third Party
@@ -87,7 +88,7 @@
8788
num_train_epochs=5,
8889
per_device_train_batch_size=4,
8990
per_device_eval_batch_size=4,
90-
gradient_accumulation_steps=4,
91+
gradient_accumulation_steps=1,
9192
learning_rate=0.00001,
9293
weight_decay=0,
9394
warmup_ratio=0.03,
@@ -1142,7 +1143,13 @@ def _validate_hf_resource_scanner_file(tempdir):
11421143

11431144

11441145
def _get_checkpoint_path(dir_path):
1145-
return os.path.join(dir_path, "checkpoint-5")
1146+
checkpoint_dirs = [
1147+
d
1148+
for d in os.listdir(dir_path)
1149+
if os.path.isdir(os.path.join(dir_path, d)) and re.match(r"^checkpoint-\d+$", d)
1150+
]
1151+
checkpoint_dirs.sort(key=lambda name: int(name.split("-")[-1]))
1152+
return os.path.join(dir_path, checkpoint_dirs[-1])
11461153

11471154

11481155
def _get_adapter_config(dir_path):

0 commit comments

Comments
 (0)