diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index 8cb8856fbc..dbb8402ceb 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -79,14 +79,25 @@ Specifies the algorithm type and its related hyperparameters. algorithm: algorithm_type: grpo repeat_times: 1 - gamma: 1.0 - lam: 1.0 + + # The following parameters are optional + # If not specified, they will automatically be set based on the `algorithm_type` + sample_strategy: "default" + advantage_fn: "ppo" + kl_penalty_fn: "none" + kl_loss_fn: "k2" + entropy_loss_fn: "default" ``` - `algorithm_type`: Type of reinforcement learning algorithm. Supported types: `ppo`, `grpo`, `opmd`, `dpo`. - `repeat_times`: Number of times each task is repeated. Default is `1`. In `dpo`, this is automatically set to `2`. -- `gamma`: Discount factor for future rewards. Default is `1.0`. -- `lam`: Lambda value for Generalized Advantage Estimation (GAE). Default is `1.0`. + +- `sample_strategy`: The sampling strategy used for loading experiences from experience buffer. +- `advantage_fn`: The advantage function used for computing advantages. +- `kl_penalty_fn`: The KL penalty function used for computing KL penalty. +- `kl_loss_fn`: The KL loss function used for computing KL loss. +- `entropy_loss_fn`: The entropy loss function used for computing entropy loss. + --- diff --git a/trinity/algorithm/__init__.py b/trinity/algorithm/__init__.py index 101364c57c..ff52f609e5 100644 --- a/trinity/algorithm/__init__.py +++ b/trinity/algorithm/__init__.py @@ -2,6 +2,7 @@ from trinity.algorithm.entropy_loss_fn import ENTROPY_LOSS_FN, EntropyLossFn from trinity.algorithm.kl_fn import KL_FN, KLFn from trinity.algorithm.policy_loss_fn import POLICY_LOSS_FN, PolicyLossFn +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, SampleStrategy __all__ = [ "AdvantageFn", @@ -12,4 +13,6 @@ "KL_FN", "EntropyLossFn", "ENTROPY_LOSS_FN", + "SampleStrategy", + "SAMPLE_STRATEGY", ] diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index f94798fe85..88b9b946b7 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -7,7 +7,6 @@ from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel from trinity.common.config import Config from trinity.common.constants import SyncMethod -from trinity.common.experience import Experience, Experiences from trinity.utils.log import get_logger from trinity.utils.registry import Registry @@ -31,10 +30,6 @@ class AlgorithmType(ABC, metaclass=ConstantMeta): can_balance_batch: bool schema: type - @classmethod - def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences: - return Experiences.gather_experiences(exps, pad_token_id) - @classmethod def get_default_config(cls) -> Dict: raise NotImplementedError @@ -62,6 +57,7 @@ class SFTAlgorithm(AlgorithmType): @classmethod def get_default_config(cls) -> Dict: return { + "sample_strategy": "default", "policy_loss_fn": "sft", "kl_loss_fn": "none", "entropy_loss_fn": "none", @@ -83,11 +79,12 @@ class PPOAlgorithm(AlgorithmType): def get_default_config(cls) -> Dict: return { "repeat_times": 1, + "sample_strategy": "warmup", "policy_loss_fn": "ppo", "advantage_fn": "ppo", "kl_penalty_fn": "none", "kl_loss_fn": "k2", - "entropy_loss_fn": "basic", + "entropy_loss_fn": "default", } @@ -106,11 +103,12 @@ class GRPOAlgorithm(AlgorithmType): def get_default_config(cls) -> Dict: return { "repeat_times": 2, + "sample_strategy": "warmup", "policy_loss_fn": "ppo", "advantage_fn": "grpo", "kl_penalty_fn": "none", "kl_loss_fn": "k2", - "entropy_loss_fn": "basic", + "entropy_loss_fn": "default", } @@ -129,11 +127,12 @@ class OPMDAlgorithm(AlgorithmType): def get_default_config(cls) -> Dict: return { "repeat_times": 2, + "sample_strategy": "warmup", "policy_loss_fn": "opmd", "advantage_fn": "opmd", "kl_penalty_fn": "none", "kl_loss_fn": "k2", - "entropy_loss_fn": "basic", + "entropy_loss_fn": "default", } @@ -148,17 +147,14 @@ class DPOAlgorithm(AlgorithmType): can_balance_batch: bool = False schema: type = DPODataModel - @classmethod - def gather_experience(cls, exps: list[Experience], pad_token_id: int = 0) -> Experiences: - return Experiences.gather_dpo_experiences(exps, pad_token_id) - @classmethod def get_default_config(cls) -> Dict: return { "repeat_times": 2, # fake repeat times + "sample_strategy": "dpo", "policy_loss_fn": "dpo", "kl_loss_fn": "k2", - "entropy_loss_fn": "basic", + "entropy_loss_fn": "default", } @classmethod diff --git a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py index 41583ec3ba..d6179a832c 100644 --- a/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py +++ b/trinity/algorithm/entropy_loss_fn/entropy_loss_fn.py @@ -40,8 +40,8 @@ def default_args(cls) -> Dict: return {"entropy_coef": 0.0} -@ENTROPY_LOSS_FN.register_module("basic") -class BasicEntropyLossFn(EntropyLossFn): +@ENTROPY_LOSS_FN.register_module("default") +class DefaultEntropyLossFn(EntropyLossFn): """ Basic entropy loss function. """ diff --git a/trinity/algorithm/sample_strategy/__init__.py b/trinity/algorithm/sample_strategy/__init__.py new file mode 100644 index 0000000000..60f2e268ae --- /dev/null +++ b/trinity/algorithm/sample_strategy/__init__.py @@ -0,0 +1,13 @@ +from trinity.algorithm.sample_strategy.sample_strategy import ( + SAMPLE_STRATEGY, + DefaultSampleStrategy, + SampleStrategy, + WarmupSampleStrategy, +) + +__all__ = [ + "SAMPLE_STRATEGY", + "SampleStrategy", + "DefaultSampleStrategy", + "WarmupSampleStrategy", +] diff --git a/trinity/algorithm/sample_strategy/sample_strategy.py b/trinity/algorithm/sample_strategy/sample_strategy.py new file mode 100644 index 0000000000..8686a0d497 --- /dev/null +++ b/trinity/algorithm/sample_strategy/sample_strategy.py @@ -0,0 +1,114 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Tuple + +from trinity.algorithm.sample_strategy.utils import representative_sample, to_data_proto +from trinity.buffer import get_buffer_reader +from trinity.common.config import BufferConfig +from trinity.common.experience import Experiences +from trinity.utils.registry import Registry +from trinity.utils.timer import Timer + +SAMPLE_STRATEGY = Registry("sample_strategy") + + +class SampleStrategy(ABC): + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + self.pad_token_id = buffer_config.pad_token_id + self.trainer_type = trainer_type + + @abstractmethod + def sample(self, step: int) -> Tuple[Any, Dict, List]: + """Sample experiences from buffer. + + Args: + step (`int`): The step number of current step. + + Returns: + `Any`: The sampled experiences. + `Dict`: Metrics for logging. + `List`: Representative experiences for logging. + """ + + @classmethod + def default_args(cls) -> dict: + return {} + + +@SAMPLE_STRATEGY.register_module("warmup") +class WarmupSampleStrategy(SampleStrategy): + """The default sample strategy.""" + + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + self.sft_warmup_steps = buffer_config.trainer_input.sft_warmup_steps + if self.sft_warmup_steps > 0 and buffer_config.trainer_input.sft_warmup_dataset is None: + raise ValueError("sft_warmup_dataset is required when sft_warmup_steps > 0") + if buffer_config.trainer_input.sft_warmup_dataset is not None: + self.sft_buffer = get_buffer_reader( + buffer_config.trainer_input.sft_warmup_dataset, buffer_config + ) + else: + self.sft_buffer = None + + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + if step <= self.sft_warmup_steps: + exp_list = self.sft_buffer.read() + else: + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + +@SAMPLE_STRATEGY.register_module("default") +class DefaultSampleStrategy(SampleStrategy): + def __init__(self, buffer_config: BufferConfig, trainer_type: str, **kwargs): + super().__init__(buffer_config, trainer_type) + self.exp_buffer = get_buffer_reader( + buffer_config.trainer_input.experience_buffer, buffer_config # type: ignore + ) + + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_experiences(exp_list, self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") + + +@SAMPLE_STRATEGY.register_module("dpo") +class DPOSampleStrategy(WarmupSampleStrategy): + def sample(self, step: int, **kwargs) -> Tuple[Any, Dict, List]: + metrics = {} + with Timer(metrics, "read_time"): + if step <= self.sft_warmup_steps: + exp_list = self.sft_buffer.read() + else: + exp_list = self.exp_buffer.read() + repr_samples = representative_sample(exp_list) + with Timer(metrics, "gather_time"): + exps = Experiences.gather_dpo_experiences(exp_list, pad_token_id=self.pad_token_id) # type: ignore + if self.trainer_type == "verl": + with Timer(metrics, "convert_time"): + data = to_data_proto(exps) + return data, metrics, repr_samples + else: + raise NotImplementedError(f"backend {self.trainer_type} is not supported") diff --git a/trinity/algorithm/sample_strategy/utils.py b/trinity/algorithm/sample_strategy/utils.py new file mode 100644 index 0000000000..8c443a20b1 --- /dev/null +++ b/trinity/algorithm/sample_strategy/utils.py @@ -0,0 +1,78 @@ +import random +from typing import List + +import numpy as np +import torch +from verl.trainer.ppo.ray_trainer import DataProto + +from trinity.common.experience import Experience, Experiences + + +def to_data_proto(experiences: Experiences) -> DataProto: + attention_mask = experiences.attention_masks + cumsum = torch.cumsum(attention_mask, dim=-1) + position_ids = torch.clip(cumsum - 1, 0, None).long() + batch_dict = { + "uid": np.array(experiences.run_ids), + "position_ids": position_ids, + "input_ids": experiences.tokens.long(), + "responses": experiences.tokens[:, experiences.prompt_length :].long(), + "attention_mask": attention_mask.long(), + "response_mask": ( + experiences.action_masks[:, experiences.prompt_length :].long() + if hasattr(experiences, "action_masks") and experiences.action_masks is not None + else attention_mask[:, experiences.prompt_length :].long() + ), + } + if experiences.rewards is not None: + token_level_rewards = torch.zeros(attention_mask.shape, dtype=experiences.rewards.dtype) + eos_mask_idx = cumsum.argmax(dim=-1) + token_level_rewards[ + torch.arange(experiences.batch_size), eos_mask_idx + ] = experiences.rewards + token_level_rewards = token_level_rewards[:, experiences.prompt_length :] + batch_dict.update( + { + "token_level_scores": token_level_rewards, + "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore + } + ) + return DataProto.from_single_dict(batch_dict) + + +def representative_sample(experiences: List[Experience]) -> List[dict]: + if experiences[0].reward is None: + sample = random.choice(experiences) + return [ + { + "prompt": sample.prompt_text, + "response": sample.response_text, + } + ] + samples = [] + min_reward_sample = None + max_reward_sample = None + for exp in experiences: + if exp.reward is None: + continue + if min_reward_sample is None or exp.reward < min_reward_sample.reward: + min_reward_sample = exp + if max_reward_sample is None or exp.reward > max_reward_sample.reward: + max_reward_sample = exp + if min_reward_sample is not None: + samples.append( + { + "prompt": min_reward_sample.prompt_text, + "response": min_reward_sample.response_text, + "reward": min_reward_sample.reward, + } + ) + if max_reward_sample is not None: + samples.append( + { + "prompt": max_reward_sample.prompt_text, + "response": max_reward_sample.response_text, + "reward": max_reward_sample.reward, + } + ) + return samples diff --git a/trinity/common/config.py b/trinity/common/config.py index 22d8f3d711..7c371f4bcb 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -176,9 +176,8 @@ class AlgorithmConfig: # for GRPO-like algorithms, repeat each task for `repeat_times` times repeat_times: int = 1 - policy_loss_fn: Optional[str] = None # "ppo" - # If not set, use PolicyLossFn.default_args() - policy_loss_fn_args: Optional[dict] = None + sample_strategy: Optional[str] = None + sample_strategy_args: Optional[dict] = None advantage_fn: Optional[str] = None # "ppo" # If not set, use AdvantageFn.default_args() @@ -188,11 +187,15 @@ class AlgorithmConfig: # If not set, use kl_penalty_fn.default_args() kl_penalty_fn_args: Optional[dict] = None + policy_loss_fn: Optional[str] = None # "ppo" + # If not set, use PolicyLossFn.default_args() + policy_loss_fn_args: Optional[dict] = None + kl_loss_fn: Optional[str] = None # "k2" # set to "none" to disable kl loss # If not set, use kl_loss_fn.default_args() kl_loss_fn_args: Optional[dict] = None - entropy_loss_fn: Optional[str] = None # "basic" + entropy_loss_fn: Optional[str] = None # "default" # If not set, use entropy_loss_fn.default_args() entropy_loss_fn_args: Optional[dict] = None @@ -489,23 +492,32 @@ def _check_algorithm(self) -> None: ENTROPY_LOSS_FN, KL_FN, POLICY_LOSS_FN, + SAMPLE_STRATEGY, ) from trinity.algorithm.algorithm import ALGORITHM_TYPE algorithm = ALGORITHM_TYPE.get(self.algorithm.algorithm_type) algorithm.check_config(self) default_config = { + "sample_strategy": "warmup", "policy_loss_fn": "ppo", "advantage_fn": "ppo", "kl_penalty_fn": "none", "kl_loss_fn": "k2", - "entropy_loss_fn": "basic", + "entropy_loss_fn": "default", } default_config.update(algorithm.get_default_config()) for key, value in default_config.items(): if getattr(self.algorithm, key, None) is None: setattr(self.algorithm, key, value) + # TODO: simplify the following code + sample_strategy_cls = SAMPLE_STRATEGY.get(self.algorithm.sample_strategy) + if sample_strategy_cls is None: + raise ValueError(f"Invalid sample_strategy: {self.algorithm.sample_strategy}") + if self.algorithm.sample_strategy_args is None: + self.algorithm.sample_strategy_args = sample_strategy_cls.default_args() + policy_fn_cls = POLICY_LOSS_FN.get(self.algorithm.policy_loss_fn) if policy_fn_cls is None: raise ValueError(f"Invalid policy_loss_fn: {self.algorithm.policy_loss_fn}") diff --git a/trinity/trainer/trainer.py b/trinity/trainer/trainer.py index 95859685ee..2920604fbb 100644 --- a/trinity/trainer/trainer.py +++ b/trinity/trainer/trainer.py @@ -12,9 +12,7 @@ import ray -from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm from trinity.algorithm.algorithm_manager import AlgorithmManager -from trinity.buffer import get_buffer_reader from trinity.common.config import Config from trinity.common.constants import SyncMethod from trinity.utils.log import get_logger @@ -28,18 +26,6 @@ def __init__(self, config: Config) -> None: self.config = config self.logger = get_logger(__name__) self.algorithm_manager = AlgorithmManager(config) - self.train_buffer = get_buffer_reader( - self.config.buffer.trainer_input.experience_buffer, # type: ignore - self.config.buffer, - ) - self.sft_warmup_buffer = ( - get_buffer_reader( - self.config.buffer.trainer_input.sft_warmup_dataset, # type: ignore - self.config.buffer, - ) - if self.config.buffer.trainer_input.sft_warmup_steps > 0 - else None - ) self.engine = get_trainer_wrapper(config) def prepare(self) -> None: @@ -71,29 +57,7 @@ def train_step(self) -> Tuple[bool, int]: Returns: bool: Whether to continue training. """ - algo_config = self.algorithm_manager.get_current_algorithm_config( - self.engine.train_step_num + 1 - ) - algo_type = algo_config.algorithm_type - algorithm = ALGORITHM_TYPE.get(algo_type) - if algorithm.use_rollout: - strategy = self.config.buffer.trainer_input.read_experience_strategy - else: - strategy = None - try: - if algorithm == SFTAlgorithm: - exps = self.sft_warmup_buffer.read() - else: - exps = self.train_buffer.read(strategy=strategy) - except StopIteration: - self.logger.warning("No more data to train. Stop training.") - return False, self.engine.train_step_num - - experiences = algorithm.gather_experience( - exps, - pad_token_id=self.config.buffer.pad_token_id, # type: ignore - ) - return self.engine.train_step(experiences) + return self.engine.train_step() def sync_weight(self) -> None: """Sync the model weight.""" @@ -126,7 +90,7 @@ def train_step_num(self) -> int: """Get the current training step number.""" @abstractmethod - def train_step(self, experiences) -> Tuple[bool, int]: + def train_step(self) -> Tuple[bool, int]: """Training.""" @abstractmethod diff --git a/trinity/trainer/verl_trainer.py b/trinity/trainer/verl_trainer.py index 84da4cbf98..110a54a7db 100644 --- a/trinity/trainer/verl_trainer.py +++ b/trinity/trainer/verl_trainer.py @@ -6,9 +6,8 @@ import os import sys from pprint import pprint -from typing import Tuple +from typing import Dict, List, Tuple -import numpy as np import pandas as pd import ray import torch @@ -20,7 +19,6 @@ reduce_metrics, ) from verl.trainer.ppo.ray_trainer import ( - DataProto, RayClassWithInitArgs, RayPPOTrainer, RayWorkerGroup, @@ -33,7 +31,7 @@ from verl.utils import hf_tokenizer from verl.utils.fs import copy_local_path_from_hdfs -from trinity.algorithm import ADVANTAGE_FN, KL_FN +from trinity.algorithm import ADVANTAGE_FN, KL_FN, SAMPLE_STRATEGY from trinity.algorithm.algorithm import ALGORITHM_TYPE, SFTAlgorithm from trinity.algorithm.algorithm_manager import AlgorithmManager from trinity.algorithm.utils import prefix_metrics @@ -135,7 +133,11 @@ def __init__( self.kl_fn = KL_FN.get(self.algorithm_config.kl_penalty_fn)( **self.algorithm_config.kl_penalty_fn_args ) - + self.sample_strategy = SAMPLE_STRATEGY.get(global_config.algorithm.sample_strategy)( + buffer_config=global_config.buffer, + trainer_type=global_config.trainer.trainer_type, + **global_config.algorithm.sample_strategy_args, + ) super().__init__( config, tokenizer, @@ -237,9 +239,7 @@ def init_workers(self): self.actor_rollout_wg.init_model() def reset_experiences_example_table(self): - self.experiences_example_table = pd.DataFrame( - columns=["step", "reward", "prompt", "response"] - ) + self.sample_exps_to_log = [] @property def train_step_num(self) -> int: @@ -270,9 +270,15 @@ def _create_dataloader(self): # TODO: compute total training steps self.total_training_steps = self.config.trainer.total_training_steps or sys.maxsize - def train_step(self, experiences: Experiences) -> Tuple[bool, int]: - self.global_steps += 1 + def train_step(self) -> Tuple[bool, int]: # noqa C901 metrics = {} + try: + batch, sample_metrics, exp_samples = self.sample_strategy.sample(self.global_steps + 1) + prefix_metrics(sample_metrics, "sample", metrics) + except StopIteration: + print("No more data to train. Stop training.") + return False, self.global_steps + self.global_steps += 1 timing_raw = {} algorithm_config = self.algorithm_manager.get_current_algorithm_config(self.global_steps) algorithm = ALGORITHM_TYPE.get(algorithm_config.algorithm_type) @@ -283,39 +289,6 @@ def train_step(self, experiences: Experiences) -> Tuple[bool, int]: self.algorithm = algorithm with _timer("step", timing_raw): - # Convert rewards to token_level_rewards - attention_mask = experiences.attention_masks - cumsum = torch.cumsum(attention_mask, dim=-1) - position_ids = torch.clip(cumsum - 1, 0, None).long() - batch_dict = { - "uid": np.array(experiences.run_ids), - "position_ids": position_ids, - "input_ids": experiences.tokens.long(), - "responses": experiences.tokens[:, experiences.prompt_length :].long(), - "attention_mask": attention_mask.long(), - "response_mask": ( - experiences.action_masks[:, experiences.prompt_length :].long() - if hasattr(experiences, "action_masks") and experiences.action_masks is not None - else attention_mask[:, experiences.prompt_length :].long() - ), - } - if self.algorithm.use_advantage: - token_level_rewards = torch.zeros( - attention_mask.shape, dtype=experiences.rewards.dtype - ) - eos_mask_idx = cumsum.argmax(dim=-1) - token_level_rewards[ - torch.arange(experiences.batch_size), eos_mask_idx - ] = experiences.rewards - token_level_rewards = token_level_rewards[:, experiences.prompt_length :] - batch_dict.update( - { - "token_level_scores": token_level_rewards, - "old_log_probs": experiences.logprobs[:, experiences.prompt_length :], # type: ignore - } - ) - - batch = DataProto.from_single_dict(batch_dict) batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature if self.algorithm.can_balance_batch and self.config.trainer.balance_batch: @@ -381,7 +354,7 @@ def train_step(self, experiences: Experiences) -> Tuple[bool, int]: ) if self.algorithm.use_advantage and self.config.enable_preview: # TODO - self._log_experiences(experiences) + self._log_experiences(exp_samples) # TODO: make a canonical logger that supports various backend self.logger.log(data=metrics, step=self.global_steps) @@ -419,21 +392,13 @@ def _log_single_experience( "response": [response_text], } ) - self.experiences_example_table = pd.concat( - [self.experiences_example_table, new_row], ignore_index=True - ) - - def _log_experiences(self, experiences: Experiences) -> None: - skip_special_tokens = False - reward_max_id = torch.argmax(experiences.rewards) - self._log_single_experience(experiences, reward_max_id, skip_special_tokens) - - reward_min_id = torch.argmin(experiences.rewards) - self._log_single_experience(experiences, reward_min_id, skip_special_tokens) + self.sample_exps_to_log = pd.concat([self.sample_exps_to_log, new_row], ignore_index=True) + def _log_experiences(self, samples: List[Dict]) -> None: + self.sample_exps_to_log.extend(samples) if self.global_steps % self.config.trainer.sync_freq == 0: self.logger.log_table( - "rollout_examples", self.experiences_example_table, self.global_steps + "rollout_examples", pd.DataFrame(self.sample_exps_to_log), self.global_steps ) self.reset_experiences_example_table() diff --git a/trinity/utils/timer.py b/trinity/utils/timer.py new file mode 100644 index 0000000000..5e80f406b8 --- /dev/null +++ b/trinity/utils/timer.py @@ -0,0 +1,18 @@ +"""Timer context manager""" + +import time + + +class Timer: + def __init__(self, metrics_dict, key_name): + self.metrics = metrics_dict + self.key = key_name + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + end_time = time.time() + elapsed_time = end_time - self.start_time + self.metrics[self.key] = elapsed_time