From ba742247a237b6910f532a3e79099f89f34a0e3c Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 16 Dec 2025 19:54:01 +0800 Subject: [PATCH 1/7] add staleness control --- tests/buffer/experience_storage_test.py | 1 + tests/buffer/sample_strategy_test.py | 184 ++++++++++++++++++ tests/buffer/sql_test.py | 2 + tests/common/experience_test.py | 1 + .../sample_strategy/sample_strategy.py | 18 ++ trinity/buffer/buffer_reader.py | 4 +- trinity/buffer/reader/file_reader.py | 8 +- trinity/buffer/reader/queue_reader.py | 8 +- trinity/buffer/reader/sql_reader.py | 12 +- trinity/buffer/schema/sql_schema.py | 9 +- trinity/buffer/storage/queue.py | 62 ++++-- trinity/buffer/storage/sql.py | 20 +- trinity/explorer/explorer.py | 13 +- 13 files changed, 297 insertions(+), 45 deletions(-) create mode 100644 tests/buffer/sample_strategy_test.py diff --git a/tests/buffer/experience_storage_test.py b/tests/buffer/experience_storage_test.py index da0f80df8f..f308e1ee10 100644 --- a/tests/buffer/experience_storage_test.py +++ b/tests/buffer/experience_storage_test.py @@ -108,6 +108,7 @@ async def test_sql_experience_buffer(self): prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), + info={"model_version": 0}, ) for i in range(1, self.put_batch_size + 1) ] diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py new file mode 100644 index 0000000000..c6318e2fa1 --- /dev/null +++ b/tests/buffer/sample_strategy_test.py @@ -0,0 +1,184 @@ +import asyncio +from collections import deque +import shutil +import torch +from tests.tools import RayUnittestBaseAysnc, get_template_config +from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY, SampleStrategy +from trinity.buffer.buffer import get_buffer_writer +from trinity.common.config import ExperienceBufferConfig +from trinity.common.constants import StorageType +from trinity.common.experience import Experience + + +class ExperienceStorageTest(RayUnittestBaseAysnc): + def setUp(self): + self.config = get_template_config() + self.num_steps = 20 + self.exp_write_batch_size = 3 # 3 # 6 + + def _default_exp_list(self): + return [ + [ + Experience( + tokens=torch.tensor([float(k) for k in range(j + 3)]), + reward=float(i), # using reward to carry model_version for testing + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ) + for j in range(self.exp_write_batch_size) + ] + for i in range(self.num_steps) + ] + + def _default_steps(self): + return [0, 5, 10, 15] + + async def _verify_model_version(self, step, expected_versions): + batch, metrics, _ = await self.sample_strategy.sample(step=step) + self.assertEqual(batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}") + self.assertEqual(metrics['sample/model_version/min'], min(expected_versions), f"Min model version mismatch at step {step}") + self.assertEqual(metrics['sample/model_version/max'], max(expected_versions), f"Max model version mismatch at step {step}") + self.assertEqual(metrics['sample/model_version/mean'], sum(expected_versions) / len(expected_versions), f"Mean model version mismatch at step {step}") + + async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map): + # Initialize buffer writer and sample strategy + self.buffer_writer = get_buffer_writer( + self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type] + ) + self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get( + self.config.algorithm.sample_strategy + )( + buffer_config=self.config.buffer, + **self.config.algorithm.sample_strategy_args, + ) + + # Write experiences to buffer, while sample and validate model versions + current_task = None + for step, exps in enumerate(exps_list): + await self.buffer_writer.write_async(exps) + if step in expected_model_versions_map: + if current_task: + await current_task + current_task = asyncio.create_task( + self._verify_model_version(step, expected_model_versions_map[step]) + ) + await asyncio.sleep(0.1) + + if current_task: + await current_task + + async def test_default_queue_default_sample_strategy(self): + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.name = "default_queue_default_strategy" + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + train_batch_size = self.config.buffer.train_batch_size + expected_model_versions_map = {} + for idx, step in enumerate(steps): + start_idx = idx * train_batch_size + batch_versions = [ + (start_idx + offset) // self.exp_write_batch_size + for offset in range(train_batch_size) + ] + expected_model_versions_map[step] = batch_versions + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + async def test_default_queue_staleness_control_sample_strategy(self): + staleness_limit = 3 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.name = "default_queue_staleness_control" + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = {} + for step in steps: + predict_version = max(step - staleness_limit, 0) + expected_model_versions_map[step] = [ + predict_version + i // self.exp_write_batch_size + for i in range(self.config.buffer.train_batch_size) + ] + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + def _simulate_priority_queue(self, steps, staleness_limit = float('inf')): + expected_model_versions_map = {} + buffer = deque() + exp_pool = deque() + step_idx = 0 + train_batch_size = self.config.buffer.train_batch_size + for i in range(self.num_steps): + buffer.append([i] * self.exp_write_batch_size) + step = steps[step_idx] + if i < step: + continue + batch_versions = expected_model_versions_map.get(step, []) + if len(batch_versions) < train_batch_size: + while len(buffer) > 0: + if len(exp_pool) == 0: + exp_pool.extend(buffer.pop()) + while len(exp_pool) > 0 and len(batch_versions) < train_batch_size: + exp_version = exp_pool.popleft() + if exp_version < step - staleness_limit: + continue + batch_versions.append(exp_version) + if len(batch_versions) >= train_batch_size: + step_idx += 1 + break + expected_model_versions_map[step] = batch_versions + if step_idx >= len(steps): + break + return expected_model_versions_map + + async def test_priority_queue_default_sample_strategy(self): + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.replay_buffer.enable = True + self.config.buffer.trainer_input.experience_buffer.name = "priority_queue_default_strategy" + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = self._simulate_priority_queue(steps) + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + async def test_priority_queue_staleness_control_sample_strategy(self): + staleness_limit = 2 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.check_and_update() + self.config.buffer.trainer_input.experience_buffer.replay_buffer.enable = True + self.config.buffer.trainer_input.experience_buffer.name = "priority_queue_staleness_control" + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + expected_model_versions_map = self._simulate_priority_queue(steps, staleness_limit) + + await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + # async def test_sql_default_sample_strategy(self): # debuging + # self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + # name="sql_default_strategy", + # storage_type=StorageType.SQL.value, + # ) + # self.config.check_and_update() + + # # init testing data + # exps_list = self._default_exp_list() + # steps = self._default_steps() + # expected_model_versions_map = self._simulate_priority_queue(steps) + # from pprint import pprint + # pprint(expected_model_versions_map) + + # await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + + def tearDown(self): + asyncio.run(self.buffer_writer.release()) + shutil.rmtree(self.config.checkpoint_job_dir) + return super().tearDown() diff --git a/tests/buffer/sql_test.py b/tests/buffer/sql_test.py index 44b81e1495..1e742a54bc 100644 --- a/tests/buffer/sql_test.py +++ b/tests/buffer/sql_test.py @@ -34,6 +34,7 @@ async def test_sql_exp_buffer_read_write(self) -> None: prompt_length=i, reward=float(i), logprobs=torch.tensor([0.1]), + info={"model_version": i}, ) for i in range(1, put_batch_size + 1) ] @@ -52,6 +53,7 @@ async def test_sql_exp_buffer_read_write(self) -> None: reward=float(i), logprobs=torch.tensor([0.1]), action_mask=torch.tensor([j % 2 for j in range(i + 1)]), + info={"model_version": i}, ) for i in range(1, put_batch_size * 2 + 1) ] diff --git a/tests/common/experience_test.py b/tests/common/experience_test.py index 195aaa61ae..55eabca721 100644 --- a/tests/common/experience_test.py +++ b/tests/common/experience_test.py @@ -253,6 +253,7 @@ def test_experience_model_experience_conversion(self): reward=reward, prompt_length=prompt_length, logprobs=logprobs, + info={"model_version": 0}, ) model = ExperienceModel.from_experience(experience) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 27b021146f..4817bde3e4 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -80,6 +80,24 @@ def load_state_dict(self, state_dict: dict) -> None: self.exp_buffer.load_state_dict(state_dict) +@SAMPLE_STRATEGY.register_module("staleness_control") +class StalenessControlSampleStrategy(DefaultSampleStrategy): + def __init__(self, buffer_config: BufferConfig, **kwargs): + super().__init__(buffer_config) + self.staleness_limit = kwargs.get("staleness_limit", float("inf")) + + async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: + oldest_valid_version = max(step - self.staleness_limit, -1) + metrics = {} + with Timer(metrics, "time/read_experience"): + exp_list = await self.exp_buffer.read_async(oldest_valid_version=oldest_valid_version) + repr_samples = representative_sample(exp_list) + self.set_model_version_metric(exp_list, metrics) + with Timer(metrics, "time/gather_experience"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + return exps, metrics, repr_samples + + @Deprecated @SAMPLE_STRATEGY.register_module("warmup") class WarmupSampleStrategy(DefaultSampleStrategy): diff --git a/trinity/buffer/buffer_reader.py b/trinity/buffer/buffer_reader.py index d47d80ace1..ad4d414547 100644 --- a/trinity/buffer/buffer_reader.py +++ b/trinity/buffer/buffer_reader.py @@ -7,11 +7,11 @@ class BufferReader(ABC): """Interface of the buffer reader.""" @abstractmethod - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: """Read from buffer.""" @abstractmethod - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: """Read from buffer asynchronously.""" def __len__(self) -> int: diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 8fa3d4c03d..a31009e4c4 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -86,7 +86,7 @@ def select_batch(self, indices: List[int]) -> List: class BaseFileReader(BufferReader): - async def read_async(self, batch_size: Optional[int] = None): + async def read_async(self, batch_size: Optional[int] = None, **kwargs): try: return self.read(batch_size) except StopIteration as e: @@ -103,7 +103,7 @@ def __init__(self, config: StorageConfig): else: self.reader = TaskFileReader(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: return self.reader.read(batch_size) def read_with_indices(self, indices: List[int]) -> List: @@ -142,7 +142,7 @@ def __init__(self, config: StorageConfig): enable_progress_bar=config.enable_progress_bar, ) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: samples, _ = self.dataset.read_batch(batch_size or self.read_batch_size) exp_list = [] for sample in samples: @@ -189,7 +189,7 @@ def _get_tasks(self, samples: List, indices: List) -> List: tasks.append(task) return tasks - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = batch_size or self.read_batch_size samples, indices = self.dataset.read_batch(batch_size) return self._get_tasks(samples, indices) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index e8debdada3..4d340bd1b3 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -21,10 +21,10 @@ def __init__(self, config: StorageConfig): self.read_batch_size = config.batch_size self.queue = QueueStorage.get_wrapper(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: try: batch_size = batch_size or self.read_batch_size - exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout)) + exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." @@ -33,9 +33,9 @@ def read(self, batch_size: Optional[int] = None) -> List: raise StopIteration() return exps - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: batch_size = batch_size or self.read_batch_size - exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout) + exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs) if len(exps) != batch_size: raise TimeoutError( f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow." diff --git a/trinity/buffer/reader/sql_reader.py b/trinity/buffer/reader/sql_reader.py index d13feeed7f..8f7dd57907 100644 --- a/trinity/buffer/reader/sql_reader.py +++ b/trinity/buffer/reader/sql_reader.py @@ -20,20 +20,20 @@ def __init__(self, config: StorageConfig) -> None: self.wrap_in_ray = config.wrap_in_ray self.storage = SQLStorage.get_wrapper(config) - def read(self, batch_size: Optional[int] = None) -> List: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List: if self.wrap_in_ray: - return ray.get(self.storage.read.remote(batch_size)) + return ray.get(self.storage.read.remote(batch_size, **kwargs)) else: - return self.storage.read(batch_size) + return self.storage.read(batch_size, **kwargs) - async def read_async(self, batch_size: Optional[int] = None) -> List: + async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List: if self.wrap_in_ray: try: - return ray.get(self.storage.read.remote(batch_size)) + return await self.storage.read.remote(batch_size, **kwargs) except StopIteration: raise StopAsyncIteration else: - return self.storage.read(batch_size) + return self.storage.read(batch_size, **kwargs) def state_dict(self) -> Dict: # SQL Not supporting state dict yet diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index e0df0b1e8e..c3bc13411e 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -43,6 +43,9 @@ class ExperienceModel(Base): # type: ignore # for multi turn message_list = Column(JSON, nullable=True) reward = Column(Float, nullable=True) + # for step info + train_step = Column(Integer, nullable=True) + explore_step = Column(Integer, nullable=True) # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) consumed = Column(Integer, default=0, index=True) @@ -55,11 +58,13 @@ def to_experience(self) -> Experience: def from_experience(cls, experience: Experience): """Save the experience to database.""" return cls( - experience_bytes=experience.serialize(), - reward=experience.reward, prompt=experience.prompt_text, response=experience.response_text, message_list=experience.messages, + reward=experience.reward, + train_step=experience.info["model_version"], + explore_step=experience.eid.batch, + experience_bytes=experience.serialize(), ) diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 523cde5b18..33403a769d 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -106,6 +106,9 @@ def default_config(cls) -> Dict: class QueueBuffer(ABC): + async def set_oldest_valid_version(self, oldest_valid_version: int): + self.oldest_valid_version = oldest_valid_version + @abstractmethod async def put(self, exps: List[Experience]) -> None: """Put a list of experiences into the queue.""" @@ -155,6 +158,21 @@ def __init__(self, capacity: int): """ super().__init__(maxsize=capacity) self._closed = False + self.oldest_valid_version = -1 + + async def put(self, item: List[Experience]): + if len(item) == 0: + return + await super().put(item) + + async def get(self): + while True: + item = await super().get() + if ( + self.oldest_valid_version < 0 + or item[0].info["model_version"] >= self.oldest_valid_version + ): + return item async def close(self) -> None: """Close the queue.""" @@ -208,6 +226,7 @@ def __init__( self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False + self.oldest_valid_version = -1 async def _put(self, item: List[Experience], delay: float = 0) -> None: """ @@ -259,16 +278,23 @@ async def get(self) -> List[Experience]: - After retrieval, the item is optionally reinserted after a cooldown period. """ async with self._condition: - while len(self.priority_groups) == 0: - if self._closed: - raise StopAsyncIteration() - await self._condition.wait() + while True: + while len(self.priority_groups) == 0: + if self._closed: + raise StopAsyncIteration() + await self._condition.wait() + + _, item_queue = self.priority_groups.peekitem(index=-1) + item = item_queue.popleft() + self.item_count -= 1 + if not item_queue: + self.priority_groups.popitem(index=-1) - _, item_queue = self.priority_groups.peekitem(index=-1) - item = item_queue.popleft() - self.item_count -= 1 - if not item_queue: - self.priority_groups.popitem(index=-1) + if ( + self.oldest_valid_version < 0 + or item[0].info["model_version"] >= self.oldest_valid_version + ): + break for exp in item: exp.info["use_count"] += 1 @@ -352,10 +378,22 @@ async def put_batch(self, exp_list: List) -> None: if self.writer is not None: self.writer.write(exp_list) - async def get_batch(self, batch_size: int, timeout: float) -> List: + async def get_batch( + self, batch_size: int, timeout: float, oldest_valid_version: int = -1 + ) -> List: """Get batch of experience.""" + await self.queue.set_oldest_valid_version(oldest_valid_version) start_time = time.time() - while len(self.exp_pool) < batch_size: + result = [] + while len(result) < batch_size: + while len(self.exp_pool) > 0 and len(result) < batch_size: + exp = self.exp_pool.popleft() + if oldest_valid_version >= 0 and exp.info["model_version"] < oldest_valid_version: + continue + result.append(exp) + if len(result) >= batch_size: + break + if self.queue.stopped(): # If the queue is stopped, ignore the rest of the experiences in the pool raise StopAsyncIteration("Queue is closed and no more items to get.") @@ -372,7 +410,7 @@ async def get_batch(self, batch_size: int, timeout: float) -> List: batch = list(self.exp_pool) self.exp_pool.clear() return batch - return [self.exp_pool.popleft() for _ in range(batch_size)] + return result @classmethod def get_wrapper(cls, config: StorageConfig): diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index e3068fd896..2e20c1475f 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -111,7 +111,7 @@ def write(self, data: List[Experience]) -> None: session.add_all(experience_models) self.logger.info(f"Write {len(experience_models)} experiences to SQL storage.") - def _read_fifo(self, batch_size: int) -> List[Experience]: + def _read_fifo(self, batch_size: int, oldest_valid_version: int = -1) -> List[Experience]: """Read experiences in FIFO order.""" exp_list = [] start_time = time.time() @@ -143,7 +143,7 @@ def _read_fifo(self, batch_size: int) -> List[Experience]: time.sleep(1) return exp_list - def _read_priority(self, batch_size: int) -> List[Experience]: + def _read_priority(self, batch_size: int, oldest_valid_version: int = -1) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 @@ -158,13 +158,12 @@ def _read_priority(self, batch_size: int) -> List[Experience]: with retry_session( self.session, self.max_retry_times, self.max_retry_interval ) as session: - experiences = ( - session.query(self.table_model_cls) - .order_by(asc(self.table_model_cls.consumed), desc(self.table_model_cls.id)) - .limit(batch_size) - .with_for_update() - .all() + query = session.query(self.table_model_cls).order_by( + asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) ) + if oldest_valid_version >= 0: + query = query.filter(self.table_model_cls.train_step >= oldest_valid_version) + experiences = query.limit(batch_size).with_for_update().all() if len(experiences) != batch_size: if latest_size != len(experiences): latest_size = len(experiences) @@ -186,12 +185,13 @@ def _read_priority(self, batch_size: int) -> List[Experience]: time.sleep(1) return exp_list - def read(self, batch_size: Optional[int] = None) -> List[Experience]: + def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: if self.stopped: raise StopIteration() batch_size = batch_size or self.batch_size - return self._read_method(batch_size) + oldest_valid_version = kwargs.pop("oldest_valid_version", -1) + return self._read_method(batch_size, oldest_valid_version=oldest_valid_version) @classmethod def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLExperienceStorage": diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index 458c1ba626..c690bc407d 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -47,8 +47,8 @@ def __init__(self, config: Config): ) explorer_state = self.state.load_explorer() self.explore_step_num = explorer_state.get("latest_iteration", 0) - self.last_sync_step = self.explore_step_num if self.explore_step_num > 0 else -1 - self.last_monitored_step = self.explore_step_num if self.explore_step_num > 0 else -1 + self.last_sync_step = self.explore_step_num + self.last_monitored_step = self.explore_step_num self.synchronizer = Synchronizer.get_actor(config) self.config = config self.models, self.auxiliary_models = create_inference_models(config) @@ -328,9 +328,12 @@ async def benchmark(self) -> bool: async def save_checkpoint(self, sync_weight: bool = False) -> None: if self.scheduler: - await self._finish_steps( - self.last_monitored_step + 1, self.explore_step_num, self.model_version - ) + if self.explore_step_num == 0: + await self._finish_eval_step(step=0) + else: + await self._finish_steps( + self.last_monitored_step + 1, self.explore_step_num, self.model_version + ) self.last_monitored_step = self.explore_step_num if sync_weight: From 08c51446266990e65e252c70b1690661aa04dbb7 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Tue, 16 Dec 2025 19:54:34 +0800 Subject: [PATCH 2/7] pre commit fix --- tests/buffer/sample_strategy_test.py | 39 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index c6318e2fa1..ca2d1eca9b 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -1,12 +1,15 @@ import asyncio -from collections import deque import shutil +from collections import deque + import torch + from tests.tools import RayUnittestBaseAysnc, get_template_config -from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY, SampleStrategy +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + SampleStrategy, +) from trinity.buffer.buffer import get_buffer_writer -from trinity.common.config import ExperienceBufferConfig -from trinity.common.constants import StorageType from trinity.common.experience import Experience @@ -29,16 +32,30 @@ def _default_exp_list(self): ] for i in range(self.num_steps) ] - + def _default_steps(self): return [0, 5, 10, 15] async def _verify_model_version(self, step, expected_versions): batch, metrics, _ = await self.sample_strategy.sample(step=step) - self.assertEqual(batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}") - self.assertEqual(metrics['sample/model_version/min'], min(expected_versions), f"Min model version mismatch at step {step}") - self.assertEqual(metrics['sample/model_version/max'], max(expected_versions), f"Max model version mismatch at step {step}") - self.assertEqual(metrics['sample/model_version/mean'], sum(expected_versions) / len(expected_versions), f"Mean model version mismatch at step {step}") + self.assertEqual( + batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}" + ) + self.assertEqual( + metrics["sample/model_version/min"], + min(expected_versions), + f"Min model version mismatch at step {step}", + ) + self.assertEqual( + metrics["sample/model_version/max"], + max(expected_versions), + f"Max model version mismatch at step {step}", + ) + self.assertEqual( + metrics["sample/model_version/mean"], + sum(expected_versions) / len(expected_versions), + f"Mean model version mismatch at step {step}", + ) async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map): # Initialize buffer writer and sample strategy @@ -51,7 +68,7 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio buffer_config=self.config.buffer, **self.config.algorithm.sample_strategy_args, ) - + # Write experiences to buffer, while sample and validate model versions current_task = None for step, exps in enumerate(exps_list): @@ -106,7 +123,7 @@ async def test_default_queue_staleness_control_sample_strategy(self): await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) - def _simulate_priority_queue(self, steps, staleness_limit = float('inf')): + def _simulate_priority_queue(self, steps, staleness_limit=float("inf")): expected_model_versions_map = {} buffer = deque() exp_pool = deque() From 2926a94bf9b0a71b1347b0284b6021fe087f0599 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 18 Dec 2025 11:02:14 +0800 Subject: [PATCH 3/7] fix unittest and apply reviews --- tests/buffer/sample_strategy_test.py | 118 +++++++++++++----- tests/explorer/explorer_test.py | 2 + tests/service/data_juicer_test.py | 4 + .../sample_strategy/sample_strategy.py | 2 +- trinity/buffer/schema/sql_schema.py | 6 +- trinity/buffer/storage/queue.py | 14 +-- trinity/buffer/storage/sql.py | 11 +- 7 files changed, 107 insertions(+), 50 deletions(-) diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index ca2d1eca9b..f0685dd412 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -3,6 +3,7 @@ from collections import deque import torch +from parameterized import parameterized_class from tests.tools import RayUnittestBaseAysnc, get_template_config from trinity.algorithm.sample_strategy.sample_strategy import ( @@ -10,14 +11,22 @@ SampleStrategy, ) from trinity.buffer.buffer import get_buffer_writer +from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig +from trinity.common.constants import StorageType from trinity.common.experience import Experience +@parameterized_class( + ("exp_write_batch_size",), + [ + (3,), + (6,), + ], +) class ExperienceStorageTest(RayUnittestBaseAysnc): def setUp(self): self.config = get_template_config() self.num_steps = 20 - self.exp_write_batch_size = 3 # 3 # 6 def _default_exp_list(self): return [ @@ -36,6 +45,18 @@ def _default_exp_list(self): def _default_steps(self): return [0, 5, 10, 15] + def _init_buffer_writer_and_sample_strategy(self): + # Initialize buffer writer and sample strategy + self.buffer_writer = get_buffer_writer( + self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type] + ) + self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get( + self.config.algorithm.sample_strategy + )( + buffer_config=self.config.buffer, + **self.config.algorithm.sample_strategy_args, + ) + async def _verify_model_version(self, step, expected_versions): batch, metrics, _ = await self.sample_strategy.sample(step=step) self.assertEqual( @@ -58,16 +79,7 @@ async def _verify_model_version(self, step, expected_versions): ) async def _verify_sampling_model_versions(self, exps_list, expected_model_versions_map): - # Initialize buffer writer and sample strategy - self.buffer_writer = get_buffer_writer( - self.config.buffer.trainer_input.experience_buffer, # type: ignore [arg-type] - ) - self.sample_strategy: SampleStrategy = SAMPLE_STRATEGY.get( - self.config.algorithm.sample_strategy - )( - buffer_config=self.config.buffer, - **self.config.algorithm.sample_strategy_args, - ) + self._init_buffer_writer_and_sample_strategy() # Write experiences to buffer, while sample and validate model versions current_task = None @@ -84,9 +96,41 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio if current_task: await current_task + async def _flexible_verify_model_version(self, step, staleness_limit): + _, metrics, _ = await self.sample_strategy.sample(step=step) + self.assertGreaterEqual( + metrics["sample/model_version/min"], + step - staleness_limit, + f"Min model version mismatch at step {step}", + ) + + async def _flexible_verify_sampling_model_versions( + self, exps_list, check_steps, staleness_limit + ): + self._init_buffer_writer_and_sample_strategy() + + # Write experiences to buffer, while sample and validate model versions + current_task = None + for step, exps in enumerate(exps_list): + await self.buffer_writer.write_async(exps) + if step in check_steps: + if current_task: + await current_task + current_task = asyncio.create_task( + self._flexible_verify_model_version(step, staleness_limit) + ) + await asyncio.sleep(0.1) + + if current_task: + await current_task + async def test_default_queue_default_sample_strategy(self): + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="default_queue_default_strategy", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=False), + ) self.config.check_and_update() - self.config.buffer.trainer_input.experience_buffer.name = "default_queue_default_strategy" # init testing data exps_list = self._default_exp_list() @@ -107,8 +151,12 @@ async def test_default_queue_staleness_control_sample_strategy(self): staleness_limit = 3 self.config.algorithm.sample_strategy = "staleness_control" self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="default_queue_staleness_control", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=False), + ) self.config.check_and_update() - self.config.buffer.trainer_input.experience_buffer.name = "default_queue_staleness_control" # init testing data exps_list = self._default_exp_list() @@ -153,9 +201,12 @@ def _simulate_priority_queue(self, steps, staleness_limit=float("inf")): return expected_model_versions_map async def test_priority_queue_default_sample_strategy(self): + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="priority_queue_default_strategy", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=True), + ) self.config.check_and_update() - self.config.buffer.trainer_input.experience_buffer.replay_buffer.enable = True - self.config.buffer.trainer_input.experience_buffer.name = "priority_queue_default_strategy" # init testing data exps_list = self._default_exp_list() @@ -168,9 +219,12 @@ async def test_priority_queue_staleness_control_sample_strategy(self): staleness_limit = 2 self.config.algorithm.sample_strategy = "staleness_control" self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="priority_queue_staleness_control", + storage_type=StorageType.QUEUE.value, + replay_buffer=ReplayBufferConfig(enable=True), + ) self.config.check_and_update() - self.config.buffer.trainer_input.experience_buffer.replay_buffer.enable = True - self.config.buffer.trainer_input.experience_buffer.name = "priority_queue_staleness_control" # init testing data exps_list = self._default_exp_list() @@ -179,21 +233,21 @@ async def test_priority_queue_staleness_control_sample_strategy(self): await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) - # async def test_sql_default_sample_strategy(self): # debuging - # self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( - # name="sql_default_strategy", - # storage_type=StorageType.SQL.value, - # ) - # self.config.check_and_update() - - # # init testing data - # exps_list = self._default_exp_list() - # steps = self._default_steps() - # expected_model_versions_map = self._simulate_priority_queue(steps) - # from pprint import pprint - # pprint(expected_model_versions_map) - - # await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) + async def test_sql_staleness_control_sample_strategy(self): + staleness_limit = 2 + self.config.algorithm.sample_strategy = "staleness_control" + self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( + name="sql_staleness_control", + storage_type=StorageType.SQL.value, + ) + self.config.check_and_update() + + # init testing data + exps_list = self._default_exp_list() + steps = self._default_steps() + + await self._flexible_verify_sampling_model_versions(exps_list, steps, staleness_limit) def tearDown(self): asyncio.run(self.buffer_writer.release()) diff --git a/tests/explorer/explorer_test.py b/tests/explorer/explorer_test.py index c061099437..b1282d7c7a 100644 --- a/tests/explorer/explorer_test.py +++ b/tests/explorer/explorer_test.py @@ -5,6 +5,7 @@ import os import random import shutil +import unittest from datetime import datetime import httpx @@ -200,6 +201,7 @@ def setUp(self): if multiprocessing.get_start_method(allow_none=True) != "spawn": multiprocessing.set_start_method("spawn", force=True) + @unittest.skip("Require improvement for agent mode") async def test_serve(self): # noqa: C901 serve_process = multiprocessing.Process(target=run_serve, args=(self.config,)) serve_process.start() diff --git a/tests/service/data_juicer_test.py b/tests/service/data_juicer_test.py index 2fdd89ad85..60440e0d6e 100644 --- a/tests/service/data_juicer_test.py +++ b/tests/service/data_juicer_test.py @@ -182,24 +182,28 @@ async def test_data_juicer_operators(self): prompt_length=3, prompt_text="Hello, how are you?", response_text="Hi, I am fine.", + info={"model_version": 0}, ), Experience( # too short response tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What is your name?", response_text="Trinity.", + info={"model_version": 0}, ), Experience( # repeated words tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What day is it today?", response_text="Today is Sunday Sunday Sunday Sunday Sunday and it's a happy day!", + info={"model_version": 0}, ), Experience( tokens=torch.tensor([1, 2, 3, 4, 5]), prompt_length=3, prompt_text="What is your favorite color?", response_text="My favorite color is blue.", + info={"model_version": 0}, ), ] metrics = await pipeline.process.remote(exps) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 4817bde3e4..7229b4b5d2 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -87,7 +87,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): self.staleness_limit = kwargs.get("staleness_limit", float("inf")) async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: - oldest_valid_version = max(step - self.staleness_limit, -1) + oldest_valid_version = max(step - self.staleness_limit, 0) metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async(oldest_valid_version=oldest_valid_version) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index c3bc13411e..dd6f5056db 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -44,8 +44,7 @@ class ExperienceModel(Base): # type: ignore message_list = Column(JSON, nullable=True) reward = Column(Float, nullable=True) # for step info - train_step = Column(Integer, nullable=True) - explore_step = Column(Integer, nullable=True) + model_version = Column(Integer, nullable=True) # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) consumed = Column(Integer, default=0, index=True) @@ -62,8 +61,7 @@ def from_experience(cls, experience: Experience): response=experience.response_text, message_list=experience.messages, reward=experience.reward, - train_step=experience.info["model_version"], - explore_step=experience.eid.batch, + model_version=experience.info["model_version"], experience_bytes=experience.serialize(), ) diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 33403a769d..4f760c0a6b 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -107,7 +107,7 @@ def default_config(cls) -> Dict: class QueueBuffer(ABC): async def set_oldest_valid_version(self, oldest_valid_version: int): - self.oldest_valid_version = oldest_valid_version + self.oldest_valid_version = max(oldest_valid_version, 0) @abstractmethod async def put(self, exps: List[Experience]) -> None: @@ -158,7 +158,7 @@ def __init__(self, capacity: int): """ super().__init__(maxsize=capacity) self._closed = False - self.oldest_valid_version = -1 + self.oldest_valid_version = 0 async def put(self, item: List[Experience]): if len(item) == 0: @@ -169,7 +169,7 @@ async def get(self): while True: item = await super().get() if ( - self.oldest_valid_version < 0 + self.oldest_valid_version <= 0 or item[0].info["model_version"] >= self.oldest_valid_version ): return item @@ -226,7 +226,7 @@ def __init__( self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False - self.oldest_valid_version = -1 + self.oldest_valid_version = 0 async def _put(self, item: List[Experience], delay: float = 0) -> None: """ @@ -291,7 +291,7 @@ async def get(self) -> List[Experience]: self.priority_groups.popitem(index=-1) if ( - self.oldest_valid_version < 0 + self.oldest_valid_version <= 0 or item[0].info["model_version"] >= self.oldest_valid_version ): break @@ -379,7 +379,7 @@ async def put_batch(self, exp_list: List) -> None: self.writer.write(exp_list) async def get_batch( - self, batch_size: int, timeout: float, oldest_valid_version: int = -1 + self, batch_size: int, timeout: float, oldest_valid_version: int = 0 ) -> List: """Get batch of experience.""" await self.queue.set_oldest_valid_version(oldest_valid_version) @@ -388,7 +388,7 @@ async def get_batch( while len(result) < batch_size: while len(self.exp_pool) > 0 and len(result) < batch_size: exp = self.exp_pool.popleft() - if oldest_valid_version >= 0 and exp.info["model_version"] < oldest_valid_version: + if oldest_valid_version > 0 and exp.info["model_version"] < oldest_valid_version: continue result.append(exp) if len(result) >= batch_size: diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 2e20c1475f..646ebeba69 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -111,7 +111,7 @@ def write(self, data: List[Experience]) -> None: session.add_all(experience_models) self.logger.info(f"Write {len(experience_models)} experiences to SQL storage.") - def _read_fifo(self, batch_size: int, oldest_valid_version: int = -1) -> List[Experience]: + def _read_fifo(self, batch_size: int) -> List[Experience]: """Read experiences in FIFO order.""" exp_list = [] start_time = time.time() @@ -143,7 +143,7 @@ def _read_fifo(self, batch_size: int, oldest_valid_version: int = -1) -> List[Ex time.sleep(1) return exp_list - def _read_priority(self, batch_size: int, oldest_valid_version: int = -1) -> List[Experience]: + def _read_priority(self, batch_size: int, oldest_valid_version: int = 0) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 @@ -161,8 +161,8 @@ def _read_priority(self, batch_size: int, oldest_valid_version: int = -1) -> Lis query = session.query(self.table_model_cls).order_by( asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) ) - if oldest_valid_version >= 0: - query = query.filter(self.table_model_cls.train_step >= oldest_valid_version) + if oldest_valid_version > 0: + query = query.filter(self.table_model_cls.model_version >= oldest_valid_version) experiences = query.limit(batch_size).with_for_update().all() if len(experiences) != batch_size: if latest_size != len(experiences): @@ -190,8 +190,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]: raise StopIteration() batch_size = batch_size or self.batch_size - oldest_valid_version = kwargs.pop("oldest_valid_version", -1) - return self._read_method(batch_size, oldest_valid_version=oldest_valid_version) + return self._read_method(batch_size, **kwargs) @classmethod def load_from_dataset(cls, dataset: Dataset, config: StorageConfig) -> "SQLExperienceStorage": From de0e39ea98760fb7dd047b784478b46c3b11f67b Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 18 Dec 2025 11:07:08 +0800 Subject: [PATCH 4/7] fix merge --- trinity/algorithm/sample_strategy/__init__.py | 1 + trinity/algorithm/sample_strategy/sample_strategy.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py index 9e2700fb4a..067b45a2e2 100644 --- a/trinity/algorithm/sample_strategy/__init__.py +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -6,6 +6,7 @@ default_mapping={ "default": "trinity.algorithm.sample_strategy.sample_strategy.DefaultSampleStrategy", "warmup": "trinity.algorithm.sample_strategy.sample_strategy.WarmupSampleStrategy", + "staleness_control": "trinity.algorithm.sample_strategy.sample_strategy.StalenessControlSampleStrategy", "mix": "trinity.algorithm.sample_strategy.mix_sample_strategy.MixSampleStrategy", }, ) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 7df9716212..9c1a65f448 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -76,7 +76,6 @@ def load_state_dict(self, state_dict: dict) -> None: self.exp_buffer.load_state_dict(state_dict) -@SAMPLE_STRATEGY.register_module("staleness_control") class StalenessControlSampleStrategy(DefaultSampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) From 4aa826b619e777703d4c30d314a180aa3d0dc933 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 18 Dec 2025 12:10:38 +0800 Subject: [PATCH 5/7] fix unittest --- tests/buffer/sample_strategy_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index f0685dd412..63d99b1d7a 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -6,10 +6,8 @@ from parameterized import parameterized_class from tests.tools import RayUnittestBaseAysnc, get_template_config -from trinity.algorithm.sample_strategy.sample_strategy import ( - SAMPLE_STRATEGY, - SampleStrategy, -) +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY +from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy from trinity.buffer.buffer import get_buffer_writer from trinity.common.config import ExperienceBufferConfig, ReplayBufferConfig from trinity.common.constants import StorageType From d71d3c8b66a66cc77f558d349dbeb9f47461e2ee Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 18 Dec 2025 14:36:13 +0800 Subject: [PATCH 6/7] apply reviews --- .../sample_strategy/sample_strategy.py | 4 ++-- trinity/buffer/storage/queue.py | 24 +++++++++---------- trinity/buffer/storage/sql.py | 6 ++--- 3 files changed, 16 insertions(+), 18 deletions(-) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 9c1a65f448..ce70383c15 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -82,10 +82,10 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): self.staleness_limit = kwargs.get("staleness_limit", float("inf")) async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: - oldest_valid_version = max(step - self.staleness_limit, 0) + min_model_version = max(step - self.staleness_limit, 0) metrics = {} with Timer(metrics, "time/read_experience"): - exp_list = await self.exp_buffer.read_async(oldest_valid_version=oldest_valid_version) + exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version) repr_samples = representative_sample(exp_list) self.set_model_version_metric(exp_list, metrics) with Timer(metrics, "time/gather_experience"): diff --git a/trinity/buffer/storage/queue.py b/trinity/buffer/storage/queue.py index 9d14e5bb25..a0da043895 100644 --- a/trinity/buffer/storage/queue.py +++ b/trinity/buffer/storage/queue.py @@ -100,8 +100,8 @@ def default_config(cls) -> Dict: class QueueBuffer(ABC): - async def set_oldest_valid_version(self, oldest_valid_version: int): - self.oldest_valid_version = max(oldest_valid_version, 0) + async def set_min_model_version(self, min_model_version: int): + self.min_model_version = max(min_model_version, 0) @abstractmethod async def put(self, exps: List[Experience]) -> None: @@ -152,7 +152,7 @@ def __init__(self, capacity: int): """ super().__init__(maxsize=capacity) self._closed = False - self.oldest_valid_version = 0 + self.min_model_version = 0 async def put(self, item: List[Experience]): if len(item) == 0: @@ -163,8 +163,8 @@ async def get(self): while True: item = await super().get() if ( - self.oldest_valid_version <= 0 - or item[0].info["model_version"] >= self.oldest_valid_version + self.min_model_version <= 0 + or item[0].info["model_version"] >= self.min_model_version ): return item @@ -222,7 +222,7 @@ def __init__( self.reuse_cooldown_time = reuse_cooldown_time self._condition = asyncio.Condition() # For thread-safe operations self._closed = False - self.oldest_valid_version = 0 + self.min_model_version = 0 async def _put(self, item: List[Experience], delay: float = 0) -> None: """ @@ -287,8 +287,8 @@ async def get(self) -> List[Experience]: self.priority_groups.popitem(index=-1) if ( - self.oldest_valid_version <= 0 - or item[0].info["model_version"] >= self.oldest_valid_version + self.min_model_version <= 0 + or item[0].info["model_version"] >= self.min_model_version ): break @@ -374,17 +374,15 @@ async def put_batch(self, exp_list: List) -> None: if self.writer is not None: self.writer.write(exp_list) - async def get_batch( - self, batch_size: int, timeout: float, oldest_valid_version: int = 0 - ) -> List: + async def get_batch(self, batch_size: int, timeout: float, min_model_version: int = 0) -> List: """Get batch of experience.""" - await self.queue.set_oldest_valid_version(oldest_valid_version) + await self.queue.set_min_model_version(min_model_version) start_time = time.time() result = [] while len(result) < batch_size: while len(self.exp_pool) > 0 and len(result) < batch_size: exp = self.exp_pool.popleft() - if oldest_valid_version > 0 and exp.info["model_version"] < oldest_valid_version: + if min_model_version > 0 and exp.info["model_version"] < min_model_version: continue result.append(exp) if len(result) >= batch_size: diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 6287866665..3947aa76a3 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -143,7 +143,7 @@ def _read_fifo(self, batch_size: int) -> List[Experience]: time.sleep(1) return exp_list - def _read_priority(self, batch_size: int, oldest_valid_version: int = 0) -> List[Experience]: + def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Experience]: exp_list = [] start_time = time.time() latest_size = 0 @@ -161,8 +161,8 @@ def _read_priority(self, batch_size: int, oldest_valid_version: int = 0) -> List query = session.query(self.table_model_cls).order_by( asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) ) - if oldest_valid_version > 0: - query = query.filter(self.table_model_cls.model_version >= oldest_valid_version) + if min_model_version > 0: + query = query.filter(self.table_model_cls.model_version >= min_model_version) experiences = query.limit(batch_size).with_for_update().all() if len(experiences) != batch_size: if latest_size != len(experiences): From f6dd28e674e8820492647ae41b0b05b43fa58cd1 Mon Sep 17 00:00:00 2001 From: chenyushuo <297086016@qq.com> Date: Thu, 18 Dec 2025 15:12:02 +0800 Subject: [PATCH 7/7] apply reviews --- .../source/tutorial/trinity_configs.md | 2 +- .../source_zh/tutorial/trinity_configs.md | 2 +- tests/buffer/sample_strategy_test.py | 32 +++++++++---------- .../sample_strategy/sample_strategy.py | 4 +-- trinity/buffer/schema/sql_schema.py | 2 +- trinity/buffer/storage/sql.py | 13 +++++--- 6 files changed, 29 insertions(+), 26 deletions(-) diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index c07fe07921..b1d08f03b0 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -112,7 +112,7 @@ algorithm: - `optimizer`: Optimizer configuration for actor. - `lr`: Learning rate for actor. - `warmup_style`: Warmup style for actor's learning rate. -- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. +- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. Supported types: `default`, `staleness_control`, `mix`. - `advantage_fn`: The advantage function used for computing advantages. - `kl_penalty_fn`: The KL penalty function used for computing KL penalty applied in reward. - `kl_loss_fn`: The KL loss function used for computing KL loss. diff --git a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md index 7e7aec64b8..0fff6cd91f 100644 --- a/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source_zh/tutorial/trinity_configs.md @@ -112,7 +112,7 @@ algorithm: - `optimizer`: Actor 优化器的参数。 - `lr`: 优化器的学习率。 - `warmup_style`: 学习率的预热策略。 -- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。 +- `sample_strategy`: 从 experience buffer 加载 experience 时使用的采样策略。支持类型:`default`、`staleness_control`、`mix`。 - `advantage_fn`: 用于计算优势值的函数。 - `kl_penalty_fn`: 用于在奖励中计算 KL 惩罚的函数。 - `kl_loss_fn`: 用于计算 KL 损失的函数。 diff --git a/tests/buffer/sample_strategy_test.py b/tests/buffer/sample_strategy_test.py index 63d99b1d7a..32ea84bdb7 100644 --- a/tests/buffer/sample_strategy_test.py +++ b/tests/buffer/sample_strategy_test.py @@ -94,17 +94,15 @@ async def _verify_sampling_model_versions(self, exps_list, expected_model_versio if current_task: await current_task - async def _flexible_verify_model_version(self, step, staleness_limit): + async def _flexible_verify_model_version(self, step, max_staleness): _, metrics, _ = await self.sample_strategy.sample(step=step) self.assertGreaterEqual( metrics["sample/model_version/min"], - step - staleness_limit, + step - max_staleness, f"Min model version mismatch at step {step}", ) - async def _flexible_verify_sampling_model_versions( - self, exps_list, check_steps, staleness_limit - ): + async def _flexible_verify_sampling_model_versions(self, exps_list, check_steps, max_staleness): self._init_buffer_writer_and_sample_strategy() # Write experiences to buffer, while sample and validate model versions @@ -115,7 +113,7 @@ async def _flexible_verify_sampling_model_versions( if current_task: await current_task current_task = asyncio.create_task( - self._flexible_verify_model_version(step, staleness_limit) + self._flexible_verify_model_version(step, max_staleness) ) await asyncio.sleep(0.1) @@ -146,9 +144,9 @@ async def test_default_queue_default_sample_strategy(self): await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) async def test_default_queue_staleness_control_sample_strategy(self): - staleness_limit = 3 + max_staleness = 3 self.config.algorithm.sample_strategy = "staleness_control" - self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="default_queue_staleness_control", storage_type=StorageType.QUEUE.value, @@ -161,7 +159,7 @@ async def test_default_queue_staleness_control_sample_strategy(self): steps = self._default_steps() expected_model_versions_map = {} for step in steps: - predict_version = max(step - staleness_limit, 0) + predict_version = max(step - max_staleness, 0) expected_model_versions_map[step] = [ predict_version + i // self.exp_write_batch_size for i in range(self.config.buffer.train_batch_size) @@ -169,7 +167,7 @@ async def test_default_queue_staleness_control_sample_strategy(self): await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) - def _simulate_priority_queue(self, steps, staleness_limit=float("inf")): + def _simulate_priority_queue(self, steps, max_staleness=float("inf")): expected_model_versions_map = {} buffer = deque() exp_pool = deque() @@ -187,7 +185,7 @@ def _simulate_priority_queue(self, steps, staleness_limit=float("inf")): exp_pool.extend(buffer.pop()) while len(exp_pool) > 0 and len(batch_versions) < train_batch_size: exp_version = exp_pool.popleft() - if exp_version < step - staleness_limit: + if exp_version < step - max_staleness: continue batch_versions.append(exp_version) if len(batch_versions) >= train_batch_size: @@ -214,9 +212,9 @@ async def test_priority_queue_default_sample_strategy(self): await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) async def test_priority_queue_staleness_control_sample_strategy(self): - staleness_limit = 2 + max_staleness = 2 self.config.algorithm.sample_strategy = "staleness_control" - self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="priority_queue_staleness_control", storage_type=StorageType.QUEUE.value, @@ -227,14 +225,14 @@ async def test_priority_queue_staleness_control_sample_strategy(self): # init testing data exps_list = self._default_exp_list() steps = self._default_steps() - expected_model_versions_map = self._simulate_priority_queue(steps, staleness_limit) + expected_model_versions_map = self._simulate_priority_queue(steps, max_staleness) await self._verify_sampling_model_versions(exps_list, expected_model_versions_map) async def test_sql_staleness_control_sample_strategy(self): - staleness_limit = 2 + max_staleness = 2 self.config.algorithm.sample_strategy = "staleness_control" - self.config.algorithm.sample_strategy_args = {"staleness_limit": staleness_limit} + self.config.algorithm.sample_strategy_args = {"max_staleness": max_staleness} self.config.buffer.trainer_input.experience_buffer = ExperienceBufferConfig( name="sql_staleness_control", storage_type=StorageType.SQL.value, @@ -245,7 +243,7 @@ async def test_sql_staleness_control_sample_strategy(self): exps_list = self._default_exp_list() steps = self._default_steps() - await self._flexible_verify_sampling_model_versions(exps_list, steps, staleness_limit) + await self._flexible_verify_sampling_model_versions(exps_list, steps, max_staleness) def tearDown(self): asyncio.run(self.buffer_writer.release()) diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index ce70383c15..2ab63032cb 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -79,10 +79,10 @@ def load_state_dict(self, state_dict: dict) -> None: class StalenessControlSampleStrategy(DefaultSampleStrategy): def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) - self.staleness_limit = kwargs.get("staleness_limit", float("inf")) + self.max_staleness = kwargs.get("max_staleness", float("inf")) async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: - min_model_version = max(step - self.staleness_limit, 0) + min_model_version = max(step - self.max_staleness, 0) metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version) diff --git a/trinity/buffer/schema/sql_schema.py b/trinity/buffer/schema/sql_schema.py index bc515c3923..997c661a23 100644 --- a/trinity/buffer/schema/sql_schema.py +++ b/trinity/buffer/schema/sql_schema.py @@ -39,7 +39,7 @@ class ExperienceModel(Base): # type: ignore message_list = Column(JSON, nullable=True) reward = Column(Float, nullable=True) # for step info - model_version = Column(Integer, nullable=True) + model_version = Column(Integer, nullable=True, index=True) # serialized experience object experience_bytes = Column(LargeBinary, nullable=True) consumed = Column(Integer, default=0, index=True) diff --git a/trinity/buffer/storage/sql.py b/trinity/buffer/storage/sql.py index 3947aa76a3..08ff06fb8c 100644 --- a/trinity/buffer/storage/sql.py +++ b/trinity/buffer/storage/sql.py @@ -158,12 +158,17 @@ def _read_priority(self, batch_size: int, min_model_version: int = 0) -> List[Ex with retry_session( self.session, self.max_retry_times, self.max_retry_interval ) as session: - query = session.query(self.table_model_cls).order_by( - asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) - ) + query = session.query(self.table_model_cls) if min_model_version > 0: query = query.filter(self.table_model_cls.model_version >= min_model_version) - experiences = query.limit(batch_size).with_for_update().all() + experiences = ( + query.order_by( + asc(self.table_model_cls.consumed), desc(self.table_model_cls.id) + ) + .limit(batch_size) + .with_for_update() + .all() + ) if len(experiences) != batch_size: if latest_size != len(experiences): latest_size = len(experiences)