diff --git a/tests/template/data/countdown/train.jsonl b/tests/template/data/countdown/train.jsonl
index 3e9adf0092..9329509906 100644
--- a/tests/template/data/countdown/train.jsonl
+++ b/tests/template/data/countdown/train.jsonl
@@ -14,3 +14,4 @@
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [19, 25, 89], create an equation that equals 95. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [19, 25, 89], \"target\": 95}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [8, 62, 43], create an equation that equals 27. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [8, 62, 43], \"target\": 27}"}
{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [74, 5, 20, 88], create an equation that equals 50. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [74, 5, 20, 88], \"target\": 50}"}
+{"question": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer.\nUser: Using the numbers [15, 51, 73], create an equation that equals 37. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once. Show your work in tags. And return the final answer in tags, for example (1 + 2) / 3 .\nAssistant: Let me solve this step by step.\n", "answer": "{\"numbers\": [15, 51, 73], \"target\": 37}"}
diff --git a/tests/tools.py b/tests/tools.py
index 2bbb8c8931..46d6b4859b 100644
--- a/tests/tools.py
+++ b/tests/tools.py
@@ -42,7 +42,7 @@ def get_checkpoint_path() -> str:
def get_unittest_dataset_config(
dataset_name: str = "countdown", split: str = "train"
) -> StorageConfig:
- """Countdown dataset with 16 samples."""
+ """Countdown dataset with 17 samples."""
if dataset_name == "countdown" or dataset_name == "copy_countdown":
return StorageConfig(
name=dataset_name,
diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py
index 50727641c6..0812c64ed9 100644
--- a/trinity/buffer/reader/file_reader.py
+++ b/trinity/buffer/reader/file_reader.py
@@ -26,12 +26,14 @@ def __init__(
name: str,
max_epoch: int = 1,
offset: int = 0,
+ drop_last: bool = True,
):
self.dataset = dataset
self.dataset_size = len(dataset)
self.name = name
self.current_batch_size = None
self.max_epoch = max_epoch
+ self.drop_last = drop_last
if offset >= self.dataset_size:
self.current_epoch = offset // self.dataset_size
self.current_offset = offset % self.dataset_size
@@ -70,7 +72,7 @@ def read_batch(self, batch_size: int) -> List:
self.current_offset = 0
if self.current_epoch >= self.max_epoch:
- if len(batch) > 0:
+ if not self.drop_last and len(batch) > 0:
return batch
else:
self.progress_bar.close()
@@ -96,6 +98,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
name=meta.name,
max_epoch=meta.total_epochs,
+ drop_last=True,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -174,6 +177,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
load_dataset(meta.path, name=subset_name, split=self.split, trust_remote_code=True),
name=meta.name,
max_epoch=meta.total_epochs,
+ drop_last=True,
) # TODO: support resume
self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path)
@@ -248,6 +252,7 @@ def __init__(self, meta: StorageConfig, config: BufferConfig):
name=meta.name,
max_epoch=self.meta.total_epochs if meta.task_type == TaskType.EXPLORE else 1,
offset=self.meta.index,
+ drop_last=self.meta.task_type == TaskType.EXPLORE,
)
self.read_batch_size = config.batch_size
self.prompt_key = meta.format.prompt_key