Skip to content

Commit d71d3c8

Browse files
committed
apply reviews
1 parent 4aa826b commit d71d3c8

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
8282
self.staleness_limit = kwargs.get("staleness_limit", float("inf"))
8383

8484
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
85-
oldest_valid_version = max(step - self.staleness_limit, 0)
85+
min_model_version = max(step - self.staleness_limit, 0)
8686
metrics = {}
8787
with Timer(metrics, "time/read_experience"):
88-
exp_list = await self.exp_buffer.read_async(oldest_valid_version=oldest_valid_version)
88+
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)
8989
repr_samples = representative_sample(exp_list)
9090
self.set_model_version_metric(exp_list, metrics)
9191
with Timer(metrics, "time/gather_experience"):

trinity/buffer/storage/queue.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def default_config(cls) -> Dict:
100100

101101

102102
class QueueBuffer(ABC):
103-
async def set_oldest_valid_version(self, oldest_valid_version: int):
104-
self.oldest_valid_version = max(oldest_valid_version, 0)
103+
async def set_min_model_version(self, min_model_version: int):
104+
self.min_model_version = max(min_model_version, 0)
105105

106106
@abstractmethod
107107
async def put(self, exps: List[Experience]) -> None:
@@ -152,7 +152,7 @@ def __init__(self, capacity: int):
152152
"""
153153
super().__init__(maxsize=capacity)
154154
self._closed = False
155-
self.oldest_valid_version = 0
155+
self.min_model_version = 0
156156

157157
async def put(self, item: List[Experience]):
158158
if len(item) == 0:
@@ -163,8 +163,8 @@ async def get(self):
163163
while True:
164164
item = await super().get()
165165
if (
166-
self.oldest_valid_version <= 0
167-
or item[0].info["model_version"] >= self.oldest_valid_version
166+
self.min_model_version <= 0
167+
or item[0].info["model_version"] >= self.min_model_version
168168
):
169169
return item
170170

@@ -222,7 +222,7 @@ def __init__(
222222
self.reuse_cooldown_time = reuse_cooldown_time
223223
self._condition = asyncio.Condition() # For thread-safe operations
224224
self._closed = False
225-
self.oldest_valid_version = 0
225+
self.min_model_version = 0
226226

227227
async def _put(self, item: List[Experience], delay: float = 0) -> None:
228228
"""
@@ -287,8 +287,8 @@ async def get(self) -> List[Experience]:
287287
self.priority_groups.popitem(index=-1)
288288

289289
if (
290-
self.oldest_valid_version <= 0
291-
or item[0].info["model_version"] >= self.oldest_valid_version
290+
self.min_model_version <= 0
291+
or item[0].info["model_version"] >= self.min_model_version
292292
):
293293
break
294294

@@ -374,17 +374,15 @@ async def put_batch(self, exp_list: List) -> None:
374374
if self.writer is not None:
375375
self.writer.write(exp_list)
376376

377-
async def get_batch(
378-
self, batch_size: int, timeout: float, oldest_valid_version: int = 0
379-
) -> List:
377+
async def get_batch(self, batch_size: int, timeout: float, min_model_version: int = 0) -> List:
380378
"""Get batch of experience."""
381-
await self.queue.set_oldest_valid_version(oldest_valid_version)
379+
await self.queue.set_min_model_version(min_model_version)
382380
start_time = time.time()
383381
result = []
384382
while len(result) < batch_size:
385383
while len(self.exp_pool) > 0 and len(result) < batch_size:
386384
exp = self.exp_pool.popleft()
387-
if oldest_valid_version > 0 and exp.info["model_version"] < oldest_valid_version:
385+
if min_model_version > 0 and exp.info["model_version"] < min_model_version:
388386
continue
389387
result.append(exp)
390388
if len(result) >= batch_size:

trinity/buffer/storage/sql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _read_fifo(self, batch_size: int) -> List[Experience]:
143143
time.sleep(1)
144144
return exp_list
145145

146-
def _read_priority(self, batch_size: int, oldest_valid_version: int = 0) -> List[Experience]:
146+
def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Experience]:
147147
exp_list = []
148148
start_time = time.time()
149149
latest_size = 0
@@ -161,8 +161,8 @@ def _read_priority(self, batch_size: int, oldest_valid_version: int = 0) -> List
161161
query = session.query(self.table_model_cls).order_by(
162162
asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)
163163
)
164-
if oldest_valid_version > 0:
165-
query = query.filter(self.table_model_cls.model_version >= oldest_valid_version)
164+
if min_model_version > 0:
165+
query = query.filter(self.table_model_cls.model_version >= min_model_version)
166166
experiences = query.limit(batch_size).with_for_update().all()
167167
if len(experiences) != batch_size:
168168
if latest_size != len(experiences):

0 commit comments

Comments
 (0)