Skip to content

Commit f6dd28e

Browse files
committed
apply reviews
1 parent d71d3c8 commit f6dd28e

File tree

6 files changed

+29
-26
lines changed

6 files changed

+29
-26
lines changed

docs/sphinx_doc/source/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ algorithm:
112112
- `optimizer`: Optimizer configuration for actor.
113113
- `lr`: Learning rate for actor.
114114
- `warmup_style`: Warmup style for actor's learning rate.
115-
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer.
115+
- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`.
116116
- `advantage_fn`: The advantage function used for computing advantages.
117117
- `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward.
118118
- `kl_loss_fn`: The KL loss function used for computing KL loss.

docs/sphinx_doc/source_zh/tutorial/trinity_configs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ algorithm:
112112
- `optimizer`: Actor 优化器的参数。
113113
- `lr`: 优化器的学习率。
114114
- `warmup_style`: 学习率的预热策略。
115-
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。
115+
- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。
116116
- `advantage_fn`: 用于计算优势值的函数。
117117
- `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。
118118
- `kl_loss_fn`: 用于计算 KL 损失的函数。

tests/buffer/sample_strategy_test.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,15 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio
9494
if current_task:
9595
await current_task
9696

97-
async def _flexible_verify_model_version(self, step, staleness_limit):
97+
async def _flexible_verify_model_version(self, step, max_staleness):
9898
_, metrics, _ = await self.sample_strategy.sample(step=step)
9999
self.assertGreaterEqual(
100100
metrics["sample/model_version/min"],
101-
step - staleness_limit,
101+
step - max_staleness,
102102
f"Min model version mismatch at step {step}",
103103
)
104104

105-
async def _flexible_verify_sampling_model_versions(
106-
self, exps_list, check_steps, staleness_limit
107-
):
105+
async def _flexible_verify_sampling_model_versions(self, exps_list, check_steps, max_staleness):
108106
self._init_buffer_writer_and_sample_strategy()
109107

110108
# Write experiences to buffer, while sample and validate model versions
@@ -115,7 +113,7 @@ async def _flexible_verify_sampling_model_versions(
115113
if current_task:
116114
await current_task
117115
current_task = asyncio.create_task(
118-
self._flexible_verify_model_version(step, staleness_limit)
116+
self._flexible_verify_model_version(step, max_staleness)
119117
)
120118
await asyncio.sleep(0.1)
121119

@@ -146,9 +144,9 @@ async def test_default_queue_default_sample_strategy(self):
146144
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
147145

148146
async def test_default_queue_staleness_control_sample_strategy(self):
149-
staleness_limit = 3
147+
max_staleness = 3
150148
self.config.algorithm.sample_strategy = "staleness_control"
151-
self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit}
149+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
152150
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
153151
name="default_queue_staleness_control",
154152
storage_type=StorageType.QUEUE.value,
@@ -161,15 +159,15 @@ async def test_default_queue_staleness_control_sample_strategy(self):
161159
steps = self._default_steps()
162160
expected_model_versions_map = {}
163161
for step in steps:
164-
predict_version = max(step - staleness_limit, 0)
162+
predict_version = max(step - max_staleness, 0)
165163
expected_model_versions_map[step] = [
166164
predict_version + i // self.exp_write_batch_size
167165
for i in range(self.config.buffer.train_batch_size)
168166
]
169167

170168
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
171169

172-
def _simulate_priority_queue(self, steps, staleness_limit=float("inf")):
170+
def _simulate_priority_queue(self, steps, max_staleness=float("inf")):
173171
expected_model_versions_map = {}
174172
buffer = deque()
175173
exp_pool = deque()
@@ -187,7 +185,7 @@ def _simulate_priority_queue(self, steps, staleness_limit=float("inf")):
187185
exp_pool.extend(buffer.pop())
188186
while len(exp_pool) > 0 and len(batch_versions) < train_batch_size:
189187
exp_version = exp_pool.popleft()
190-
if exp_version < step - staleness_limit:
188+
if exp_version < step - max_staleness:
191189
continue
192190
batch_versions.append(exp_version)
193191
if len(batch_versions) >= train_batch_size:
@@ -214,9 +212,9 @@ async def test_priority_queue_default_sample_strategy(self):
214212
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
215213

216214
async def test_priority_queue_staleness_control_sample_strategy(self):
217-
staleness_limit = 2
215+
max_staleness = 2
218216
self.config.algorithm.sample_strategy = "staleness_control"
219-
self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit}
217+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
220218
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
221219
name="priority_queue_staleness_control",
222220
storage_type=StorageType.QUEUE.value,
@@ -227,14 +225,14 @@ async def test_priority_queue_staleness_control_sample_strategy(self):
227225
# init testing data
228226
exps_list = self._default_exp_list()
229227
steps = self._default_steps()
230-
expected_model_versions_map = self._simulate_priority_queue(steps, staleness_limit)
228+
expected_model_versions_map = self._simulate_priority_queue(steps, max_staleness)
231229

232230
await self._verify_sampling_model_versions(exps_list, expected_model_versions_map)
233231

234232
async def test_sql_staleness_control_sample_strategy(self):
235-
staleness_limit = 2
233+
max_staleness = 2
236234
self.config.algorithm.sample_strategy = "staleness_control"
237-
self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit}
235+
self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness}
238236
self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig(
239237
name="sql_staleness_control",
240238
storage_type=StorageType.SQL.value,
@@ -245,7 +243,7 @@ async def test_sql_staleness_control_sample_strategy(self):
245243
exps_list = self._default_exp_list()
246244
steps = self._default_steps()
247245

248-
await self._flexible_verify_sampling_model_versions(exps_list, steps, staleness_limit)
246+
await self._flexible_verify_sampling_model_versions(exps_list, steps, max_staleness)
249247

250248
def tearDown(self):
251249
asyncio.run(self.buffer_writer.release())

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ def load_state_dict(self, state_dict: dict) -> None:
7979
class StalenessControlSampleStrategy(DefaultSampleStrategy):
8080
def __init__(self, buffer_config: BufferConfig, **kwargs):
8181
super().__init__(buffer_config)
82-
self.staleness_limit = kwargs.get("staleness_limit", float("inf"))
82+
self.max_staleness = kwargs.get("max_staleness", float("inf"))
8383

8484
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
85-
min_model_version = max(step - self.staleness_limit, 0)
85+
min_model_version = max(step - self.max_staleness, 0)
8686
metrics = {}
8787
with Timer(metrics, "time/read_experience"):
8888
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)

trinity/buffer/schema/sql_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ExperienceModel(Base): # type: ignore
3939
message_list = Column(JSON, nullable=True)
4040
reward = Column(Float, nullable=True)
4141
# for step info
42-
model_version = Column(Integer, nullable=True)
42+
model_version = Column(Integer, nullable=True, index=True)
4343
# serialized experience object
4444
experience_bytes = Column(LargeBinary, nullable=True)
4545
consumed = Column(Integer, default=0, index=True)

trinity/buffer/storage/sql.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,17 @@ def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Ex
158158
with retry_session(
159159
self.session, self.max_retry_times, self.max_retry_interval
160160
) as session:
161-
query = session.query(self.table_model_cls).order_by(
162-
asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)
163-
)
161+
query = session.query(self.table_model_cls)
164162
if min_model_version > 0:
165163
query = query.filter(self.table_model_cls.model_version >= min_model_version)
166-
experiences = query.limit(batch_size).with_for_update().all()
164+
experiences = (
165+
query.order_by(
166+
asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)
167+
)
168+
.limit(batch_size)
169+
.with_for_update()
170+
.all()
171+
)
167172
if len(experiences) != batch_size:
168173
if latest_size != len(experiences):
169174
latest_size = len(experiences)

0 commit comments

Comments
 (0)