Skip to content

Commit 6b109b0

Browse files
authored
Add unique_id for each experience (#120)
1 parent 63d4920 commit 6b109b0

File tree

9 files changed

+45
-12
lines changed

9 files changed

+45
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ agent = [
5454
"agentscope"
5555
]
5656
rm_gallery = [
57-
"rm-gallery"
57+
"rm-gallery>=0.1.1"
5858
]
5959
dev = [
6060
"pre-commit>=2.17.0",

tests/explorer/scheduler_test.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,31 +403,44 @@ async def test_split_tasks(self):
403403
self.config.check_and_update()
404404
scheduler = Scheduler(self.config, [DummyModel.remote(), DummyModel.remote()])
405405
await scheduler.start()
406+
exp_list = []
406407

407408
tasks = generate_tasks(4, repeat_times=8) # ceil(8 / 2) == 4
408409
scheduler.schedule(tasks, batch_id=1)
409410
results = await scheduler.get_results(batch_id=1)
410411
self.assertEqual(len(results), 4 * 4)
411-
self.assertEqual(len(self.queue.read(batch_size=4 * 8)), 4 * 8)
412+
exps = self.queue.read(batch_size=4 * 8)
413+
self.assertEqual(len(exps), 4 * 8)
414+
exp_list.extend(exps)
412415
with self.assertRaises(TimeoutError):
413416
self.queue.read(batch_size=1)
414417

415418
tasks = generate_tasks(4, repeat_times=5) # ceil(5 / 2) == 3
416-
scheduler.schedule(tasks, batch_id=1)
417-
results = await scheduler.get_results(batch_id=1)
419+
scheduler.schedule(tasks, batch_id=2)
420+
results = await scheduler.get_results(batch_id=2)
418421
self.assertEqual(len(results), 4 * 3)
419-
self.assertEqual(len(self.queue.read(batch_size=4 * 5)), 4 * 5)
422+
exps = self.queue.read(batch_size=4 * 5)
423+
self.assertEqual(len(exps), 4 * 5)
424+
exp_list.extend(exps)
420425
with self.assertRaises(TimeoutError):
421426
self.queue.read(batch_size=1)
422427

423428
tasks = generate_tasks(3, repeat_times=1) # ceil(1 / 2) == 1
424-
scheduler.schedule(tasks, batch_id=1)
425-
results = await scheduler.get_results(batch_id=1)
429+
scheduler.schedule(tasks, batch_id=3)
430+
results = await scheduler.get_results(batch_id=3)
426431
self.assertEqual(len(results), 3 * 1)
427-
self.assertEqual(len(self.queue.read(batch_size=3 * 1)), 3 * 1)
432+
exps = self.queue.read(batch_size=3 * 1)
433+
self.assertEqual(len(exps), 3 * 1)
434+
exp_list.extend(exps)
428435
with self.assertRaises(TimeoutError):
429436
self.queue.read(batch_size=1)
430437

438+
# test group_id and unique_id
439+
group_ids = [exp.group_id for exp in exp_list]
440+
self.assertEqual(len(set(group_ids)), 11) # 4 + 4 + 3
441+
unique_ids = [exp.unique_id for exp in exp_list]
442+
self.assertEqual(len(unique_ids), len(set(unique_ids)))
443+
431444
await scheduler.stop()
432445

433446
def tearDown(self):

tests/explorer/workflow_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class MockResponse:
2222
reward: float = 0.0
2323
metrics: Optional[Dict[str, float]] = None
2424
info: Optional[Dict] = None
25+
unique_id: Optional[str] = "0"
2526

2627

2728
class DummyWorkflow(Workflow):
@@ -237,7 +238,6 @@ def test_gsm8k_workflow(self) -> None:
237238
self.assertEqual(experiences[2].reward, -0.1)
238239
self.assertEqual(experiences[3].reward, 1.1)
239240

240-
@unittest.skip("Skip for now, need to fix import issues of RM-Gallery")
241241
def test_rm_gallery_workflow(self) -> None:
242242
model = MagicMock()
243243
model.chat.return_value = [

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def to_data_proto_mix(experiences: Experiences, is_expert_mask: torch.tensor) ->
9191
position_ids = torch.clip(cumsum - 1, 0, None).long()
9292
batch_dict = {
9393
"uid": np.array(experiences.group_ids),
94+
"unique_ids": np.array(experiences.unique_ids),
9495
"position_ids": position_ids,
9596
"input_ids": experiences.tokens.long(),
9697
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

trinity/algorithm/sample_strategy/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def to_data_proto(experiences: Experiences) -> DataProto:
1414
position_ids = torch.clip(cumsum - 1, 0, None).long()
1515
batch_dict = {
1616
"uid": np.array(experiences.group_ids),
17+
"unique_ids": np.array(experiences.unique_ids),
1718
"position_ids": position_ids,
1819
"input_ids": experiences.tokens.long(),
1920
"responses": experiences.tokens[:, experiences.prompt_length :].long(),

trinity/common/experience.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class Experience:
2727
info: Optional[dict] = None
2828
metrics: Optional[dict[str, float]] = None
2929
group_id: str = "" # for grpo
30+
unique_id: str = ""
3031

3132
def __post_init__(self):
3233
if self.action_mask is not None:
@@ -96,6 +97,7 @@ class Experiences:
9697
prompt_length: int
9798
logprobs: Optional[Tensor]
9899
group_ids: List[str]
100+
unique_ids: List[str]
99101

100102
@property
101103
def batch_size(self) -> int:
@@ -119,10 +121,12 @@ def gather_experiences(
119121
logprobs=torch.empty(0, dtype=torch.float32),
120122
prompt_length=torch.empty(0, dtype=torch.int32),
121123
group_ids=[],
124+
unique_ids=[],
122125
)
123126
max_prompt_length = max([exp.prompt_length for exp in experiences])
124127
max_response_length = max([len(exp.tokens) - exp.prompt_length for exp in experiences])
125128
group_ids = [exp.group_id for exp in experiences]
129+
unique_ids = [exp.unique_id for exp in experiences]
126130
tokens_dtype = experiences[0].tokens.dtype
127131
tokens = torch.stack(
128132
[
@@ -209,6 +213,7 @@ def gather_experiences(
209213

210214
return cls(
211215
group_ids=group_ids,
216+
unique_ids=unique_ids,
212217
tokens=tokens,
213218
rewards=rewards,
214219
attention_masks=attention_masks,
@@ -250,6 +255,7 @@ def gather_dpo_experiences(
250255
logprobs=torch.empty(0, dtype=torch.float32),
251256
prompt_length=torch.empty(0, dtype=torch.int32),
252257
group_ids=[],
258+
unique_ids=[],
253259
)
254260

255261
# TODO: exp.tokens in DPO are prompt tokens
@@ -262,6 +268,11 @@ def gather_dpo_experiences(
262268
max_response_length = max([len(response) for response in response_tokens]) # type: ignore
263269

264270
group_ids = list(chain.from_iterable([repeat(exp.group_id, 2) for exp in experiences]))
271+
unique_ids = list(
272+
chain.from_iterable(
273+
[(f"{exp.unique_id}/1", f"{exp.unique_id}/0") for exp in experiences]
274+
)
275+
)
265276
tokens_dtype = experiences[0].tokens.dtype
266277
tokens = torch.stack(
267278
[
@@ -298,6 +309,7 @@ def gather_dpo_experiences(
298309

299310
return cls(
300311
group_ids=group_ids,
312+
unique_ids=unique_ids,
301313
tokens=tokens,
302314
attention_masks=attention_masks,
303315
prompt_length=max_prompt_length,

trinity/common/rewards/reward_fn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def _build_sample_from_experience(
6969
]
7070

7171
sample = DataSample(
72-
unique_id="0", # TODO: Generate unique ID
72+
unique_id=experience.unique_id,
7373
input=to_rm_gallery_messages(messages),
7474
output=output,
7575
metadata=experience.info,

trinity/explorer/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def _create_runner(self):
5656
"env_vars": self.config.explorer.env_vars,
5757
},
5858
)
59-
.remote(self.config, self.rollout_model, self.auxiliary_models)
59+
.remote(self.config, self.rollout_model, self.auxiliary_models, self.runner_id)
6060
)
6161

6262
async def run_with_retry(self, task: TaskWrapper) -> Tuple[Status, int]:

trinity/explorer/workflow_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""The Workflow Runner Moudle."""
33
import time
44
import traceback
5+
import uuid
56
from collections import defaultdict
67
from dataclasses import dataclass
78
from typing import List, Optional
@@ -31,6 +32,7 @@ def __init__(
3132
config: Config,
3233
model: InferenceModel,
3334
auxiliary_models: Optional[List[InferenceModel]] = None,
35+
runner_id: Optional[int] = None,
3436
) -> None:
3537
self.config = config
3638
self.experience_buffer = get_buffer_writer(
@@ -52,6 +54,7 @@ def __init__(
5254
self.auxiliary_models.append(api_client)
5355
self.logger = get_logger(__name__)
5456
self.workflow_instance = None
57+
self.runner_id = runner_id
5558

5659
def is_alive(self):
5760
return True
@@ -78,8 +81,11 @@ def run_task(self, task: Task) -> Status:
7881
assert exps is not None and len(exps) > 0, "An empty experience is generated"
7982
metrics: dict[str, List[float]] = defaultdict(list)
8083
# set group id
81-
for exp in exps:
84+
for idx, exp in enumerate(exps):
8285
setattr(exp, "group_id", task.group_id)
86+
setattr(
87+
exp, "unique_id", f"{task.group_id}/{self.runner_id}/{str(uuid.uuid4())[:6]}"
88+
)
8389

8490
if not hasattr(exp, "info") or exp.info is None:
8591
exp.info = {}

0 commit comments

Comments
 (0)