Skip to content

Commit a52cc3a

Browse files
authored
Fix metrics in trainer (#381)
1 parent 73c81b7 commit a52cc3a

File tree

3 files changed

+23
-10
lines changed

3 files changed

+23
-10
lines changed

trinity/algorithm/sample_strategy/mix_sample_strategy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
8585
exp_list = usual_exp_list + expert_exp_list
8686
repr_samples = representative_sample(exp_list)
8787

88+
self.set_model_version_metric(exp_list, metrics)
8889
with Timer(metrics, "time/gather_experience"):
8990
exps = Experiences.gather_experiences(
9091
experiences=exp_list,

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Tuple
2+
from typing import Dict, List, Tuple
33

44
from trinity.algorithm.sample_strategy.utils import representative_sample
55
from trinity.buffer import get_buffer_reader
66
from trinity.common.config import BufferConfig
7-
from trinity.common.experience import Experiences
7+
from trinity.common.experience import Experience, Experiences
88
from trinity.utils.annotations import Deprecated
9+
from trinity.utils.monitor import gather_metrics
910
from trinity.utils.registry import Registry
1011
from trinity.utils.timer import Timer
1112

@@ -16,6 +17,14 @@ class SampleStrategy(ABC):
1617
def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
1718
self.pad_token_id = buffer_config.pad_token_id
1819

20+
def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
21+
metric_list = [
22+
{"model_version": exp.info["model_version"]}
23+
for exp in exp_list
24+
if "model_version" in exp.info
25+
]
26+
metrics.update(gather_metrics(metric_list, "sample"))
27+
1928
@abstractmethod
2029
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
2130
"""Sample data from buffer.
@@ -41,11 +50,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
4150
super().__init__(buffer_config)
4251
self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]
4352

44-
async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
53+
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
4554
metrics = {}
4655
with Timer(metrics, "time/read_experience"):
4756
exp_list = await self.exp_buffer.read_async()
4857
repr_samples = representative_sample(exp_list)
58+
self.set_model_version_metric(exp_list, metrics)
4959
with Timer(metrics, "time/gather_experience"):
5060
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
5161
return exps, metrics, repr_samples

trinity/trainer/trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None:
5656
)
5757
self.save_interval = config.trainer.save_interval
5858
self.last_sync_step = None
59+
self.last_sync_time = None
5960
self.total_steps = config.trainer.total_steps or float("inf")
6061

6162
async def prepare(self) -> None:
@@ -68,22 +69,20 @@ async def train(self) -> str:
6869
"""Train the model."""
6970
while self.train_step_num < self.total_steps:
7071
try:
71-
st = time.time()
72+
metrics = {}
7273
# sample may be blocked due to explorer does not generate enough data
7374
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
7475
sample_task = asyncio.create_task(self._sample_data())
7576
while not sample_task.done():
7677
# sync weight to make sure the explorer can continue to explore and generate enough data
7778
if await self.need_sync():
78-
# Currently, we do not record the metrics of sync_weight here
79-
await self.sync_weight()
79+
metrics.update(await self.sync_weight())
8080
await asyncio.sleep(1)
81-
exps, metrics, repr_samples = await sample_task
81+
exps, sample_metrics, repr_samples = await sample_task
82+
metrics.update(sample_metrics)
8283
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
8384
metrics.update(await self.train_step(exps))
8485
if await self.need_sync():
85-
# Record the time: sample_experience + train_step (>=1)
86-
metrics.update({"time/trainer_sync_interval": time.time() - st})
8786
metrics.update(await self.sync_weight())
8887
if self.need_save():
8988
metrics.update(self.save_checkpoint())
@@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
126125
List[Dict]: A list of representative samples for logging.
127126
"""
128127
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
129-
metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids))
128+
metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids))
130129
return batch, metrics, repr_samples
131130

132131
async def need_sync(self) -> bool:
@@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict:
155154
"""Sync the model weight."""
156155
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} started.")
157156
metrics = {}
157+
if self.last_sync_time is not None:
158+
metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time
158159
with Timer(metrics, "time/sync_weight"):
159160
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
160161
result = await self.synchronizer.ready_to_nccl_sync.remote(
@@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict:
170171
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
171172
self.engine.upload_state_dict()
172173
self.last_sync_step = self.train_step_num
174+
self.last_sync_time = time.time()
173175
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
174176
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.")
175177
return metrics

0 commit comments

Comments
 (0)