diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index e1a497c9de..ecd846f789 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -69,9 +69,8 @@ jobs: docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV elif [ "$TYPE" = "diff" ]; then - ROOT_DIR=trinity-${{ github.run_id }} - if [ -s "$ROOT_DIR/test_dirs.txt" ]; then - TEST_DIRS=$(cat "$ROOT_DIR/test_dirs.txt" | xargs) + if [ -s ../../../test_dirs.txt ]; then + TEST_DIRS=$(cat ../../../test_dirs.txt | xargs) docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json echo "tests_run=true" >> $GITHUB_ENV else diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index a7c92bef61..7ecb6e30bc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -100,10 +100,12 @@ class Workflow(ABC): def __init__( self, - model: ModelWrapper, + *, task: Task, + model: ModelWrapper, auxiliary_models: Optional[List[openai.OpenAI]] = None, ): + self.task = task self.model = model self.auxiliary_models = auxiliary_models @@ -116,13 +118,13 @@ class Workflow(ABC): During initialization, `Workflow` receives the following parameters: -- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). - `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset. +- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`). - `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs. ```{tip} You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow. -And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`. +And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id` or `openai_client.model_path`. ``` Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization. @@ -130,8 +132,8 @@ Here's an example of initializing a simple workflow using only `raw_task` and `r ```python class ExampleWorkflow(Workflow): - def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List): - super().__init__(model, task, auxiliary_models) + def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") self.rollout_args = task.rollout_args @@ -244,8 +246,8 @@ class ExampleWorkflow(Workflow): @WORKFLOWS.register_module("example_workflow") class ExampleWorkflow(Workflow): - def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List): - super().__init__(model, task, auxiliary_models) + def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) self.question = task.raw_task.get("question") self.answer = task.raw_task.get("answer") self.rollout_args = task.rollout_args diff --git a/tests/buffer/queue_test.py b/tests/buffer/queue_test.py index 11526c223f..6406229a98 100644 --- a/tests/buffer/queue_test.py +++ b/tests/buffer/queue_test.py @@ -2,6 +2,7 @@ import threading import time +import ray import torch from parameterized import parameterized @@ -75,19 +76,17 @@ async def test_queue_buffer(self, name, use_priority_queue): self.assertRaises(StopIteration, reader.read) with open(BUFFER_FILE_PATH, "r") as f: self.assertEqual(len(f.readlines()), self.total_num + self.put_batch_size * 2) - st = time.time() - self.assertRaises(TimeoutError, reader.read, batch_size=1) - et = time.time() - self.assertTrue(et - st > 2) + self.assertRaises(StopIteration, reader.read, batch_size=1) async def test_priority_queue_capacity(self): # test queue capacity + self.config.read_batch_size = 4 meta = StorageConfig( name="test_buffer_small", algorithm_type="ppo", storage_type=StorageType.QUEUE, max_read_timeout=1, - capacity=2, + capacity=100, # priority will use 2 * read_batch_size as capacity (8) path=BUFFER_FILE_PATH, use_priority_queue=True, replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6}, @@ -95,7 +94,7 @@ async def test_priority_queue_capacity(self): writer = QueueWriter(meta, self.config) reader = QueueReader(meta, self.config) - for i in range(4): + for i in range(12): writer.write( [ Experience( @@ -106,15 +105,34 @@ async def test_priority_queue_capacity(self): ] ) - exps = reader.read(batch_size=2) - self.assertEqual(exps[0].info["model_version"], 3) + self.assertEqual(ray.get(reader.queue.length.remote()), 8) + + exps = reader.read(batch_size=8) + self.assertEqual(exps[0].info["model_version"], 11) self.assertEqual(exps[0].info["use_count"], 1) - self.assertEqual(exps[1].info["model_version"], 2) + self.assertEqual(exps[1].info["model_version"], 10) self.assertEqual(exps[1].info["use_count"], 1) + self.assertEqual(exps[7].info["model_version"], 4) with self.assertRaises(TimeoutError): reader.read(batch_size=1) + for i in range(12): + writer.write( + [ + Experience( + tokens=torch.tensor([1, 2, 3]), + prompt_length=2, + info={"model_version": i, "use_count": 0}, + ), + ] + ) + await writer.release() + exps = reader.read(batch_size=8) + + with self.assertRaises(StopIteration): + reader.read(batch_size=1) + async def test_queue_buffer_capacity(self): # test queue capacity meta = StorageConfig( diff --git a/trinity/buffer/priority_queue.py b/trinity/buffer/priority_queue.py deleted file mode 100644 index 7878a3e28e..0000000000 --- a/trinity/buffer/priority_queue.py +++ /dev/null @@ -1,126 +0,0 @@ -"""An Async PriorityQueue.""" -import asyncio -from collections import deque -from typing import List, Optional, Union - -import numpy as np -from sortedcontainers import SortedDict - -from trinity.common.experience import Experience -from trinity.utils.registry import Registry - -PRIORITY_FUNC = Registry("priority_fn") - - -@PRIORITY_FUNC.register_module("linear_decay") -def linear_decay_priority(item: List[Experience], decay: float = 0.1): - return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore - - -class AsyncPriorityQueue: - """ - An asynchronous priority queue that manages a fixed-size buffer of experience items. - Items are prioritized using a user-defined function and reinserted after a cooldown period. - - Attributes: - capacity (int): Maximum number of items the queue can hold. - priority_groups (SortedDict): Maps priorities to deques of items with the same priority. - priority_fn (callable): Function used to determine the priority of an item. - reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable). - """ - - def __init__( - self, - capacity: int, - reuse_cooldown_time: Optional[float] = None, - priority_fn: str = "linear_decay", - **kwargs, - ): - """ - Initialize the async priority queue. - - Args: - capacity (`int`): The maximum number of items the queue can store. - reuse_cooldown_time (`float`): Time to wait before reusing an item. Set to None to disable reuse. - priority_fn (`str`): Name of the function to use for determining item priority. - kwargs: Additional keyword arguments for the priority function. - """ - self.capacity = capacity - self.priority_groups = SortedDict() # Maps priority -> deque of items - priority_fn = PRIORITY_FUNC.get(priority_fn) - from trinity.buffer.queue import QueueActor - - # TODO: remove FINISHE_MESSAGE and use a more elegant solution - self.FINISH_MESSAGE = QueueActor.FINISH_MESSAGE - - self.priority_fn = ( - lambda item: priority_fn(item, **kwargs) if item != self.FINISH_MESSAGE else -np.inf # type: ignore - ) - self.reuse_cooldown_time = reuse_cooldown_time - self._condition = asyncio.Condition() # For thread-safe operations - - async def put(self, item: Union[List[Experience], str], delay: float = 0) -> None: - """ - Insert an item into the queue, possibly replacing the lowest-priority item if full. - - Args: - item (`List[Experience]`): A list of experiences to add. - delay (`float`): Optional delay before insertion (for simulating timing behavior). - """ - if delay > 0: - await asyncio.sleep(delay) - - priority = self.priority_fn(item) - async with self._condition: - if len(self.priority_groups) == self.capacity: - # If full, only insert if new item has higher or equal priority than the lowest - lowest_priority, item_queue = self.priority_groups.peekitem(index=0) - if lowest_priority > priority: - return # Skip insertion if lower priority - # Remove the lowest priority item - item_queue.popleft() - if not item_queue: - self.priority_groups.popitem(index=0) - - # Add the new item - if priority not in self.priority_groups: - self.priority_groups[priority] = deque() - self.priority_groups[priority].append(item) - self._condition.notify() - - async def get(self) -> List[Experience]: - """ - Retrieve the highest-priority item from the queue. - - Returns: - List[Experience]: The highest-priority item (list of experiences). - - Notes: - - After retrieval, the item is optionally reinserted after a cooldown period. - """ - async with self._condition: - while len(self.priority_groups) == 0: - await self._condition.wait() - - _, item_queue = self.priority_groups.peekitem(index=-1) - item = item_queue.popleft() - if not item_queue: - self.priority_groups.popitem(index=-1) - - if item != self.FINISH_MESSAGE: - for exp in item: - exp.info["use_count"] += 1 - # Optionally resubmit the item after a cooldown - if self.reuse_cooldown_time is not None: - asyncio.create_task(self.put(item, self.reuse_cooldown_time)) - - return item - - def size(self) -> int: - """ - Get the current number of items in the queue. - - Returns: - int: Number of items currently stored. - """ - return len(self.priority_groups) diff --git a/trinity/buffer/queue.py b/trinity/buffer/queue.py index c2d45029b3..e5d76e57d4 100644 --- a/trinity/buffer/queue.py +++ b/trinity/buffer/queue.py @@ -1,112 +1,191 @@ -"""A queue implemented by Ray Actor.""" +"""Implementation of async queue buffers.""" import asyncio -from copy import deepcopy -from typing import List +from abc import ABC, abstractmethod +from collections import deque +from functools import partial +from typing import List, Optional -import ray +from sortedcontainers import SortedDict -from trinity.buffer.priority_queue import AsyncPriorityQueue -from trinity.buffer.writer.file_writer import JSONWriter -from trinity.buffer.writer.sql_writer import SQLWriter from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import StorageType +from trinity.common.experience import Experience from trinity.utils.log import get_logger +from trinity.utils.registry import Registry +PRIORITY_FUNC = Registry("priority_fn") -def is_database_url(path: str) -> bool: - return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"]) +@PRIORITY_FUNC.register_module("linear_decay") +def linear_decay_priority(item: List[Experience], decay: float = 0.1): + return item[0].info["model_version"] - decay * item[0].info["use_count"] # type: ignore -def is_json_file(path: str) -> bool: - return path.endswith(".json") or path.endswith(".jsonl") +class QueueBuffer(ABC): + @abstractmethod + async def put(self, exps: List[Experience]) -> None: + """Put a list of experiences into the queue.""" -class QueueActor: - """An asyncio.Queue based queue actor.""" + @abstractmethod + async def get(self) -> List[Experience]: + """Get a list of experience from the queue.""" - FINISH_MESSAGE = "$FINISH$" + @abstractmethod + def qsize(self) -> int: + """Get the current size of the queue.""" - def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: - self.logger = get_logger(__name__) - self.config = config - self.capacity = storage_config.capacity + @abstractmethod + async def close(self) -> None: + """Close the queue.""" + + @abstractmethod + def stopped(self) -> bool: + """Check if there is no more data to read.""" + + @classmethod + def get_queue(cls, storage_config: StorageConfig, config: BufferConfig) -> "QueueBuffer": + """Get a queue instance based on the storage configuration.""" + logger = get_logger(__name__) if storage_config.use_priority_queue: reuse_cooldown_time = storage_config.reuse_cooldown_time replay_buffer_kwargs = storage_config.replay_buffer_kwargs - self.queue = AsyncPriorityQueue( - self.capacity, reuse_cooldown_time, **replay_buffer_kwargs + capacity = min(storage_config.capacity, config.read_batch_size * 2) + logger.info( + f"Using AsyncPriorityQueue with capacity {capacity}, reuse_cooldown_time {reuse_cooldown_time}." ) + return AsyncPriorityQueue(capacity, reuse_cooldown_time, **replay_buffer_kwargs) else: - self.queue = asyncio.Queue(self.capacity) - st_config = deepcopy(storage_config) - st_config.wrap_in_ray = False - if st_config.path is not None: - if is_database_url(st_config.path): - st_config.storage_type = StorageType.SQL - self.writer = SQLWriter(st_config, self.config) - elif is_json_file(st_config.path): - st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) - else: - self.logger.warning("Unknown supported storage path: %s", st_config.path) - self.writer = None - else: - st_config.storage_type = StorageType.FILE - self.writer = JSONWriter(st_config, self.config) - self.logger.warning(f"Save experiences in {st_config.path}.") - self.ref_count = 0 - - async def acquire(self) -> int: - self.ref_count += 1 - return self.ref_count - - async def release(self) -> int: - """Release the queue.""" - self.ref_count -= 1 - if self.ref_count <= 0: - await self.queue.put(self.FINISH_MESSAGE) - await self.writer.release() - return self.ref_count - - def length(self) -> int: - """The length of the queue.""" - return self.queue.qsize() - - async def put_batch(self, exp_list: List) -> None: - """Put batch of experience.""" - await self.queue.put(exp_list) - if self.writer is not None: - self.writer.write(exp_list) - - async def get_batch(self, batch_size: int, timeout: float) -> List: - """Get batch of experience.""" - batch = [] - while True: - try: - exp_list = await asyncio.wait_for(self.queue.get(), timeout=timeout) - except asyncio.TimeoutError: - self.logger.error( - f"Timeout when waiting for experience, only get {len(batch)} experiences.\n" - "This phenomenon is usually caused by the workflow not returning enough " - "experiences or running timeout. Please check your workflow implementation." - ) - return batch - if exp_list == self.FINISH_MESSAGE: - raise StopAsyncIteration() - batch.extend(exp_list) - if len(batch) >= batch_size: - break - return batch - - @classmethod - def get_actor(cls, storage_config: StorageConfig, config: BufferConfig): - """Get the queue actor.""" - return ( - ray.remote(cls) - .options( - name=f"queue-{storage_config.name}", - namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, - get_if_exists=True, - ) - .remote(storage_config, config) - ) + return AsyncQueue(capacity=storage_config.capacity) + + +class AsyncQueue(asyncio.Queue, QueueBuffer): + def __init__(self, capacity: int): + """ + Initialize the async queue with a specified capacity. + + Args: + capacity (`int`): The maximum number of items the queue can hold. + """ + super().__init__(maxsize=capacity) + self._closed = False + + async def close(self) -> None: + """Close the queue.""" + self._closed = True + + def stopped(self) -> bool: + """Check if there is no more data to read.""" + return self._closed and self.empty() + + +class AsyncPriorityQueue(QueueBuffer): + """ + An asynchronous priority queue that manages a fixed-size buffer of experience items. + Items are prioritized using a user-defined function and reinserted after a cooldown period. + + Attributes: + capacity (int): Maximum number of items the queue can hold. This value is automatically + adjusted to be at most twice the read batch size. + priority_groups (SortedDict): Maps priorities to deques of items with the same priority. + priority_fn (callable): Function used to determine the priority of an item. + reuse_cooldown_time (float): Delay before reusing an item (set to infinity to disable). + """ + + def __init__( + self, + capacity: int, + reuse_cooldown_time: Optional[float] = None, + priority_fn: str = "linear_decay", + **kwargs, + ): + """ + Initialize the async priority queue. + + Args: + capacity (`int`): The maximum number of items the queue can store. + reuse_cooldown_time (`float`): Time to wait before reusing an item. Set to None to disable reuse. + priority_fn (`str`): Name of the function to use for determining item priority. + kwargs: Additional keyword arguments for the priority function. + """ + self.capacity = capacity + self.priority_groups = SortedDict() # Maps priority -> deque of items + self.priority_fn = partial(PRIORITY_FUNC.get(priority_fn), **kwargs) + self.reuse_cooldown_time = reuse_cooldown_time + self._condition = asyncio.Condition() # For thread-safe operations + self._closed = False + + async def _put(self, item: List[Experience], delay: float = 0) -> None: + """ + Insert an item into the queue, replacing the lowest-priority item if full. + + Args: + item (`List[Experience]`): A list of experiences to add. + delay (`float`): Optional delay before insertion (for simulating timing behavior). + """ + if delay > 0: + await asyncio.sleep(delay) + + priority = self.priority_fn(item=item) + async with self._condition: + if len(self.priority_groups) == self.capacity: + # If full, only insert if new item has higher or equal priority than the lowest + lowest_priority, item_queue = self.priority_groups.peekitem(index=0) + if lowest_priority > priority: + return # Skip insertion if lower priority + # Remove the lowest priority item + item_queue.popleft() + if not item_queue: + self.priority_groups.popitem(index=0) + + # Add the new item + if priority not in self.priority_groups: + self.priority_groups[priority] = deque() + self.priority_groups[priority].append(item) + self._condition.notify() + + async def put(self, item: List[Experience]) -> None: + await self._put(item, delay=0) + + async def get(self) -> List[Experience]: + """ + Retrieve the highest-priority item from the queue. + + Returns: + List[Experience]: The highest-priority item (list of experiences). + + Notes: + - After retrieval, the item is optionally reinserted after a cooldown period. + """ + async with self._condition: + while len(self.priority_groups) == 0: + if self._closed: + raise StopAsyncIteration() + await self._condition.wait() + + _, item_queue = self.priority_groups.peekitem(index=-1) + item = item_queue.popleft() + if not item_queue: + self.priority_groups.popitem(index=-1) + + for exp in item: + exp.info["use_count"] += 1 + # Optionally resubmit the item after a cooldown + if self.reuse_cooldown_time is not None: + asyncio.create_task(self._put(item, self.reuse_cooldown_time)) + + return item + + def qsize(self): + return len(self.priority_groups) + + async def close(self) -> None: + """ + Close the queue. + """ + async with self._condition: + self._closed = True + # No more items will be added, but existing items can still be processed. + self.reuse_cooldown_time = None + self._condition.notify_all() + + def stopped(self) -> bool: + return self._closed and len(self.priority_groups) == 0 diff --git a/trinity/buffer/ray_wrapper.py b/trinity/buffer/ray_wrapper.py index 33a9d8633b..b89b0415fa 100644 --- a/trinity/buffer/ray_wrapper.py +++ b/trinity/buffer/ray_wrapper.py @@ -1,6 +1,10 @@ +"""Ray actor wrapper for different buffers.""" +import asyncio import json import os import time +from collections import deque +from copy import deepcopy from typing import List, Optional import ray @@ -9,10 +13,11 @@ from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import NullPool +from trinity.buffer.queue import QueueBuffer from trinity.buffer.schema import Base, create_dynamic_table from trinity.buffer.utils import default_storage_path, retry_session from trinity.common.config import BufferConfig, StorageConfig -from trinity.common.constants import ReadStrategy +from trinity.common.constants import ReadStrategy, StorageType from trinity.common.experience import Experience from trinity.common.workflows import Task from trinity.utils.log import get_logger @@ -199,3 +204,101 @@ def release(self) -> int: if self.ref_count <= 0: self.file.close() return self.ref_count + + +def is_database_url(path: str) -> bool: + return any(path.startswith(prefix) for prefix in ["sqlite:///", "postgresql://", "mysql://"]) + + +def is_json_file(path: str) -> bool: + return path.endswith(".json") or path.endswith(".jsonl") + + +class QueueWrapper: + """An wrapper of a async queue.""" + + def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None: + self.logger = get_logger(__name__) + self.config = config + self.capacity = storage_config.capacity + self.queue = QueueBuffer.get_queue(storage_config, config) + st_config = deepcopy(storage_config) + st_config.wrap_in_ray = False + if st_config.path is not None: + if is_database_url(st_config.path): + from trinity.buffer.writer.sql_writer import SQLWriter + + st_config.storage_type = StorageType.SQL + self.writer = SQLWriter(st_config, self.config) + elif is_json_file(st_config.path): + from trinity.buffer.writer.file_writer import JSONWriter + + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) + else: + self.logger.warning("Unknown supported storage path: %s", st_config.path) + self.writer = None + else: + from trinity.buffer.writer.file_writer import JSONWriter + + st_config.storage_type = StorageType.FILE + self.writer = JSONWriter(st_config, self.config) + self.logger.warning(f"Save experiences in {st_config.path}.") + self.ref_count = 0 + self.exp_pool = deque() # A pool to store experiences + self.closed = False + + async def acquire(self) -> int: + self.ref_count += 1 + return self.ref_count + + async def release(self) -> int: + """Release the queue.""" + self.ref_count -= 1 + if self.ref_count <= 0: + await self.queue.close() + await self.writer.release() + return self.ref_count + + def length(self) -> int: + """The length of the queue.""" + return self.queue.qsize() + + async def put_batch(self, exp_list: List) -> None: + """Put batch of experience.""" + await self.queue.put(exp_list) + if self.writer is not None: + self.writer.write(exp_list) + + async def get_batch(self, batch_size: int, timeout: float) -> List: + """Get batch of experience.""" + while len(self.exp_pool) < batch_size: + if self.queue.stopped(): + # If the queue is stopped, ignore the rest of the experiences in the pool + raise StopAsyncIteration("Queue is closed and no more items to get.") + try: + exp_list = await asyncio.wait_for(self.queue.get(), timeout=timeout) + except asyncio.TimeoutError: + self.logger.error( + f"Timeout when waiting for experience, only get {len(self.exp_pool)} experiences.\n" + "This phenomenon is usually caused by the workflow not returning enough " + "experiences or running timeout. Please check your workflow implementation." + ) + batch = list(self.exp_pool) + self.exp_pool.clear() + return batch + self.exp_pool.extend(exp_list) + return [self.exp_pool.popleft() for _ in range(batch_size)] + + @classmethod + def get_wrapper(cls, storage_config: StorageConfig, config: BufferConfig): + """Get the queue actor.""" + return ( + ray.remote(cls) + .options( + name=f"queue-{storage_config.name}", + namespace=storage_config.ray_namespace or ray.get_runtime_context().namespace, + get_if_exists=True, + ) + .remote(storage_config, config) + ) diff --git a/trinity/buffer/reader/queue_reader.py b/trinity/buffer/reader/queue_reader.py index e6b65e4e6c..4e59b9e297 100644 --- a/trinity/buffer/reader/queue_reader.py +++ b/trinity/buffer/reader/queue_reader.py @@ -5,7 +5,7 @@ import ray from trinity.buffer.buffer_reader import BufferReader -from trinity.buffer.queue import QueueActor +from trinity.buffer.ray_wrapper import QueueWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import ReadStrategy, StorageType from trinity.utils.log import get_logger @@ -20,7 +20,7 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig): assert storage_config.storage_type == StorageType.QUEUE self.timeout = storage_config.max_read_timeout self.read_batch_size = config.read_batch_size - self.queue = QueueActor.get_actor(storage_config, config) + self.queue = QueueWrapper.get_wrapper(storage_config, config) def read( self, batch_size: Optional[int] = None, strategy: Optional[ReadStrategy] = None diff --git a/trinity/buffer/writer/queue_writer.py b/trinity/buffer/writer/queue_writer.py index 9b13262b80..4b8034d716 100644 --- a/trinity/buffer/writer/queue_writer.py +++ b/trinity/buffer/writer/queue_writer.py @@ -4,7 +4,7 @@ import ray from trinity.buffer.buffer_writer import BufferWriter -from trinity.buffer.queue import QueueActor +from trinity.buffer.ray_wrapper import QueueWrapper from trinity.common.config import BufferConfig, StorageConfig from trinity.common.constants import StorageType from trinity.utils.log import get_logger @@ -18,7 +18,7 @@ class QueueWriter(BufferWriter): def __init__(self, meta: StorageConfig, config: BufferConfig): assert meta.storage_type == StorageType.QUEUE self.config = config - self.queue = QueueActor.get_actor(meta, config) + self.queue = QueueWrapper.get_wrapper(meta, config) def write(self, data: List) -> None: ray.get(self.queue.put_batch.remote(data)) diff --git a/trinity/common/experience.py b/trinity/common/experience.py index 0c5f98e89c..652ff79eb7 100644 --- a/trinity/common/experience.py +++ b/trinity/common/experience.py @@ -25,7 +25,7 @@ class EID(dict): batch: int = 0 # Task number, e.g., the task sequence in the batch, the first task in the batch has task=0 # Automatically set by the workflow runner - task: int = 0 # Task sequence in the batch, e.g., the first task in the batch has task=0 + task: int = 0 # Run id, e.g., the first run in the task has run=0 # User should set this field in custom workflows when creating experiences run: int = 0 @@ -202,10 +202,10 @@ def to_dict(self) -> dict: res["response_text"] = self.response_text if self.messages is not None: res["messages"] = self.messages - if self.chosen is not None: - res["chosen"] = self.chosen.tolist() - if self.rejected is not None: - res["rejected"] = self.rejected.tolist() + if self.chosen_text is not None: + res["chosen_text"] = self.chosen_text + if self.rejected_text is not None: + res["rejected_text"] = self.rejected_text if self.reward is not None: res["reward"] = float(self.reward) return res