Skip to content

Commit cfa7f85

Browse files
authored
Fix priority queue implementation and enhance testing (#135)
1 parent 36344ec commit cfa7f85

File tree

9 files changed

+324
-249
lines changed

9 files changed

+324
-249
lines changed

.github/workflows/unittest.yaml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,8 @@ jobs:
6969
docker compose exec trinity-node-1 pytest tests -v -s --ignore=tests/data --ctrf report.json
7070
echo "tests_run=true" >> $GITHUB_ENV
7171
elif [ "$TYPE" = "diff" ]; then
72-
ROOT_DIR=trinity-${{ github.run_id }}
73-
if [ -s "$ROOT_DIR/test_dirs.txt" ]; then
74-
TEST_DIRS=$(cat "$ROOT_DIR/test_dirs.txt" | xargs)
72+
if [ -s ../../../test_dirs.txt ]; then
73+
TEST_DIRS=$(cat ../../../test_dirs.txt | xargs)
7574
docker compose exec trinity-node-1 pytest $TEST_DIRS -v -s --ignore=tests/data --ctrf report.json
7675
echo "tests_run=true" >> $GITHUB_ENV
7776
else

docs/sphinx_doc/source/tutorial/trinity_programming_guide.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ class Workflow(ABC):
100100

101101
def __init__(
102102
self,
103-
model: ModelWrapper,
103+
*,
104104
task: Task,
105+
model: ModelWrapper,
105106
auxiliary_models: Optional[List[openai.OpenAI]] = None,
106107
):
108+
self.task = task
107109
self.model = model
108110
self.auxiliary_models = auxiliary_models
109111

@@ -116,22 +118,22 @@ class Workflow(ABC):
116118

117119
During initialization, `Workflow` receives the following parameters:
118120

119-
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
120121
- `task`({class}`trinity.common.workflows.Task`): A single data item from the task dataset.
122+
- `model`({class}`trinity.common.models.model.ModelWrapper`): The model being trained, which provides an interface similar to OpenAI, capable of receiving a list of conversation messages and returning content generated by the LLM (including reply text `response_text`, full sequence token ids `tokens`, prompt part token length `prompt_length`, and a list of output token logprobs `logprobs`).
121123
- `auxiliary_models`(`List[openai.OpenAI]`):A list of auxiliary models not involved in training. All are provided via OpenAI-compatible APIs.
122124

123125
```{tip}
124126
You can switch to using the OpenAI API by setting `explorer.rollout_model.enable_openai_api` to `true` in your config file and calling `model.get_openai_client()` to get an `openai.OpenAI` instance in your workflow.
125-
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id`.
127+
And the `model` field when calling openai API can be obtained via `openai_client.models.list().data[0].id` or `openai_client.model_path`.
126128
```
127129

128130
Here's an example of initializing a simple workflow using only `raw_task` and `rollout_args`. In more complex cases, you can use the `format_args` for further customization.
129131

130132
```python
131133
class ExampleWorkflow(Workflow):
132134

133-
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
134-
super().__init__(model, task, auxiliary_models)
135+
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
136+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
135137
self.question = task.raw_task.get("question")
136138
self.answer = task.raw_task.get("answer")
137139
self.rollout_args = task.rollout_args
@@ -244,8 +246,8 @@ class ExampleWorkflow(Workflow):
244246
@WORKFLOWS.register_module("example_workflow")
245247
class ExampleWorkflow(Workflow):
246248

247-
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models: List):
248-
super().__init__(model, task, auxiliary_models)
249+
def __init__(self, task: Task, model: ModelWrapper, auxiliary_models: List):
250+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
249251
self.question = task.raw_task.get("question")
250252
self.answer = task.raw_task.get("answer")
251253
self.rollout_args = task.rollout_args

tests/buffer/queue_test.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import threading
33
import time
44

5+
import ray
56
import torch
67
from parameterized import parameterized
78

@@ -75,27 +76,25 @@ async def test_queue_buffer(self, name, use_priority_queue):
7576
self.assertRaises(StopIteration, reader.read)
7677
with open(BUFFER_FILE_PATH, "r") as f:
7778
self.assertEqual(len(f.readlines()), self.total_num + self.put_batch_size * 2)
78-
st = time.time()
79-
self.assertRaises(TimeoutError, reader.read, batch_size=1)
80-
et = time.time()
81-
self.assertTrue(et - st > 2)
79+
self.assertRaises(StopIteration, reader.read, batch_size=1)
8280

8381
async def test_priority_queue_capacity(self):
8482
# test queue capacity
83+
self.config.read_batch_size = 4
8584
meta = StorageConfig(
8685
name="test_buffer_small",
8786
algorithm_type="ppo",
8887
storage_type=StorageType.QUEUE,
8988
max_read_timeout=1,
90-
capacity=2,
89+
capacity=100, # priority will use 2 * read_batch_size as capacity (8)
9190
path=BUFFER_FILE_PATH,
9291
use_priority_queue=True,
9392
replay_buffer_kwargs={"priority_fn": "linear_decay", "decay": 0.6},
9493
)
9594
writer = QueueWriter(meta, self.config)
9695
reader = QueueReader(meta, self.config)
9796

98-
for i in range(4):
97+
for i in range(12):
9998
writer.write(
10099
[
101100
Experience(
@@ -106,15 +105,34 @@ async def test_priority_queue_capacity(self):
106105
]
107106
)
108107

109-
exps = reader.read(batch_size=2)
110-
self.assertEqual(exps[0].info["model_version"], 3)
108+
self.assertEqual(ray.get(reader.queue.length.remote()), 8)
109+
110+
exps = reader.read(batch_size=8)
111+
self.assertEqual(exps[0].info["model_version"], 11)
111112
self.assertEqual(exps[0].info["use_count"], 1)
112-
self.assertEqual(exps[1].info["model_version"], 2)
113+
self.assertEqual(exps[1].info["model_version"], 10)
113114
self.assertEqual(exps[1].info["use_count"], 1)
115+
self.assertEqual(exps[7].info["model_version"], 4)
114116

115117
with self.assertRaises(TimeoutError):
116118
reader.read(batch_size=1)
117119

120+
for i in range(12):
121+
writer.write(
122+
[
123+
Experience(
124+
tokens=torch.tensor([1, 2, 3]),
125+
prompt_length=2,
126+
info={"model_version": i, "use_count": 0},
127+
),
128+
]
129+
)
130+
await writer.release()
131+
exps = reader.read(batch_size=8)
132+
133+
with self.assertRaises(StopIteration):
134+
reader.read(batch_size=1)
135+
118136
async def test_queue_buffer_capacity(self):
119137
# test queue capacity
120138
meta = StorageConfig(

trinity/buffer/priority_queue.py

Lines changed: 0 additions & 126 deletions
This file was deleted.

0 commit comments

Comments
 (0)