Skip to content

Commit fac2629

Browse files
committed
fix buffer
1 parent 62d6a57 commit fac2629

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

tests/buffer/sql_test.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ async def test_sql_task_buffer_read_write(self) -> None:
7575
)
7676
sql_writer = SQLWriter(config.to_storage_config())
7777
tasks = [
78-
{"task_id": i, "raw_task": {"question": f"question_{i}", "answer": f"answer_{i}"}}
79-
for i in range(total_samples)
78+
{"question": f"question_{i}", "answer": f"answer_{i}"} for i in range(total_samples)
8079
]
8180
self.assertEqual(await sql_writer.acquire(), 1)
8281
sql_writer.write(tasks)
@@ -89,8 +88,8 @@ async def test_sql_task_buffer_read_write(self) -> None:
8988
except StopIteration:
9089
pass
9190
self.assertEqual(len(read_tasks), total_samples)
92-
self.assertIn("question", read_tasks[0]["raw_task"])
93-
self.assertIn("answer", read_tasks[0]["raw_task"])
91+
self.assertIn("question", read_tasks[0].raw_task)
92+
self.assertIn("answer", read_tasks[0].raw_task)
9493
db_wrapper = ray.get_actor("sql-test_task_buffer")
9594
self.assertIsNotNone(db_wrapper)
9695
self.assertEqual(await sql_writer.release(), 0)

trinity/buffer/reader/file_reader.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,6 @@ def select_batch(self, indices: List[int]) -> List:
8686

8787

8888
class BaseFileReader(BufferReader):
89-
def __len__(self):
90-
return self.dataset.dataset_size
91-
9289
async def read_async(self, batch_size: Optional[int] = None):
9390
try:
9491
return self.read(batch_size)
@@ -123,6 +120,9 @@ def state_dict(self):
123120
def load_state_dict(self, state_dict):
124121
return self.reader.load_state_dict(state_dict)
125122

123+
def __len__(self):
124+
return self.reader.__len__()
125+
126126

127127
class ExperienceFileReader(BaseFileReader):
128128
"""Reader for SFT / DPO file data."""
@@ -156,6 +156,9 @@ def state_dict(self):
156156
def load_state_dict(self, state_dict):
157157
self.dataset.current_offset = state_dict["current_index"]
158158

159+
def __len__(self):
160+
return self.dataset.dataset_size
161+
159162

160163
class TaskFileReader(BaseFileReader):
161164
"""A Reader for task file data."""
@@ -205,3 +208,6 @@ def state_dict(self):
205208

206209
def load_state_dict(self, state_dict):
207210
self.dataset.current_offset = state_dict["current_index"]
211+
212+
def __len__(self):
213+
return self.dataset.dataset_size

trinity/buffer/task_scheduler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(self, explorer_state: Dict, config: Config):
121121
self.read_batch_size = config.buffer.batch_size
122122
taskset_configs = config.buffer.explorer_input.tasksets
123123

124-
from trinity.buffer.reader.file_reader import TaskFileReader
124+
from trinity.buffer.reader.file_reader import FileReader
125125

126126
taskset_states = explorer_state.get(
127127
"taskset_states", [{"current_index": 0}] * len(taskset_configs)
@@ -131,15 +131,15 @@ def __init__(self, explorer_state: Dict, config: Config):
131131
for taskset_config, taskset_state in zip(taskset_configs, taskset_states):
132132
assert not taskset_config.is_eval # assume drop last
133133
taskset = get_buffer_reader(taskset_config)
134-
if not isinstance(taskset, TaskFileReader):
134+
if not isinstance(taskset, FileReader):
135135
raise TypeError(
136136
f"Taskset '{taskset_config.name}' has an unsupported type '{type(taskset).__name__}'."
137-
f"Currently, only 'TaskFileReader' is supported by TasksetScheduler."
137+
f"Currently, only 'FileReader' is supported by TasksetScheduler."
138138
)
139139

140140
# Create selector based on type specified in config (e.g., 'sequential', 'shuffle')
141141
selector = SELECTORS.get(taskset_config.task_selector.selector_type)(
142-
taskset.dataset, taskset_config.task_selector
142+
taskset.reader.dataset, taskset_config.task_selector
143143
)
144144
selector.load_state_dict(taskset_state) # Restore any prior state
145145

0 commit comments

Comments
 (0)