Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/algorithm/sample_strategy/mix_sample_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
exp_list = usual_exp_list + expert_exp_list
repr_samples = representative_sample(exp_list)

self.set_model_version_metric(exp_list, metrics)
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(
experiences=exp_list,
Expand Down
16 changes: 13 additions & 3 deletions trinity/algorithm/sample_strategy/sample_strategy.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple

from trinity.algorithm.sample_strategy.utils import representative_sample
from trinity.buffer import get_buffer_reader
from trinity.common.config import BufferConfig
from trinity.common.experience import Experiences
from trinity.common.experience import Experience, Experiences
from trinity.utils.annotations import Deprecated
from trinity.utils.monitor import gather_metrics
from trinity.utils.registry import Registry
from trinity.utils.timer import Timer

Expand All @@ -16,6 +17,14 @@ class SampleStrategy(ABC):
def __init__(self, buffer_config: BufferConfig, **kwargs) -> None:
self.pad_token_id = buffer_config.pad_token_id

def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict):
metric_list = [
{"model_version": exp.info["model_version"]}
for exp in exp_list
if "model_version" in exp.info
]
metrics.update(gather_metrics(metric_list, "sample"))

@abstractmethod
async def sample(self, step: int) -> Tuple[Experiences, Dict, List]:
"""Sample data from buffer.
Expand All @@ -41,11 +50,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs):
super().__init__(buffer_config)
self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type]

async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]:
async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]:
metrics = {}
with Timer(metrics, "time/read_experience"):
exp_list = await self.exp_buffer.read_async()
repr_samples = representative_sample(exp_list)
self.set_model_version_metric(exp_list, metrics)
with Timer(metrics, "time/gather_experience"):
exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore
return exps, metrics, repr_samples
Expand Down
16 changes: 9 additions & 7 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None:
)
self.save_interval = config.trainer.save_interval
self.last_sync_step = None
self.last_sync_time = None
self.total_steps = config.trainer.total_steps or float("inf")

async def prepare(self) -> None:
Expand All @@ -68,22 +69,20 @@ async def train(self) -> str:
"""Train the model."""
while self.train_step_num < self.total_steps:
try:
st = time.time()
metrics = {}
# sample may be blocked due to explorer does not generate enough data
self.logger.info(f"Sample data for step {self.train_step_num + 1} started.")
sample_task = asyncio.create_task(self._sample_data())
while not sample_task.done():
# sync weight to make sure the explorer can continue to explore and generate enough data
if await self.need_sync():
# Currently, we do not record the metrics of sync_weight here
await self.sync_weight()
metrics.update(await self.sync_weight())
await asyncio.sleep(1)
exps, metrics, repr_samples = await sample_task
exps, sample_metrics, repr_samples = await sample_task
metrics.update(sample_metrics)
self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.")
metrics.update(await self.train_step(exps))
if await self.need_sync():
# Record the time: sample_experience + train_step (>=1)
metrics.update({"time/trainer_sync_interval": time.time() - st})
metrics.update(await self.sync_weight())
if self.need_save():
metrics.update(self.save_checkpoint())
Expand Down Expand Up @@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]:
List[Dict]: A list of representative samples for logging.
"""
batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1)
metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids))
metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids))
return batch, metrics, repr_samples

async def need_sync(self) -> bool:
Expand Down Expand Up @@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict:
"""Sync the model weight."""
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} started.")
metrics = {}
if self.last_sync_time is not None:
metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time
with Timer(metrics, "time/sync_weight"):
if self.config.synchronizer.sync_method == SyncMethod.NCCL:
result = await self.synchronizer.ready_to_nccl_sync.remote(
Expand All @@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict:
elif self.config.synchronizer.sync_method == SyncMethod.MEMORY:
self.engine.upload_state_dict()
self.last_sync_step = self.train_step_num
self.last_sync_time = time.time()
await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING)
self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.")
return metrics
Expand Down