Skip to content

Commit 87e95de

Browse files
committed
Fix metrics in trainer
1 parent 90b55e8 commit 87e95de

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

trinity/algorithm/sample_strategy/sample_strategy.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from trinity.common.config import BufferConfig
77
from trinity.common.experience import Experiences
88
from trinity.utils.annotations import Deprecated
9+
from trinity.utils.monitor import gather_metrics
910
from trinity.utils.registry import Registry
1011
from trinity.utils.timer import Timer
1112

@@ -46,6 +47,8 @@ async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
4647
with Timer(metrics, "time/read_experience"):
4748
exp_list = await self.exp_buffer.read_async()
4849
repr_samples = representative_sample(exp_list)
50+
metric_list = [{"model_version": exp.info["model_version"]} for exp in exp_list]
51+
metrics.update(gather_metrics(metric_list, "sample"))
4952
with Timer(metrics, "time/gather_experience"):
5053
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
5154
return exps, metrics, repr_samples

trinity/trainer/trainer.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None:
5656
)
5757
self.save_interval = config.trainer.save_interval
5858
self.last_sync_step = None
59+
self.last_sync_time = None
5960
self.total_steps = config.trainer.total_steps or float("inf")
6061

6162
async def prepare(self) -> None:
@@ -68,22 +69,20 @@ async def train(self) -> str:
6869
"""Train the model."""
6970
while self.train_step_num < self.total_steps:
7071
try:
71-
st = time.time()
72+
metrics = {}
7273
# sample may be blocked due to explorer does not generate enough data
7374
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
7475
sample_task = asyncio.create_task(self._sample_data())
7576
while not sample_task.done():
7677
# sync weight to make sure the explorer can continue to explore and generate enough data
7778
if await self.need_sync():
78-
# Currently, we do not record the metrics of sync_weight here
79-
await self.sync_weight()
79+
metrics.update(await self.sync_weight())
8080
await asyncio.sleep(1)
81-
exps, metrics, repr_samples = await sample_task
81+
exps, sample_metrics, repr_samples = await sample_task
82+
metrics.update(sample_metrics)
8283
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
8384
metrics.update(await self.train_step(exps))
8485
if await self.need_sync():
85-
# Record the time: sample_experience + train_step (>=1)
86-
metrics.update({"time/trainer_sync_interval": time.time() - st})
8786
metrics.update(await self.sync_weight())
8887
if self.need_save():
8988
metrics.update(self.save_checkpoint())
@@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
126125
List[Dict]: A list of representative samples for logging.
127126
"""
128127
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
129-
metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids))
128+
metrics["sample/task_count"] = len(set(f"{eid.batch}_{eid.task}" for eid in batch.eids))
130129
return batch, metrics, repr_samples
131130

132131
async def need_sync(self) -> bool:
@@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict:
155154
"""Sync the model weight."""
156155
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} started.")
157156
metrics = {}
157+
if self.last_sync_time is not None:
158+
metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time
158159
with Timer(metrics, "time/sync_weight"):
159160
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
160161
result = await self.synchronizer.ready_to_nccl_sync.remote(
@@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict:
170171
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
171172
self.engine.upload_state_dict()
172173
self.last_sync_step = self.train_step_num
174+
self.last_sync_time = time.time()
173175
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
174176
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.")
175177
return metrics

0 commit comments

Comments
 (0)