Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/sphinx_doc/source/tutorial/example_mix_algo.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ We also need to add an `is_expert_mask` field when transforming to DataProto to
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.run_ids),
"uid": np.array(experiences.group_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
2 changes: 1 addition & 1 deletion tests/buffer/queue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def test_queue_buffer(self):
with open(BUFFER_FILE_PATH, "r") as f:
self.assertEqual(len(f.readlines()), total_num + put_batch_size * 2)
st = time.time()
self.assertRaises(StopIteration, reader.read, batch_size=1)
self.assertRaises(TimeoutError, reader.read, batch_size=1)
et = time.time()
self.assertTrue(et - st > 2)

Expand Down
76 changes: 70 additions & 6 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from tests.tools import get_template_config
from trinity.buffer.reader.queue_reader import QueueReader
from trinity.common.config import StorageConfig
from trinity.common.config import GenerationConfig, StorageConfig
from trinity.common.constants import StorageType
from trinity.common.experience import Experience
from trinity.common.models.model import InferenceModel
Expand All @@ -23,6 +23,7 @@ def __init__(self, model, task, auxiliary_models):
super().__init__(model, task, auxiliary_models)
self.error_type = task.raw_task.get("error_type", "")
self.seconds = None
self.repeat_times = task.rollout_args.n
if "timeout" in self.error_type:
parts = self.error_type.split("_")
if len(parts) > 1:
Expand All @@ -42,8 +43,12 @@ def run(self) -> List[Experience]:

return [
Experience(
tokens=torch.zeros(5), prompt_length=2, prompt_text=self.error_type or "success"
tokens=torch.zeros(5),
prompt_length=2,
prompt_text=self.error_type or "success",
info={"repeat_times": self.repeat_times},
)
for _ in range(self.repeat_times)
]


Expand Down Expand Up @@ -98,7 +103,11 @@ def api_server_ready(self) -> Tuple[str, str]:


def generate_tasks(
total_num: int, timeout_num: int = 0, exception_num: int = 0, timeout_seconds: int = 10
total_num: int,
timeout_num: int = 0,
exception_num: int = 0,
timeout_seconds: int = 10,
repeat_times: int = 1,
):
"""Generate some tasks for testing

Expand All @@ -108,7 +117,10 @@ def generate_tasks(
exception_num: number of exception tasks
timeout_seconds: the timeout for timeout tasks
"""
tasks = [Task(workflow=DummyWorkflow, raw_task={}) for _ in range(total_num)]
tasks = [
Task(workflow=DummyWorkflow, raw_task={}, rollout_args=GenerationConfig(n=repeat_times))
for _ in range(total_num)
]

tasks.extend(
[
Expand Down Expand Up @@ -150,6 +162,9 @@ def setUp(self):
algorithm_type="ppo",
path="",
)
self.config.buffer.trainer_input.experience_buffer.max_read_timeout = 1
self.config.algorithm.repeat_times = 1
self.config.check_and_update()
self.queue = QueueReader(
self.config.buffer.trainer_input.experience_buffer, self.config.buffer
)
Expand All @@ -163,6 +178,9 @@ async def test_get_results(self):

results = await scheduler.get_results(batch_id=0, min_num=8, timeout=20)
self.assertEqual(len(results), 8)
self.assertEqual(len(self.queue.read(batch_size=8)), 8)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

for result in results:
self.assertTrue(result.ok)
Expand All @@ -176,13 +194,17 @@ async def test_get_results(self):
results = await scheduler.get_results(batch_id=batch_id, min_num=4, timeout=10)
self.assertEqual(len(results), 4)
self.assertFalse(scheduler.has_step(batch_id))
self.assertEqual(len(self.queue.read(batch_size=4)), 4)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

tasks = generate_tasks(3)
scheduler.schedule(tasks, batch_id=4)
self.assertTrue(scheduler.has_step(4))
results = await scheduler.get_results(batch_id=4)
self.assertEqual(len(results), 3)
self.assertFalse(scheduler.has_step(4))
self.assertEqual(len(self.queue.read(batch_size=3)), 3)

# test timeout
tasks = generate_tasks(2, timeout_num=2, timeout_seconds=10)
Expand All @@ -194,6 +216,7 @@ async def test_get_results(self):

self.assertLessEqual(end_time - start_time, 5)
self.assertEqual(len(results), 2)
self.assertEqual(len(self.queue.read(batch_size=2)), 2)

# test run tasks after timeout
tasks = generate_tasks(4)
Expand All @@ -204,8 +227,10 @@ async def test_get_results(self):
self.assertEqual(len(results), 4)

success_count = sum(1 for r in results if r.ok)

self.assertEqual(success_count, sum(1 for r in results if r.ok))
self.assertEqual(success_count, 4)
self.assertEqual(len(self.queue.read(batch_size=4)), 4)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

# test exception tasks
tasks = generate_tasks(1, exception_num=3)
Expand All @@ -215,14 +240,21 @@ async def test_get_results(self):

success_count = sum(1 for r in results if r.ok)
self.assertEqual(success_count, 1)
self.assertEqual(len(self.queue.read(batch_size=1)), 1)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

# test clear_timeout_tasks
tasks = generate_tasks(3, timeout_num=1, timeout_seconds=3)
scheduler.schedule(tasks, batch_id=2)
results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False)
self.assertEqual(len(results), 3)
self.assertEqual(len(self.queue.read(batch_size=3)), 3)
results = await scheduler.get_results(batch_id=2, timeout=2, clear_timeout_tasks=False)
self.assertEqual(len(results), 1)
self.assertEqual(len(self.queue.read(batch_size=1)), 1)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

await scheduler.stop()

Expand Down Expand Up @@ -366,6 +398,38 @@ async def test_scheduler_all_methods(self):
self.assertFalse(scheduler.has_step(2))
await scheduler.stop()

async def test_split_tasks(self):
self.config.explorer.max_repeat_times_per_runner = 2
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()

tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
self.assertEqual(len(results), 4 * 4)
self.assertEqual(len(self.queue.read(batch_size=4 * 8)), 4 * 8)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
self.assertEqual(len(results), 4 * 3)
self.assertEqual(len(self.queue.read(batch_size=4 * 5)), 4 * 5)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
self.assertEqual(len(results), 3 * 1)
self.assertEqual(len(self.queue.read(batch_size=3 * 1)), 3 * 1)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

await scheduler.stop()

def tearDown(self):
try:
ray.shutdown()
Expand Down
2 changes: 1 addition & 1 deletion trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) ->
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.run_ids),
"uid": np.array(experiences.group_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
2 changes: 1 addition & 1 deletion trinity/algorithm/sample_strategy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
cumsum = torch.cumsum(attention_mask, dim=-1)
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.run_ids),
"uid": np.array(experiences.group_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
4 changes: 3 additions & 1 deletion trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def read(
batch_size = batch_size or self.read_batch_size
exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout))
if len(exps) != batch_size:
raise StopIteration("Read incomplete batch, please check your workflow.")
raise TimeoutError(
f"Read incomplete batch ({len(exps)}/{batch_size}), please check your workflow."
)
except StopAsyncIteration:
raise StopIteration()
return exps
6 changes: 4 additions & 2 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,9 @@ class ExplorerConfig:
max_timeout: int = 1800 # wait each task for 30 minutes
max_retry_times: int = 2 # retry each task for 2 times if it fails or timeout
env_vars: dict = field(default_factory=dict) # environment variables for workflow runner
max_repeat_times_per_runner: Optional[
int
] = None # the number of time to repeat each task in a single workflow runner (for GRPO-like algorithms)

runner_num: Optional[int] = None # deprecated

Expand Down Expand Up @@ -358,9 +361,8 @@ class MonitorConfig:

@dataclass
class SynchronizerConfig:
"""Configs for model weight synchronization"""
"""Configs for model weight synchronization."""

# TODO: rename to "checkpoint", "nccl", "ipc"
sync_method: SyncMethod = SyncMethod.NCCL
# sync weights every `sync_interval` steps
sync_interval: int = 1
Expand Down
16 changes: 8 additions & 8 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Experience:
rejected: Optional[Tensor] = None # for dpo
info: Optional[dict] = None
metrics: Optional[dict[str, float]] = None
run_id: str = ""
group_id: str = "" # for grpo

def __post_init__(self):
if self.action_mask is not None:
Expand Down Expand Up @@ -95,7 +95,7 @@ class Experiences:
action_masks: Optional[Tensor]
prompt_length: int
logprobs: Optional[Tensor]
run_ids: List[str]
group_ids: List[str]

@property
def batch_size(self) -> int:
Expand All @@ -118,11 +118,11 @@ def gather_experiences(
action_masks=torch.empty(0, dtype=torch.bool),
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
run_ids=[],
group_ids=[],
)
max_prompt_length = max([exp.prompt_length for exp in experiences])
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences])
run_ids = [exp.run_id for exp in experiences]
group_ids = [exp.group_id for exp in experiences]
tokens_dtype = experiences[0].tokens.dtype
tokens = torch.stack(
[
Expand Down Expand Up @@ -208,7 +208,7 @@ def gather_experiences(
logprobs = None

return cls(
run_ids=run_ids,
group_ids=group_ids,
tokens=tokens,
rewards=rewards,
attention_masks=attention_masks,
Expand Down Expand Up @@ -249,7 +249,7 @@ def gather_dpo_experiences(
action_masks=torch.empty(0, dtype=torch.bool),
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
run_ids=[],
group_ids=[],
)

# TODO: exp.tokens in DPO are prompt tokens
Expand All @@ -261,7 +261,7 @@ def gather_dpo_experiences(
response_tokens = list(chain.from_iterable(zip(chosen_tokens, rejected_tokens)))
max_response_length = max([len(response) for response in response_tokens]) # type: ignore

run_ids = list(chain.from_iterable([repeat(exp.run_id, 2) for exp in experiences]))
group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences]))
tokens_dtype = experiences[0].tokens.dtype
tokens = torch.stack(
[
Expand Down Expand Up @@ -297,7 +297,7 @@ def gather_dpo_experiences(
assert len(tokens) == 2 * len(experiences)

return cls(
run_ids=run_ids,
group_ids=group_ids,
tokens=tokens,
attention_masks=attention_masks,
prompt_length=max_prompt_length,
Expand Down
2 changes: 2 additions & 0 deletions trinity/common/workflows/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class Task:
reward_fn: Optional[Type[RewardFn]] = None
raw_task: Optional[dict] = None # The raw data sample

group_id: Optional[str] = None # for GRPO-like algorithms, automatically assigned

def to_workflow(
self, model: Any, auxiliary_models: Optional[List[openai.OpenAI]] = None
) -> Workflow:
Expand Down
Loading