Skip to content

Commit 4462fed

Browse files
committed
refactor _should_stop
1 parent fedc49a commit 4462fed

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

trinity/buffer/task_scheduler.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def __init__(self, explorer_state: Dict, config: Config):
9090
self.epoch = self.step * self.read_batch_size // len(self.base_taskset_ids)
9191
self.orders = self.build_orders(self.epoch)
9292

93+
if self.config.buffer.total_steps:
94+
self.max_steps = self.config.buffer.total_steps
95+
else:
96+
self.max_steps = (
97+
self.config.buffer.total_epochs * len(self.base_taskset_ids) // self.read_batch_size
98+
)
99+
93100
def build_orders(self, epoch: int):
94101
"""
95102
Creates a shuffled sequence of taskset IDs to control sampling priority per step.
@@ -109,13 +116,7 @@ def build_orders(self, epoch: int):
109116
return taskset_ids
110117

111118
def _should_stop(self) -> bool:
112-
if self.config.buffer.total_steps:
113-
if self.step >= self.config.buffer.total_steps:
114-
return True
115-
else:
116-
if self.epoch >= self.config.buffer.total_epochs:
117-
return True
118-
return False
119+
return self.step >= self.max_steps
119120

120121
async def read_async(self) -> List:
121122
"""

0 commit comments

Comments
 (0)