Skip to content

Commit a1b3b01

Browse files
committed
Merge branch 'main' into feat/exp_pipeline
2 parents a1cdc7f + 8bc98cd commit a1b3b01

File tree

3 files changed

+8
-2
lines changed

3 files changed

+8
-2
lines changed

tests/template/data/countdown/train.jsonl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 25, 89], create an equation that equals 95. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [19, 25, 89], \"target\": 95}"}
1515
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [8, 62, 43], create an equation that equals 27. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [8, 62, 43], \"target\": 27}"}
1616
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [74, 5, 20, 88], create an equation that equals 50. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [74, 5, 20, 88], \"target\": 50}"}
17+
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [15, 51, 73], create an equation that equals 37. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in <think> </think> tags. And return the final answer in <answer> </answer> tags, for example <answer> (1 + 2) / 3 </answer>.\nAssistant: Let me solve this step by step.\n<think>", "answer": "{\"numbers\": [15, 51, 73], \"target\": 37}"}

tests/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_checkpoint_path() -> str:
4242
def get_unittest_dataset_config(
4343
dataset_name: str = "countdown", split: str = "train"
4444
) -> StorageConfig:
45-
"""Countdown dataset with 16 samples."""
45+
"""Countdown dataset with 17 samples."""
4646
if dataset_name == "countdown" or dataset_name == "copy_countdown":
4747
return StorageConfig(
4848
name=dataset_name,

trinity/buffer/reader/file_reader.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,14 @@ def __init__(
2626
name: str,
2727
max_epoch: int = 1,
2828
offset: int = 0,
29+
drop_last: bool = True,
2930
):
3031
self.dataset = dataset
3132
self.dataset_size = len(dataset)
3233
self.name = name
3334
self.current_batch_size = None
3435
self.max_epoch = max_epoch
36+
self.drop_last = drop_last
3537
if offset >= self.dataset_size:
3638
self.current_epoch = offset // self.dataset_size
3739
self.current_offset = offset % self.dataset_size
@@ -70,7 +72,7 @@ def read_batch(self, batch_size: int) -> List:
7072
self.current_offset = 0
7173

7274
if self.current_epoch >= self.max_epoch:
73-
if len(batch) > 0:
75+
if not self.drop_last and len(batch) > 0:
7476
return batch
7577
else:
7678
self.progress_bar.close()
@@ -96,6 +98,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
9698
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
9799
name=meta.name,
98100
max_epoch=meta.total_epochs,
101+
drop_last=True,
99102
) # TODO: support resume
100103
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
101104

@@ -174,6 +177,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
174177
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
175178
name=meta.name,
176179
max_epoch=meta.total_epochs,
180+
drop_last=True,
177181
) # TODO: support resume
178182
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
179183

@@ -248,6 +252,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
248252
name=meta.name,
249253
max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
250254
offset=self.meta.index,
255+
drop_last=self.meta.task_type == TaskType.EXPLORE,
251256
)
252257
self.read_batch_size = config.batch_size
253258
self.prompt_key = meta.format.prompt_key

0 commit comments

Comments
 (0)