Skip to content

Commit 5565309

Browse files
committed
Refactor Experiences to List[Experience]
1 parent d1d450c commit 5565309

File tree

9 files changed

+152
-129
lines changed

9 files changed

+152
-129
lines changed

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,7 @@ def tearDown(self):
13251325

13261326

13271327
class TestTinkerTrainer(BaseTrainerCase):
1328-
@unittest.skip("Require tinker API key")
1328+
# @unittest.skip("Require tinker API key")
13291329
def test_trainer(self):
13301330
"""Test GSM8K on tinker."""
13311331
# test both mode

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from trinity.algorithm.sample_strategy.utils import representative_sample
99
from trinity.buffer import get_buffer_reader
1010
from trinity.common.config import BufferConfig
11-
from trinity.common.experience import CustomField, Experiences
11+
from trinity.common.experience import CustomField, Experience
1212
from trinity.utils.timer import Timer
1313

1414

@@ -53,7 +53,7 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
5353
expert_buffer_config,
5454
)
5555

56-
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
56+
async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]:
5757
metrics = {}
5858
with Timer(metrics, "time/read_experience"):
5959
usual_exp_list = await self.usual_exp_buffer.read_async()
@@ -82,24 +82,21 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
8282
repr_samples = representative_sample(exp_list)
8383

8484
self.set_model_version_metric(exp_list, metrics)
85-
with Timer(metrics, "time/gather_experience"):
86-
exps = Experiences.gather_experiences(
87-
experiences=exp_list,
88-
pad_token_id=self.pad_token_id, # type: ignore [arg-type]
89-
custom_fields=[
90-
CustomField(
91-
source_field="is_expert",
92-
destination_field="expert_mask",
93-
data_type=torch.bool,
94-
),
95-
CustomField(
96-
source_field="step",
97-
destination_field="step",
98-
data_type=torch.int32,
99-
),
100-
],
101-
) # type: ignore
102-
return exps, metrics, repr_samples
85+
custom_fields = [
86+
CustomField(
87+
source_field="is_expert",
88+
destination_field="expert_mask",
89+
data_type=torch.bool,
90+
),
91+
CustomField(
92+
source_field="step",
93+
destination_field="step",
94+
data_type=torch.int32,
95+
),
96+
]
97+
for exp in exp_list:
98+
exp.custom_fields = custom_fields
99+
return exp_list, metrics, repr_samples
103100

104101
@classmethod
105102
def default_args(cls) -> Dict:

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
from trinity.algorithm.sample_strategy.utils import representative_sample
55
from trinity.buffer import get_buffer_reader
66
from trinity.common.config import BufferConfig
7-
from trinity.common.experience import Experience, Experiences
7+
from trinity.common.experience import Experience
88
from trinity.utils.annotations import Deprecated
99
from trinity.utils.monitor import gather_metrics
1010
from trinity.utils.timer import Timer
1111

1212

1313
class SampleStrategy(ABC):
1414
def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
15-
self.pad_token_id = buffer_config.pad_token_id
15+
pass
1616

1717
def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
1818
metric_list = [
@@ -23,14 +23,14 @@ def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
2323
metrics.update(gather_metrics(metric_list, "sample"))
2424

2525
@abstractmethod
26-
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
26+
async def sample(self, step: int) -> Tuple[List[Experience], Dict, List]:
2727
"""Sample data from buffer.
2828
2929
Args:
3030
step (`int`): The step number of current step.
3131
3232
Returns:
33-
`Experiences`: The sampled Experiences data.
33+
`List[Experience]`: The sampled List[Experience] data.
3434
`Dict`: Metrics for logging.
3535
`List`: Representative data for logging.
3636
"""
@@ -54,15 +54,13 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
5454
super().__init__(buffer_config)
5555
self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]
5656

57-
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
57+
async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]:
5858
metrics = {}
5959
with Timer(metrics, "time/read_experience"):
6060
exp_list = await self.exp_buffer.read_async()
6161
repr_samples = representative_sample(exp_list)
6262
self.set_model_version_metric(exp_list, metrics)
63-
with Timer(metrics, "time/gather_experience"):
64-
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
65-
return exps, metrics, repr_samples
63+
return exp_list, metrics, repr_samples
6664

6765
@classmethod
6866
def default_args(cls) -> dict:
@@ -81,16 +79,14 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
8179
super().__init__(buffer_config)
8280
self.max_staleness = kwargs.get("max_staleness", float("inf"))
8381

84-
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
82+
async def sample(self, step: int, **kwargs) -> Tuple[List[Experience], Dict, List]:
8583
min_model_version = max(step - self.max_staleness, 0)
8684
metrics = {}
8785
with Timer(metrics, "time/read_experience"):
8886
exp_list = await self.exp_buffer.read_async(min_model_version=min_model_version)
8987
repr_samples = representative_sample(exp_list)
9088
self.set_model_version_metric(exp_list, metrics)
91-
with Timer(metrics, "time/gather_experience"):
92-
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
93-
return exps, metrics, repr_samples
89+
return exp_list, metrics, repr_samples
9490

9591

9692
@Deprecated

trinity/common/experience.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,8 @@ class Experience:
136136
# for on-policy distillation
137137
teacher_logprobs: Optional[Tensor] = None # [resp_length]
138138

139+
custom_fields: List[CustomField] = field(default_factory=list)
140+
139141
def __init__( # noqa: C901
140142
self,
141143
*,
@@ -161,6 +163,7 @@ def __init__( # noqa: C901
161163
rejected_messages=None,
162164
multi_modal_inputs=None,
163165
teacher_logprobs=None,
166+
custom_fields=None,
164167
):
165168
if action_mask is not None:
166169
experience_type = "multi_turn"
@@ -250,6 +253,7 @@ def __init__( # noqa: C901
250253
self.rejected = torch.tensor(self.rejected)
251254
if self.teacher_logprobs is not None and not isinstance(self.teacher_logprobs, Tensor):
252255
self.teacher_logprobs = torch.tensor(self.teacher_logprobs, dtype=torch.float32)
256+
self.custom_fields = custom_fields or []
253257

254258
def serialize(self) -> bytes:
255259
"""Serialize the experience to bytes."""

trinity/trainer/tinker/utils.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
from tinker import types
66

7-
from trinity.common.experience import Experiences
7+
from trinity.common.experience import Experience, split_dpo_experience_to_single_turn
88

99

1010
def pad_to_length(
@@ -23,60 +23,61 @@ def pad_to_length(
2323

2424

2525
def to_tinker_input(
26-
experiences: Experiences, logger: Logger
26+
experiences: List[Experience], logger: Logger
2727
) -> Tuple[List[types.Datum], List[types.ModelInput], List[dict]]:
28-
cumsum = torch.cumsum(experiences.attention_masks, dim=-1)
29-
eos_mask_idx = cumsum.argmax(dim=-1)
30-
prompt_length = experiences.prompt_length
28+
assert len(experiences) > 0, "No experiences provided."
29+
if experiences[0].experience_type == "dpo":
30+
experiences = split_dpo_experience_to_single_turn(experiences)
31+
3132
batch = []
3233
batch_input_tokens = []
3334
model_inputs_list = []
34-
for i in range(len(experiences.tokens)):
35-
tokens = experiences.tokens[i]
36-
attention_mask = experiences.attention_masks[i]
37-
response_mask = attention_mask[prompt_length:]
38-
input_tokens = tokens[attention_mask].long()
39-
exp_seq_length = sum(attention_mask)
40-
exp_response_length = sum(response_mask)
35+
for exp in experiences:
36+
tokens = exp.tokens
37+
input_tokens = tokens.long()
38+
prompt_length = exp.prompt_length
39+
total_length = len(tokens) # type: ignore
40+
response_length = total_length - prompt_length
4141
loss_fn_inputs = {
42-
"weights": pad_to_length(
43-
experiences.action_masks[i][response_mask].float(), exp_seq_length - 1 # type: ignore
42+
"weights": torch.concat(
43+
[
44+
torch.zeros(prompt_length - 1, dtype=torch.float32),
45+
exp.action_mask.float(),
46+
]
4447
),
4548
"target_tokens": input_tokens.tolist()[1:],
4649
}
4750
model_inputs = {
48-
"total_length": exp_seq_length,
49-
"action_mask": experiences.action_masks[i][response_mask], # type: ignore
51+
"total_length": total_length,
52+
"action_mask": exp.action_mask,
5053
}
51-
if experiences.rewards is not None or experiences.token_level_rewards is not None:
52-
assert experiences.logprobs is not None
53-
if experiences.token_level_rewards is not None:
54-
if experiences.rewards is not None:
54+
if exp.reward is not None or exp.token_level_reward is not None:
55+
assert exp.logprobs is not None
56+
if exp.token_level_reward is not None:
57+
if exp.reward is not None:
5558
logger.warning(
56-
"Both experiences.rewards and experiences.token_level_rewards are provided. "
57-
"Using experiences.token_level_rewards."
59+
"Both exp.rewards and exp.token_level_rewards are provided. "
60+
"Using exp.token_level_rewards."
5861
)
59-
token_level_rewards = experiences.token_level_rewards[i][response_mask]
62+
token_level_reward = exp.token_level_reward
6063
else:
61-
token_level_rewards = torch.zeros(
62-
exp_response_length, dtype=experiences.rewards.dtype
63-
)
64-
token_level_rewards[eos_mask_idx[i] - prompt_length] = experiences.rewards[i]
64+
token_level_reward = torch.zeros(response_length, dtype=torch.float32)
65+
token_level_reward[-1] = exp.reward
6566
model_inputs.update(
6667
{
67-
"token_level_scores": token_level_rewards,
68-
"old_logprob": experiences.logprobs[i][response_mask], # type: ignore
68+
"token_level_scores": token_level_reward,
69+
"old_logprob": exp.logprobs,
6970
}
7071
)
71-
if experiences.advantages is not None:
72-
model_inputs["advantages"] = experiences.advantages[i][response_mask]
73-
if experiences.returns is not None:
74-
model_inputs["returns"] = experiences.returns[i][response_mask]
72+
for attr in ["advantages", "returns", "teacher_logprobs"]:
73+
if getattr(exp, attr, None) is not None:
74+
model_inputs[attr] = getattr(exp, attr)
7575
# TODO: if tinker support multi-modal input, we can add it here
76-
if experiences.custom_fields:
77-
for field in experiences.custom_fields:
78-
if hasattr(experiences, field):
79-
model_inputs[field] = getattr(experiences, field)
76+
for custom_field in exp.custom_fields:
77+
model_inputs[custom_field.destination_field] = torch.tensor(
78+
exp.info[custom_field.source_field],
79+
dtype=custom_field.data_type,
80+
)
8081

8182
batch.append(
8283
types.Datum(

trinity/trainer/tinker_trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Dict
2+
from typing import Dict, List
33

44
import ray
55
import tinker
@@ -16,7 +16,7 @@
1616
from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN
1717
from trinity.algorithm.utils import prefix_metrics
1818
from trinity.common.config import Config
19-
from trinity.common.experience import Experiences
19+
from trinity.common.experience import Experience
2020
from trinity.manager.synchronizer import Synchronizer
2121
from trinity.trainer.tinker.utils import (
2222
compute_data_metrics,
@@ -196,11 +196,11 @@ def _loss_func(
196196
avg_metrics = {k: sum(v) / len(v) for k, v in metrics.items()}
197197
return total_loss, avg_metrics
198198

199-
async def train_step(self, batch_exps: Experiences) -> Dict:
199+
async def train_step(self, batch_exps: List[Experience]) -> Dict:
200200
"""Training one step.
201201
202202
Args:
203-
batch (Experiences): A batch of experiences to train.
203+
batch (List[Experience]): A batch of experiences to train.
204204
205205
Returns:
206206
Dict: Metrics of the training step.

trinity/trainer/trainer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from trinity.algorithm.sample_strategy.sample_strategy import SampleStrategy
1818
from trinity.common.config import Config
1919
from trinity.common.constants import RunningStatus, SyncMethod, SyncStyle
20-
from trinity.common.experience import Experiences
20+
from trinity.common.experience import Experience
2121
from trinity.manager.state_manager import StateManager
2222
from trinity.manager.synchronizer import Synchronizer
2323
from trinity.utils.log import get_logger
@@ -108,7 +108,7 @@ async def train(self) -> str:
108108
self.logger.info("--------------------\n> Trainer finished.\n--------------------")
109109
return self.config.trainer.name
110110

111-
async def train_step(self, exps: Experiences) -> Dict:
111+
async def train_step(self, exps: List[Experience]) -> Dict:
112112
"""Train one step.
113113
114114
Returns:
@@ -123,16 +123,16 @@ async def train_step(self, exps: Experiences) -> Dict:
123123
metrics.update(train_metrics)
124124
return metrics
125125

126-
async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
126+
async def _sample_data(self) -> Tuple[List[Experience], Dict, List[Dict]]:
127127
"""Sample a batch of experiences.
128128
129129
Returns:
130-
Experiences: A batch of experiences.
130+
List[Experience]: A batch of experiences.
131131
Dict: Metrics of the sampling step.
132132
List[Dict]: A list of representative samples for logging.
133133
"""
134134
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
135-
metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids))
135+
metrics["sample/task_count"] = len(set(exp.eid.tid for exp in batch))
136136
return batch, metrics, repr_samples
137137

138138
async def need_sync(self) -> bool:
@@ -239,11 +239,11 @@ def train_step_num(self) -> int:
239239
"""Get the current training step number."""
240240

241241
@abstractmethod
242-
async def train_step(self, batch_exps: Experiences) -> Dict:
242+
async def train_step(self, batch_exps: List[Experience]) -> Dict:
243243
"""Training one step.
244244
245245
Args:
246-
batch_exps (Experiences): A batch of experiences to train.
246+
batch_exps (List[Experience]): A batch of experiences to train.
247247
248248
Returns:
249249
Dict: Metrics of the training step.

0 commit comments

Comments
 (0)