diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index 38d20118c5..306e1b0836 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -85,6 +85,7 @@ async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: exp_list = usual_exp_list + expert_exp_list repr_samples = representative_sample(exp_list) + self.set_model_version_metric(exp_list, metrics) with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences( experiences=exp_list, diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py index 9d2bddf798..db3dc4f012 100644 --- a/trinity/algorithm/sample_strategy/sample_strategy.py +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -1,11 +1,12 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Tuple +from typing import Dict, List, Tuple from trinity.algorithm.sample_strategy.utils import representative_sample from trinity.buffer import get_buffer_reader from trinity.common.config import BufferConfig -from trinity.common.experience import Experiences +from trinity.common.experience import Experience, Experiences from trinity.utils.annotations import Deprecated +from trinity.utils.monitor import gather_metrics from trinity.utils.registry import Registry from trinity.utils.timer import Timer @@ -16,6 +17,14 @@ class SampleStrategy(ABC): def __init__(self, buffer_config: BufferConfig, **kwargs) -> None: self.pad_token_id = buffer_config.pad_token_id + def set_model_version_metric(self, exp_list: List[Experience], metrics: Dict): + metric_list = [ + {"model_version": exp.info["model_version"]} + for exp in exp_list + if "model_version" in exp.info + ] + metrics.update(gather_metrics(metric_list, "sample")) + @abstractmethod async def sample(self, step: int) -> Tuple[Experiences, Dict, List]: """Sample data from buffer. @@ -41,11 +50,12 @@ def __init__(self, buffer_config: BufferConfig, **kwargs): super().__init__(buffer_config) self.exp_buffer = get_buffer_reader(buffer_config.trainer_input.experience_buffer) # type: ignore[arg-type] - async def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + async def sample(self, step: int, **kwargs) -> Tuple[Experiences, Dict, List]: metrics = {} with Timer(metrics, "time/read_experience"): exp_list = await self.exp_buffer.read_async() repr_samples = representative_sample(exp_list) + self.set_model_version_metric(exp_list, metrics) with Timer(metrics, "time/gather_experience"): exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore return exps, metrics, repr_samples diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index fcea2b111a..bcc65ccebd 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -56,6 +56,7 @@ def __init__(self, config: Config) -> None: ) self.save_interval = config.trainer.save_interval self.last_sync_step = None + self.last_sync_time = None self.total_steps = config.trainer.total_steps or float("inf") async def prepare(self) -> None: @@ -68,22 +69,20 @@ async def train(self) -> str: """Train the model.""" while self.train_step_num < self.total_steps: try: - st = time.time() + metrics = {} # sample may be blocked due to explorer does not generate enough data self.logger.info(f"Sample data for step {self.train_step_num + 1} started.") sample_task = asyncio.create_task(self._sample_data()) while not sample_task.done(): # sync weight to make sure the explorer can continue to explore and generate enough data if await self.need_sync(): - # Currently, we do not record the metrics of sync_weight here - await self.sync_weight() + metrics.update(await self.sync_weight()) await asyncio.sleep(1) - exps, metrics, repr_samples = await sample_task + exps, sample_metrics, repr_samples = await sample_task + metrics.update(sample_metrics) self.logger.info(f"Sample data for step {self.train_step_num + 1} finished.") metrics.update(await self.train_step(exps)) if await self.need_sync(): - # Record the time: sample_experience + train_step (>=1) - metrics.update({"time/trainer_sync_interval": time.time() - st}) metrics.update(await self.sync_weight()) if self.need_save(): metrics.update(self.save_checkpoint()) @@ -126,7 +125,7 @@ async def _sample_data(self) -> Tuple[Experiences, Dict, List[Dict]]: List[Dict]: A list of representative samples for logging. """ batch, metrics, repr_samples = await self.sample_strategy.sample(self.train_step_num + 1) - metrics["sample/task_count"] = len(set(eid.task for eid in batch.eids)) + metrics["sample/task_count"] = len(set(eid.tid for eid in batch.eids)) return batch, metrics, repr_samples async def need_sync(self) -> bool: @@ -155,6 +154,8 @@ async def sync_weight(self) -> Dict: """Sync the model weight.""" self.logger.info(f"Trainer sync_weights at step {self.train_step_num} started.") metrics = {} + if self.last_sync_time is not None: + metrics["time/trainer_sync_interval"] = time.time() - self.last_sync_time with Timer(metrics, "time/sync_weight"): if self.config.synchronizer.sync_method == SyncMethod.NCCL: result = await self.synchronizer.ready_to_nccl_sync.remote( @@ -170,6 +171,7 @@ async def sync_weight(self) -> Dict: elif self.config.synchronizer.sync_method == SyncMethod.MEMORY: self.engine.upload_state_dict() self.last_sync_step = self.train_step_num + self.last_sync_time = time.time() await self.synchronizer.set_trainer_status.remote(RunningStatus.RUNNING) self.logger.info(f"Trainer sync_weights at step {self.train_step_num} finished.") return metrics