Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ agent = [
"agentscope"
]
rm_gallery = [
"rm-gallery"
"rm-gallery>=0.1.1"
]
dev = [
"pre-commit>=2.17.0",
Expand Down
27 changes: 20 additions & 7 deletions tests/explorer/scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,31 +403,44 @@ async def test_split_tasks(self):
self.config.check_and_update()
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
await scheduler.start()
exp_list = []

tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
self.assertEqual(len(results), 4 * 4)
self.assertEqual(len(self.queue.read(batch_size=4 * 8)), 4 * 8)
exps = self.queue.read(batch_size=4 * 8)
self.assertEqual(len(exps), 4 * 8)
exp_list.extend(exps)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
scheduler.schedule(tasks, batch_id=2)
results = await scheduler.get_results(batch_id=2)
self.assertEqual(len(results), 4 * 3)
self.assertEqual(len(self.queue.read(batch_size=4 * 5)), 4 * 5)
exps = self.queue.read(batch_size=4 * 5)
self.assertEqual(len(exps), 4 * 5)
exp_list.extend(exps)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
scheduler.schedule(tasks, batch_id=1)
results = await scheduler.get_results(batch_id=1)
scheduler.schedule(tasks, batch_id=3)
results = await scheduler.get_results(batch_id=3)
self.assertEqual(len(results), 3 * 1)
self.assertEqual(len(self.queue.read(batch_size=3 * 1)), 3 * 1)
exps = self.queue.read(batch_size=3 * 1)
self.assertEqual(len(exps), 3 * 1)
exp_list.extend(exps)
with self.assertRaises(TimeoutError):
self.queue.read(batch_size=1)

# test group_id and unique_id
group_ids = [exp.group_id for exp in exp_list]
self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3
unique_ids = [exp.unique_id for exp in exp_list]
self.assertEqual(len(unique_ids), len(set(unique_ids)))

await scheduler.stop()

def tearDown(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/explorer/workflow_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class MockResponse:
reward: float = 0.0
metrics: Optional[Dict[str, float]] = None
info: Optional[Dict] = None
unique_id: Optional[str] = "0"


class DummyWorkflow(Workflow):
Expand Down Expand Up @@ -237,7 +238,6 @@ def test_gsm8k_workflow(self) -> None:
self.assertEqual(experiences[2].reward, -0.1)
self.assertEqual(experiences[3].reward, 1.1)

@unittest.skip("Skip for now, need to fix import issues of RM-Gallery")
def test_rm_gallery_workflow(self) -> None:
model = MagicMock()
model.chat.return_value = [
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) ->
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.group_ids),
"unique_ids": np.array(experiences.unique_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
1 change: 1 addition & 0 deletions trinity/algorithm/sample_strategy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
position_ids = torch.clip(cumsum - 1, 0, None).long()
batch_dict = {
"uid": np.array(experiences.group_ids),
"unique_ids": np.array(experiences.unique_ids),
"position_ids": position_ids,
"input_ids": experiences.tokens.long(),
"responses": experiences.tokens[:, experiences.prompt_length :].long(),
Expand Down
12 changes: 12 additions & 0 deletions trinity/common/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class Experience:
info: Optional[dict] = None
metrics: Optional[dict[str, float]] = None
group_id: str = "" # for grpo
unique_id: str = ""

def __post_init__(self):
if self.action_mask is not None:
Expand Down Expand Up @@ -96,6 +97,7 @@ class Experiences:
prompt_length: int
logprobs: Optional[Tensor]
group_ids: List[str]
unique_ids: List[str]

@property
def batch_size(self) -> int:
Expand All @@ -119,10 +121,12 @@ def gather_experiences(
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
group_ids=[],
unique_ids=[],
)
max_prompt_length = max([exp.prompt_length for exp in experiences])
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences])
group_ids = [exp.group_id for exp in experiences]
unique_ids = [exp.unique_id for exp in experiences]
tokens_dtype = experiences[0].tokens.dtype
tokens = torch.stack(
[
Expand Down Expand Up @@ -209,6 +213,7 @@ def gather_experiences(

return cls(
group_ids=group_ids,
unique_ids=unique_ids,
tokens=tokens,
rewards=rewards,
attention_masks=attention_masks,
Expand Down Expand Up @@ -250,6 +255,7 @@ def gather_dpo_experiences(
logprobs=torch.empty(0, dtype=torch.float32),
prompt_length=torch.empty(0, dtype=torch.int32),
group_ids=[],
unique_ids=[],
)

# TODO: exp.tokens in DPO are prompt tokens
Expand All @@ -262,6 +268,11 @@ def gather_dpo_experiences(
max_response_length = max([len(response) for response in response_tokens]) # type: ignore

group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences]))
unique_ids = list(
chain.from_iterable(
[(f"{exp.unique_id}/1", f"{exp.unique_id}/0") for exp in experiences]
)
)
tokens_dtype = experiences[0].tokens.dtype
tokens = torch.stack(
[
Expand Down Expand Up @@ -298,6 +309,7 @@ def gather_dpo_experiences(

return cls(
group_ids=group_ids,
unique_ids=unique_ids,
tokens=tokens,
attention_masks=attention_masks,
prompt_length=max_prompt_length,
Expand Down
2 changes: 1 addition & 1 deletion trinity/common/rewards/reward_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _build_sample_from_experience(
]

sample = DataSample(
unique_id="0", # TODO: Generate unique ID
unique_id=experience.unique_id,
input=to_rm_gallery_messages(messages),
output=output,
metadata=experience.info,
Expand Down
2 changes: 1 addition & 1 deletion trinity/explorer/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _create_runner(self):
"env_vars": self.config.explorer.env_vars,
},
)
.remote(self.config, self.rollout_model, self.auxiliary_models)
.remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id)
)

async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]:
Expand Down
8 changes: 7 additions & 1 deletion trinity/explorer/workflow_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""The Workflow Runner Moudle."""
import time
import traceback
import uuid
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Optional
Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
config: Config,
model: InferenceModel,
auxiliary_models: Optional[List[InferenceModel]] = None,
runner_id: Optional[int] = None,
) -> None:
self.config = config
self.experience_buffer = get_buffer_writer(
Expand All @@ -52,6 +54,7 @@ def __init__(
self.auxiliary_models.append(api_client)
self.logger = get_logger(__name__)
self.workflow_instance = None
self.runner_id = runner_id

def is_alive(self):
return True
Expand All @@ -78,8 +81,11 @@ def run_task(self, task: Task) -> Status:
assert exps is not None and len(exps) > 0, "An empty experience is generated"
metrics: dict[str, List[float]] = defaultdict(list)
# set group id
for exp in exps:
for idx, exp in enumerate(exps):
setattr(exp, "group_id", task.group_id)
setattr(
exp, "unique_id", f"{task.group_id}/{self.runner_id}/{str(uuid.uuid4())[:6]}"
)

if not hasattr(exp, "info") or exp.info is None:
exp.info = {}
Expand Down