Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/unittest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ jobs:
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
echo "tests_run=true" >> $GITHUB_ENV
elif [ "$TYPE" = "diff" ]; then
ROOT_DIR=trinity-${{ github.run_id }}
if [ -s "$ROOT_DIR/test_dirs.txt" ]; then
TEST_DIRS=$(cat "$ROOT_DIR/test_dirs.txt" | xargs)
if [ -s ../../../test_dirs.txt ]; then
TEST_DIRS=$(cat ../../../test_dirs.txt | xargs)
docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json
echo "tests_run=true" >> $GITHUB_ENV
else
Expand Down
16 changes: 9 additions & 7 deletions docs/sphinx_doc/source/tutorial/trinity_programming_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,12 @@ class Workflow(ABC):

def __init__(
self,
model: ModelWrapper,
*,
task: Task,
model: ModelWrapper,
auxiliary_models: Optional[List[openai.OpenAI]] = None,
):
self.task = task
self.model = model
self.auxiliary_models = auxiliary_models

Expand All @@ -116,22 +118,22 @@ class Workflow(ABC):

During initialization, `Workflow` receives the following parameters:

- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.

```{tip}
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`.
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id` or `openai_client.model_path`.
```

Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.

```python
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
Expand Down Expand Up @@ -244,8 +246,8 @@ class ExampleWorkflow(Workflow):
@WORKFLOWS.register_module("example_workflow")
class ExampleWorkflow(Workflow):

def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
super().__init__(model, task, auxiliary_models)
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
self.question = task.raw_task.get("question")
self.answer = task.raw_task.get("answer")
self.rollout_args = task.rollout_args
Expand Down
36 changes: 27 additions & 9 deletions tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import threading
import time

import ray
import torch
from parameterized import parameterized

Expand Down Expand Up @@ -75,27 +76,25 @@ async def test_queue_buffer(self, name, use_priority_queue):
self.assertRaises(StopIteration, reader.read)
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), self.total_num + self.put_batch_size * 2)
st = time.time()
self.assertRaises(TimeoutError, reader.read, batch_size=1)
et = time.time()
self.assertTrue(et - st > 2)
self.assertRaises(StopIteration, reader.read, batch_size=1)

async def test_priority_queue_capacity(self):
# test queue capacity
self.config.read_batch_size = 4
meta = StorageConfig(
name="test_buffer_small",
algorithm_type="ppo",
storage_type=StorageType.QUEUE,
max_read_timeout=1,
capacity=2,
capacity=100, # priority will use 2 * read_batch_size as capacity (8)
path=BUFFER_FILE_PATH,
use_priority_queue=True,
replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6},
)
writer = QueueWriter(meta, self.config)
reader = QueueReader(meta, self.config)

for i in range(4):
for i in range(12):
writer.write(
[
Experience(
Expand All @@ -106,15 +105,34 @@ async def test_priority_queue_capacity(self):
]
)

exps = reader.read(batch_size=2)
self.assertEqual(exps[0].info["model_version"], 3)
self.assertEqual(ray.get(reader.queue.length.remote()), 8)

exps = reader.read(batch_size=8)
self.assertEqual(exps[0].info["model_version"], 11)
self.assertEqual(exps[0].info["use_count"], 1)
self.assertEqual(exps[1].info["model_version"], 2)
self.assertEqual(exps[1].info["model_version"], 10)
self.assertEqual(exps[1].info["use_count"], 1)
self.assertEqual(exps[7].info["model_version"], 4)

with self.assertRaises(TimeoutError):
reader.read(batch_size=1)

for i in range(12):
writer.write(
[
Experience(
tokens=torch.tensor([1, 2, 3]),
prompt_length=2,
info={"model_version": i, "use_count": 0},
),
]
)
await writer.release()
exps = reader.read(batch_size=8)

with self.assertRaises(StopIteration):
reader.read(batch_size=1)

async def test_queue_buffer_capacity(self):
# test queue capacity
meta = StorageConfig(
Expand Down
126 changes: 0 additions & 126 deletions trinity/buffer/priority_queue.py

This file was deleted.

Loading