Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.utils.annotations import Deprecated
from trinity.utils.monitor import gather_metrics
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

Expand Down Expand Up @@ -46,6 +47,8 @@ async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
with Timer(metrics, "time/read_experience"):
exp_list = await self.exp_buffer.read_async()
repr_samples = representative_sample(exp_list)
metric_list = [{"model_version": exp.info["model_version"]} for exp in exp_list]
metrics.update(gather_metrics(metric_list, "sample"))
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
return exps, metrics, repr_samples
Expand Down
16 changes: 9 additions & 7 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None:
)
self.save_interval = config.trainer.save_interval
self.last_sync_step = None
self.last_sync_time = None
self.total_steps = config.trainer.total_steps or float("inf")

async def prepare(self) -> None:
Expand All @@ -68,22 +69,20 @@ async def train(self) -> str:
"""Train the model."""
while self.train_step_num < self.total_steps:
try:
st = time.time()
metrics = {}
# sample may be blocked due to explorer does not generate enough data
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
sample_task = asyncio.create_task(self._sample_data())
while not sample_task.done():
# sync weight to make sure the explorer can continue to explore and generate enough data
if await self.need_sync():
# Currently, we do not record the metrics of sync_weight here
await self.sync_weight()
metrics.update(await self.sync_weight())
await asyncio.sleep(1)
exps, metrics, repr_samples = await sample_task
exps, sample_metrics, repr_samples = await sample_task
metrics.update(sample_metrics)
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
metrics.update(await self.train_step(exps))
if await self.need_sync():
# Record the time: sample_experience + train_step (>=1)
metrics.update({"time/trainer_sync_interval": time.time() - st})
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
Expand Down Expand Up @@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
List[Dict]: A list of representative samples for logging.
"""
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids))
metrics["sample/task_count"] = len(set(f"{eid.batch}_{eid.task}" for eid in batch.eids))
return batch, metrics, repr_samples

async def need_sync(self) -> bool:
Expand Down Expand Up @@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict:
"""Sync the model weight."""
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} started.")
metrics = {}
if self.last_sync_time is not None:
metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time
with Timer(metrics, "time/sync_weight"):
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
result = await self.synchronizer.ready_to_nccl_sync.remote(
Expand All @@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict:
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
self.engine.upload_state_dict()
self.last_sync_step = self.train_step_num
self.last_sync_time = time.time()
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.")
return metrics
Expand Down