diff --git a/docs/sphinx_doc/source/tutorial/example_mix_algo.md b/docs/sphinx_doc/source/tutorial/example_mix_algo.md index 61ecec33b1..de664cae4a 100644 --- a/docs/sphinx_doc/source/tutorial/example_mix_algo.md +++ b/docs/sphinx_doc/source/tutorial/example_mix_algo.md @@ -46,7 +46,7 @@ class MIXAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 8, "policy_loss_fn": "mix", diff --git a/docs/sphinx_doc/source/tutorial/trinity_configs.md b/docs/sphinx_doc/source/tutorial/trinity_configs.md index dbb8402ceb..8c0cab9a0a 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_configs.md +++ b/docs/sphinx_doc/source/tutorial/trinity_configs.md @@ -376,11 +376,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -399,10 +394,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -447,22 +438,6 @@ critic: grad_clip: 1.0 cliprange_value: 0.5 -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - norm_adv_by_std_in_grpo: True - use_kl_in_reward: False - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - horizon: 10000 - target_kl: 0.1 - trainer: balance_batch: True # total_training_steps: null @@ -483,11 +458,7 @@ trainer: - `actor_rollout_ref.model.use_remove_padding`: Whether to remove pad tokens, which will reduce training time. - `actor_rollout_ref.actor.use_dynamic_bsz`: Whether to reorganize the batch data, specifically to splice the shorter data to reduce the batch size in the actual training process. - `actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu`: Batch size for one GPU in one forward pass. -- `actor_rollout_ref.actor.kl_loss_type`: How to compute kl loss, optional value is `kl`, `abs`, `mse` or `low_var_kl`. - `actor_rollout_ref.actor.ulysses_sequence_parallel_size`: Ulysses sequence parallel size. -- `actor_rollout_ref.actor.tau`: strength of regularization w.r.t. old / ref policy. -- `actor_rollout_ref.actor.opmd_baseline`: mean / logavgexp, applicable to opmd. -- `actor_rollout_ref.actor.use_uid`: True / False, applicable to pairwise_opmd. - `actor_rollout_ref.actor.optim.lr`: Learning rate for actor model. - `actor_rollout_ref.actor.optim.lr_warmup_steps_ratio`: Ratio of warmup steps for learning rate. - `actor_rollout_ref.actor.optim.warmup_style`: Warmup style for learning rate. @@ -505,8 +476,6 @@ trainer: - `critic.grad_clip`: Gradient clip for critic model training. - `critic.cliprange_value`: Used for compute value loss. -- `algorithm`: Training algorithm settings. - - `trainer.balance_batch`: Whether to balance batch size between GPUs during training. - `trainer.resume_mode`: Resume mode for training. Support `disable`, `auto` and `resume_path`. - `trainer.resume_from_path`: Path to resume from. diff --git a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md index 931cb81506..e07e6bb3dc 100644 --- a/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md +++ b/docs/sphinx_doc/source/tutorial/trinity_programming_guide.md @@ -443,13 +443,13 @@ The `AlgorithmType` class includes the following attributes and methods: - `use_advantage`: Whether to calculate Advantage; if False, the `AdvantageFn` call will be skipped - `can_balance_batch`: Whether the algorithm allows automatic balancing when splitting a batch into microbatches (which permute the order of samples) - `schema`: The format of experience data corresponding to the algorithm -- `get_default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE` +- `default_config`: Gets the default configuration of the algorithm, which will override attributes with the same name in `ALGORITHM_TYPE` Similarly, after implementation, you need to register this module through `ALGORITHM_TYPE`. Below is the implementation for the OPMD algorithm. Since the OPMD algorithm doesn't need to use the Critic model, `use_critic` is set to `False`. -The dictionary returned by the `get_default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss. +The dictionary returned by the `default_config` method indicates that OPMD will use the `opmd` type `AdvantageFn` and `PolicyLossFn` implemented in Step 1, will not apply KL Penalty on rewards, but will add a `k2` type KL loss when calculating the final loss. ```python @ALGORITHM_TYPE.register_module("opmd") @@ -463,7 +463,7 @@ class OPMDAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 2, "sample_strategy": "warmup", diff --git a/examples/async_gsm8k/verl_config.yaml b/examples/async_gsm8k/verl_config.yaml index de1b08f590..fc44fdad94 100644 --- a/examples/async_gsm8k/verl_config.yaml +++ b/examples/async_gsm8k/verl_config.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/dpo_humanlike/train_dpo.yaml b/examples/dpo_humanlike/train_dpo.yaml index 028c997e06..d5074848b0 100644 --- a/examples/dpo_humanlike/train_dpo.yaml +++ b/examples/dpo_humanlike/train_dpo.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True - kl_loss_coef: 0.1 # NOTE: beta for DPO - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -46,18 +41,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: False total_training_steps: 783 # diff --git a/examples/grpo_alfworld/train_alfworld.yaml b/examples/grpo_alfworld/train_alfworld.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_alfworld/train_alfworld.yaml +++ b/examples/grpo_alfworld/train_alfworld.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_gsm8k/train_gsm8k.yaml b/examples/grpo_gsm8k/train_gsm8k.yaml index de1b08f590..fc44fdad94 100644 --- a/examples/grpo_gsm8k/train_gsm8k.yaml +++ b/examples/grpo_gsm8k/train_gsm8k.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_math/train_math.yaml b/examples/grpo_math/train_math.yaml index 78bcb862c6..0a46bd1788 100644 --- a/examples/grpo_math/train_math.yaml +++ b/examples/grpo_math/train_math.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.0001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.0001 - trainer: balance_batch: True # auto: find the last ckpt to resume. If can't find, start from scratch diff --git a/examples/grpo_sciworld/train_sciworld.yaml b/examples/grpo_sciworld/train_sciworld.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_sciworld/train_sciworld.yaml +++ b/examples/grpo_sciworld/train_sciworld.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/grpo_webshop/train_webshop.yaml b/examples/grpo_webshop/train_webshop.yaml index 215b1817ab..5b73ec7403 100644 --- a/examples/grpo_webshop/train_webshop.yaml +++ b/examples/grpo_webshop/train_webshop.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -44,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/mix_math/train_mix_math.yaml b/examples/mix_math/train_mix_math.yaml index 7b14a87fad..ca072b78f6 100644 --- a/examples/mix_math/train_mix_math.yaml +++ b/examples/mix_math/train_mix_math.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True # False ppo_max_token_len_per_gpu: 25600 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: True # True for GRPO - kl_loss_coef: 0.0001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -33,10 +28,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -48,18 +39,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.0001 - trainer: balance_batch: True # auto: find the last ckpt to resume. If can't find, start from scratch diff --git a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml index 44a0111d64..5ddd5124ee 100644 --- a/examples/opmd_gsm8k/train_opmd_gsm8k.yaml +++ b/examples/opmd_gsm8k/train_opmd_gsm8k.yaml @@ -36,11 +36,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.000 - use_kl_loss: True - kl_loss_coef: 0.001 - kl_loss_type: mse ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -58,10 +53,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 4.0 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -73,18 +64,6 @@ actor_rollout_ref: log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.000 - trainer: balance_batch: True # total_training_steps: null diff --git a/examples/ppo_countdown/train_countdown.yaml b/examples/ppo_countdown/train_countdown.yaml index ae16122ef7..191c345b90 100644 --- a/examples/ppo_countdown/train_countdown.yaml +++ b/examples/ppo_countdown/train_countdown.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -35,10 +30,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -82,18 +73,6 @@ critic: grad_clip: 1.0 cliprange_value: 0.5 -custom_reward_function: - path: null - name: compute_score - -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/tests/template/verl_config.yaml b/tests/template/verl_config.yaml index b17fc87958..d6dcf4a997 100644 --- a/tests/template/verl_config.yaml +++ b/tests/template/verl_config.yaml @@ -12,11 +12,6 @@ actor_rollout_ref: use_dynamic_bsz: True ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - use_kl_loss: False # True for GRPO - kl_loss_coef: 0.001 # for grpo - kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -35,10 +30,6 @@ actor_rollout_ref: param_offload: False optimizer_offload: False fsdp_size: -1 - # --- below: opmd --- - tau: 0.000 # strength of regularization w.r.t. old / ref policy - opmd_baseline: mean # mean / logavgexp, applicable to opmd - use_uid: False # True / False, applicable to pairwise_opmd ref: fsdp_config: param_offload: False @@ -82,14 +73,6 @@ critic: grad_clip: 1.0 cliprange_value: 0.5 -algorithm: - gamma: 1.0 - lam: 1.0 - kl_penalty: kl # how to estimate kl divergence - kl_ctrl: - type: fixed - kl_coef: 0.001 - trainer: balance_batch: True # total_training_steps: null diff --git a/trinity/algorithm/algorithm.py b/trinity/algorithm/algorithm.py index 805dd8f213..54f5c3d296 100644 --- a/trinity/algorithm/algorithm.py +++ b/trinity/algorithm/algorithm.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """Algorithm classes.""" -from abc import ABC, ABCMeta +from abc import ABC, ABCMeta, abstractmethod from typing import Dict from trinity.buffer.schema.sql_schema import DPODataModel, ExperienceModel, SFTDataModel @@ -30,7 +30,8 @@ class AlgorithmType(ABC, metaclass=ConstantMeta): schema: type @classmethod - def get_default_config(cls) -> Dict: + @abstractmethod + def default_config(cls) -> Dict: raise NotImplementedError @classmethod @@ -53,7 +54,7 @@ class SFTAlgorithm(AlgorithmType): schema: type = SFTDataModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "sample_strategy": "default", "policy_loss_fn": "sft", @@ -73,7 +74,7 @@ class PPOAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 1, "sample_strategy": "warmup", @@ -96,7 +97,7 @@ class GRPOAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 2, "sample_strategy": "warmup", @@ -119,7 +120,7 @@ class OPMDAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 2, "sample_strategy": "warmup", @@ -142,9 +143,8 @@ class DPOAlgorithm(AlgorithmType): schema: type = DPODataModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { - "repeat_times": 2, # fake repeat times "sample_strategy": "dpo", "policy_loss_fn": "dpo", "kl_loss_fn": "k2", @@ -170,10 +170,10 @@ def check_config(cls, config: Config) -> None: "DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`." ) if config.algorithm.repeat_times != 2: - config.algorithm.repeat_times = 2 - logger.warning( - "DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2." - ) # no need to warn + config.algorithm.repeat_times = 2 # Fake repeat times + if config.algorithm.kl_loss_fn in {"none", None}: + config.algorithm.kl_loss_fn = "k2" + logger.warning("DPO must use KL loss. Set `algorithm.kl_loss_fn` to `k2`") @ALGORITHM_TYPE.register_module("mix") @@ -188,7 +188,7 @@ class MIXAlgorithm(AlgorithmType): schema: type = ExperienceModel @classmethod - def get_default_config(cls) -> Dict: + def default_config(cls) -> Dict: return { "repeat_times": 8, "policy_loss_fn": "mix", diff --git a/trinity/algorithm/algorithm_manager.py b/trinity/algorithm/algorithm_manager.py index 3c2983c80b..82cef5ebbd 100644 --- a/trinity/algorithm/algorithm_manager.py +++ b/trinity/algorithm/algorithm_manager.py @@ -12,7 +12,7 @@ class AlgorithmManager: def __init__(self, config: Config): self.config = config sft_type = ALGORITHM_TYPE.get("sft") - sft_default_config = sft_type.get_default_config() + sft_default_config = sft_type.default_config() self.sft_algorithm_config = AlgorithmConfig( algorithm_type="sft", **sft_default_config, diff --git a/trinity/algorithm/sample_strategy/mix_sample_strategy.py b/trinity/algorithm/sample_strategy/mix_sample_strategy.py index acdd340b24..25811e9190 100644 --- a/trinity/algorithm/sample_strategy/mix_sample_strategy.py +++ b/trinity/algorithm/sample_strategy/mix_sample_strategy.py @@ -79,7 +79,7 @@ def sample(self, step: int) -> Tuple[Any, Dict, List]: raise NotImplementedError(f"backend {self.trainer_type} is not supported") @classmethod - def get_default_config(cls) -> Dict: + def default_args(cls) -> Dict: return { "expert_data_ratio": 0.5, } diff --git a/trinity/buffer/reader/file_reader.py b/trinity/buffer/reader/file_reader.py index 8bb9dbcd28..0dd9aef75e 100644 --- a/trinity/buffer/reader/file_reader.py +++ b/trinity/buffer/reader/file_reader.py @@ -67,7 +67,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.response_key = meta.format.response_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split) + load_dataset(meta.path, name=subset_name, split=self.split), + max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) @@ -143,7 +144,8 @@ def __init__(self, meta: StorageConfig, config: BufferConfig): self.rejected_key = meta.format.rejected_key self.read_batch_size = config.read_batch_size self.dataset = _HFBatchReader( - load_dataset(meta.path, name=subset_name, split=self.split) + load_dataset(meta.path, name=subset_name, split=self.split), + max_epoch=meta.total_epochs, ) # TODO: support resume self.tokenizer = transformers.AutoTokenizer.from_pretrained(config.tokenizer_path) diff --git a/trinity/common/config.py b/trinity/common/config.py index 52f8c433fc..9c45627d32 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -206,15 +206,6 @@ class AlgorithmConfig: # TODO: move this to SFT warmup use_token_level_loss: bool = True - # do not set - algorithm_manager: Optional[Any] = None - - def get_current_algorithm_config(self, global_steps: int): - return self.algorithm_manager.get_current_algorithm_config(global_steps) - - def need_save(self, global_steps: int): - return self.algorithm_manager.need_save(global_steps) - @dataclass class ClusterConfig: @@ -303,7 +294,6 @@ class TrainerConfig: # trainer configs actor_grad_clip: Optional[float] = None - actor_clip_ratio: Optional[float] = None # TODO: extract more train-related params from underlying trainer engine # Only one needs to be set for `trainer_config` and `trainer_config_path` @@ -525,7 +515,7 @@ def _check_algorithm(self) -> None: "kl_loss_fn": "k2", "entropy_loss_fn": "default", } - default_config.update(algorithm.get_default_config()) + default_config.update(algorithm.default_config()) for key, value in default_config.items(): if getattr(self.algorithm, key, None) is None: setattr(self.algorithm, key, value) diff --git a/trinity/common/verl_config.py b/trinity/common/verl_config.py index e6b1b9e4e1..1ec0653503 100644 --- a/trinity/common/verl_config.py +++ b/trinity/common/verl_config.py @@ -4,7 +4,6 @@ from omegaconf import OmegaConf -from trinity.algorithm.algorithm import DPOAlgorithm from trinity.common.config import BufferConfig, Config, SynchronizerConfig from trinity.utils.log import get_logger @@ -66,22 +65,19 @@ class Actor: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} ) grad_clip: float = 1.0 - clip_ratio: float = 0.2 - entropy_coeff: float = 0.001 - use_kl_loss: bool = False - kl_loss_coef: float = 0.001 - kl_loss_type: str = "low_var_kl" ppo_epochs: int = 1 shuffle: bool = False ulysses_sequence_parallel_size: int = 1 checkpoint: Checkpoint = field(default_factory=Checkpoint) optim: Optim = field(default_factory=Optim) fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) - algorithm_type: str = "ppo" # TODO - tau: float = 0.001 # strength of regularization w.r.t. old / ref policy - opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd - use_uid: bool = False # True / False, applicable to pairwise_opmd - loss_agg_mode: str = "token-mean" # do not set + # do not set + loss_agg_mode: str = "token-mean" + clip_ratio: float = 0.2 + entropy_coeff: float = 0.001 + use_kl_loss: bool = False + kl_loss_coef: float = 0.001 + kl_loss_type: str = "low_var_kl" @dataclass @@ -208,10 +204,6 @@ class Algorithm: kl_penalty: str = "kl" kl_ctrl: KL_Ctrl = field(default_factory=KL_Ctrl) - # ! DO NOT SET THE FOLLOWING PARAMETERS - policy_loss_fn: str = "ppo" - policy_loss_fn_args: Optional[dict] = None - @dataclass class Trainer: @@ -323,33 +315,19 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901 if config.trainer.actor_grad_clip is not None: self.actor_rollout_ref.actor.grad_clip = config.trainer.actor_grad_clip - if config.trainer.actor_clip_ratio is not None: - self.actor_rollout_ref.actor.clip_ratio = config.trainer.actor_clip_ratio # Algorithm related config - adv_fn_args = config.algorithm.advantage_fn_args - if adv_fn_args is not None and "gamma" in adv_fn_args: - self.algorithm.gamma = adv_fn_args["gamma"] - if adv_fn_args is not None and "lam" in adv_fn_args: - self.algorithm.lam = adv_fn_args["lam"] - self.actor_rollout_ref.actor.algorithm_type = config.algorithm.algorithm_type self.actor_rollout_ref.actor.use_kl_loss = config.algorithm.kl_loss_fn != "none" - self.actor_rollout_ref.actor.kl_loss_coef = config.algorithm.kl_loss_fn_args["kl_coef"] # type: ignore - self.actor_rollout_ref.actor.entropy_coeff = config.algorithm.entropy_loss_fn_args[ # type: ignore - "entropy_coef" - ] + self.algorithm.use_kl_in_reward = config.algorithm.kl_penalty_fn != "none" # TODO (yanxi): it seems that adv_estimator now only affects whether use_critic is set to # True or False in RayPPOTrainer.__init__() (and hence in VerlPPOTrainerWrapper). # Need to double check whether this is indeed the case, # and see if adv_estimator can be removed completely. - if isinstance(self.actor_rollout_ref.actor.algorithm_type, DPOAlgorithm): # for DPO - if not self.actor_rollout_ref.actor.use_kl_loss: - self.actor_rollout_ref.actor.use_kl_loss = True - logger.warning("DPO must use KL loss.") + if config.algorithm.algorithm_type == "dpo": # for DPO logger.warning("DPO micro batch size is doubled for computing loss.") - self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 # type: ignore - self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 # type: ignore + self.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu *= 2 + self.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu *= 2 if self.actor_rollout_ref.rollout.n != 2: self.actor_rollout_ref.rollout.n = 2 # TODO: check other fields diff --git a/trinity/manager/config_manager.py b/trinity/manager/config_manager.py index 80b8992b3b..de4305a9cc 100644 --- a/trinity/manager/config_manager.py +++ b/trinity/manager/config_manager.py @@ -7,10 +7,25 @@ import streamlit as st import yaml -from trinity.common.constants import AlgorithmType, StorageType +from trinity.algorithm.advantage_fn.advantage_fn import ADVANTAGE_FN +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ENTROPY_LOSS_FN +from trinity.algorithm.kl_fn.kl_fn import KL_FN +from trinity.algorithm.policy_loss_fn.policy_loss_fn import POLICY_LOSS_FN +from trinity.algorithm.sample_strategy.sample_strategy import SAMPLE_STRATEGY +from trinity.common.constants import StorageType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic +register_map = { + "sample_strategy": SAMPLE_STRATEGY, + "policy_loss_fn": POLICY_LOSS_FN, + "advantage_fn": ADVANTAGE_FN, + "kl_loss_fn": KL_FN, + "kl_penalty_fn": KL_FN, + "entropy_loss_fn": ENTROPY_LOSS_FN, +} + class ConfigManager: def __init__(self): @@ -47,55 +62,48 @@ def maintain_session_state(self): for key in CONFIG_GENERATORS.default_config: st.session_state[key] = st.session_state[key] - eval_dataset_keys = [ + def maintain_list_state(prefix, key_list): + last_idx, del_num = 0, 0 + for idx in range(st.session_state[f"_{prefix}_num"]): + if st.session_state.get(f"{prefix}_{idx}_del_flag", False): + del_num += 1 + continue + for key in key_list: + full_key = f"{prefix}_{idx}_{key}" + last_full_key = f"{prefix}_{last_idx}_{key}" + st.session_state[last_full_key] = st.session_state[full_key] + last_idx += 1 + st.session_state[f"_{prefix}_num"] -= del_num + + self.eval_dataset_keys = [ "name", "path", - "subset_name", "split", + "subset_name", "prompt_key", "response_key", "temperature", "logprobs", "n", ] - last_idx, del_num = 0, 0 - for idx in range(st.session_state["_eval_tasksets_num"]): - if st.session_state.get(f"eval_taskset_{idx}_del_flag", False): - del_num += 1 - continue - for key in eval_dataset_keys: - full_key = f"eval_taskset_{idx}_{key}" - last_full_key = f"eval_taskset_{last_idx}_{key}" - st.session_state[last_full_key] = st.session_state[full_key] - last_idx += 1 - st.session_state["_eval_tasksets_num"] -= del_num - - auxiliary_model_keys = [ + maintain_list_state("eval_tasksets", self.eval_dataset_keys) + + self.inference_model_keys = [ "model_path", "engine_type", "engine_num", "tensor_parallel_size", - "gpu_memory_utilization", - "dtype", - "seed", "use_v1", "enforce_eager", "enable_prefix_caching", "enable_chunked_prefill", + "gpu_memory_utilization", + "dtype", + "seed", "enable_thinking", "enable_openai_api", ] - last_idx, del_num = 0, 0 - for idx in range(st.session_state["_auxiliary_models_num"]): - if st.session_state.get(f"auxiliary_model_{idx}_del_flag", False): - del_num += 1 - continue - for key in auxiliary_model_keys: - full_key = f"auxiliary_model_{idx}_{key}" - last_full_key = f"auxiliary_model_{last_idx}_{key}" - st.session_state[last_full_key] = st.session_state[full_key] - last_idx += 1 - st.session_state["_auxiliary_models_num"] -= del_num + maintain_list_state("auxiliary_models", self.inference_model_keys) def get_configs(self, *config_names: str, columns_spec: List[int] = None): CONFIG_GENERATORS.get_configs(*config_names, columns_spec=columns_spec) @@ -108,7 +116,7 @@ def beginner_mode(self): self.get_configs("checkpoint_root_dir") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": self.get_configs("taskset_path") else: self.get_configs("experience_buffer_path") @@ -126,7 +134,7 @@ def beginner_mode(self): self.get_configs("sync_interval", "eval_interval", "save_interval") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": self.get_configs("taskset_args") else: self.get_configs("dpo_dataset_kwargs") @@ -136,9 +144,6 @@ def beginner_mode(self): self.get_configs("default_workflow_type", "default_reward_fn_type") - self.get_configs("actor_use_kl_loss") - self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - self.get_configs( "actor_ppo_micro_batch_size_per_gpu", "actor_lr", @@ -165,7 +170,7 @@ def _expert_buffer_part(self): self.get_configs("system_prompt") self.get_configs("reply_prefix") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": with st.expander("Taskset Configs", expanded=True): self.get_configs("taskset_path") self.get_configs("taskset_args") @@ -182,7 +187,7 @@ def _expert_buffer_part(self): self.get_configs("sft_warmup_dataset_path") self.get_configs("sft_warmup_dataset_args") - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": with st.expander("Experiences Buffer Configs", expanded=True): self.get_configs("storage_type") self.get_configs("experience_buffer_path") @@ -213,8 +218,30 @@ def _expert_explorer_part(self): self.get_configs("auxiliary_models") def _expert_trainer_part(self): - self.get_configs("algorithm_type", "gamma", "lam") - self.get_configs("repeat_times", "save_interval") + self.get_configs("algorithm_type", "repeat_times", "save_interval") + self.get_configs("sample_strategy", "advantage_fn", "entropy_loss_fn") + self.get_configs("policy_loss_fn", "kl_penalty_fn", "kl_loss_fn") + + with st.expander("Advanced Algorithm Config"): + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + config_key_list = [] + for key in default_config.keys(): + value = st.session_state[key] + if key == "repeat_times": + continue + default_args = register_map[key].get(value).default_args() + for sub_key in default_args.keys(): + full_key = sub_key + "_in_" + key + config_key_list.append(full_key) + + idx = 0 + while idx < len(config_key_list): + delta = 3 if len(config_key_list) - idx != 4 else 2 + key_list = config_key_list[idx : idx + delta] + idx += delta + self.get_configs(*key_list) + self.get_configs("enable_preview") if st.session_state["trainer_type"] == "verl": @@ -238,12 +265,6 @@ def _expert_verl_training_part(self): self.get_configs("max_actor_ckpt_to_keep", "max_critic_ckpt_to_keep") - def _expert_verl_algorithm_part(self): - st.subheader("RL Algorithm Config") - self.get_configs("norm_adv_by_std_in_grpo", "use_kl_in_reward") - self.get_configs("kl_penalty", "kl_ctrl_type", "kl_ctrl_coef") - self.get_configs("horizon", "target_kl") - def _expert_verl_actor_part(self): st.subheader("Actor Model Config") self.get_configs( @@ -254,12 +275,7 @@ def _expert_verl_actor_part(self): self.get_configs("actor_lr", "actor_warmup_style", "actor_lr_warmup_steps_ratio") - self.get_configs("actor_grad_clip", "actor_clip_ratio", "actor_entropy_coef") - - self.get_configs("actor_use_kl_loss", "actor_use_uid") - self.get_configs("actor_kl_loss_coef", "actor_kl_loss_type") - - self.get_configs("actor_tau", "actor_opmd_baseline") + self.get_configs("actor_grad_clip") self.get_configs("actor_checkpoint") @@ -277,7 +293,6 @@ def _expert_verl_critic_part(self): def _expert_verl_trainer_part(self): name2func = { "RL Training Config": self._expert_verl_training_part, - "RL Algorithm Config": self._expert_verl_algorithm_part, "Actor and Ref Config": self._expert_verl_actor_part, } if use_critic(): @@ -359,9 +374,6 @@ def _generate_verl_config(self): ), }, "fsdp_config": copy.deepcopy(fsdp_config), - "tau": st.session_state["actor_tau"], - "opmd_baseline": st.session_state["actor_opmd_baseline"], - "use_uid": st.session_state["actor_use_uid"], }, "ref": { "fsdp_config": copy.deepcopy(fsdp_config), @@ -375,14 +387,7 @@ def _generate_verl_config(self): ], }, }, - "custom_reward_function": {"path": None, "name": "compute_score"}, - "algorithm": { - "kl_penalty": st.session_state["kl_penalty"], - "kl_ctrl": { - "type": st.session_state["kl_ctrl_type"], - "kl_coef": st.session_state["kl_ctrl_coef"], - }, - }, + "critic": {}, "trainer": { "balance_batch": balance_batch, "resume_mode": st.session_state["resume_mode"], @@ -436,11 +441,35 @@ def _generate_verl_config(self): "cliprange_value": st.session_state["critic_cliprange_value"], "checkpoint": {"contents": st.session_state["critic_checkpoint"]}, } + else: + del trainer_config["critic"] return trainer_config + def _gen_algorithm_config(self): + algorithm_config = { + "algorithm_type": st.session_state["algorithm_type"], + } + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + current_config = {} + for key in default_config.keys(): + current_config[key] = value = st.session_state[key] + if key == "repeat_times": + continue + default_args = register_map[key].get(value).default_args() + args = {} + for sub_key in default_args.keys(): + full_key = sub_key + "_in_" + key + args[sub_key] = st.session_state.get(full_key, default_args[sub_key]) + if default_args != args: + current_config[key + "_args"] = args + if default_config != current_config: + algorithm_config.update(current_config) + return algorithm_config + def _gen_buffer_config(self): experience_buffer_path = st.session_state["experience_buffer_path"].strip() - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": if ( not experience_buffer_path and st.session_state["storage_type"] == StorageType.SQL.value @@ -456,6 +485,7 @@ def _gen_buffer_config(self): buffer_config = { "batch_size": st.session_state["train_batch_size"], "total_epochs": st.session_state["total_epochs"], + "explorer_input": {}, "trainer_input": { "experience_buffer": { "name": "experience_buffer", @@ -497,13 +527,25 @@ def _gen_buffer_config(self): { "name": st.session_state[f"eval_taskset_{idx}_name"], "path": st.session_state[f"eval_taskset_{idx}_path"], - "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], "split": st.session_state[f"eval_taskset_{idx}_split"], - "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], - "response_key": st.session_state[f"eval_taskset_{idx}_response_key"], + "subset_name": st.session_state[f"eval_taskset_{idx}_subset_name"], + "format": { + "prompt_key": st.session_state[f"eval_taskset_{idx}_prompt_key"], + "response_key": st.session_state[ + f"eval_taskset_{idx}_response_key" + ], + }, + "rollout_args": { + "temperature": st.session_state[f"eval_taskset_{idx}_temperature"], + "logprobs": st.session_state[f"eval_taskset_{idx}_logprobs"], + "n": st.session_state[f"eval_taskset_{idx}_n"], + }, } ) - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + else: + del buffer_config["explorer_input"] + + if st.session_state["algorithm_type"] == "dpo": experience_buffer = buffer_config["trainer_input"]["experience_buffer"] experience_buffer["split"] = st.session_state["dpo_dataset_train_split"] experience_buffer["format"] = { @@ -534,26 +576,23 @@ def _gen_explorer_config(self): "max_timeout": st.session_state["max_timeout"], "max_retry_times": st.session_state["explorer_max_retry_times"], "rollout_model": { - "engine_type": st.session_state["engine_type"], - "engine_num": st.session_state["engine_num"], - "tensor_parallel_size": st.session_state["tensor_parallel_size"], - "use_v1": st.session_state["use_v1"], - "enforce_eager": st.session_state["enforce_eager"], - "enable_prefix_caching": st.session_state["enable_prefix_caching"], - "enable_chunked_prefill": st.session_state["enable_chunked_prefill"], - "gpu_memory_utilization": st.session_state["gpu_memory_utilization"], - "dtype": st.session_state["dtype"], - "seed": st.session_state["seed"], + key: st.session_state[key] + for key in self.inference_model_keys + if key != "model_path" # "max_prompt_tokens": None, # TODO # "max_response_tokens": None, # TODO # "chat_template": None, # TODO: add chat template - "enable_thinking": st.session_state["enable_thinking"], - "enable_openai_api": st.session_state["enable_openai_api"], }, "auxiliary_models": [], "eval_interval": st.session_state["eval_interval"], "eval_on_latest_checkpoint": st.session_state["eval_on_latest_checkpoint"], } + for i in range(st.session_state["_auxiliary_models_num"]): + auxiliary_model_config = { + key: st.session_state[f"auxiliary_model_{i}_{key}"] + for key in self.inference_model_keys + } + explorer_config["auxiliary_models"].append(auxiliary_model_config) return explorer_config def generate_config(self): @@ -585,12 +624,7 @@ def generate_config(self): "project": st.session_state["project"], "name": st.session_state["exp_name"], "checkpoint_root_dir": st.session_state["checkpoint_root_dir"], - "algorithm": { - "algorithm_type": st.session_state["algorithm_type"], - "repeat_times": st.session_state["repeat_times"], - "gamma": st.session_state["gamma"], - "lam": st.session_state["lam"], - }, + "algorithm": self._gen_algorithm_config(), "data_processor": {}, # TODO: Add data processor config "model": { "model_path": st.session_state["model_path"], @@ -607,11 +641,7 @@ def generate_config(self): "trainer_type": st.session_state["trainer_type"], "save_interval": st.session_state["save_interval"], "enable_preview": st.session_state["enable_preview"], - "actor_use_kl_loss": st.session_state["actor_use_kl_loss"], - "actor_kl_loss_coef": st.session_state["actor_kl_loss_coef"], - "actor_entropy_coef": st.session_state["actor_entropy_coef"], "actor_grad_clip": st.session_state["actor_grad_clip"], - "actor_clip_ratio": st.session_state["actor_clip_ratio"], "trainer_config": trainer_config, }, "monitor": { diff --git a/trinity/manager/config_registry/__init__.py b/trinity/manager/config_registry/__init__.py index e62c565fb4..3896582755 100644 --- a/trinity/manager/config_registry/__init__.py +++ b/trinity/manager/config_registry/__init__.py @@ -1,3 +1,4 @@ +import trinity.manager.config_registry.algorithm_config_manager as algorithm_config_manager import trinity.manager.config_registry.buffer_config_manager as buffer_config_manager import trinity.manager.config_registry.explorer_config_manager as explorer_config_manager import trinity.manager.config_registry.model_config_manager as model_config_manager @@ -6,6 +7,7 @@ __all__ = [ "CONFIG_GENERATORS", + "algorithm_config_manager", "buffer_config_manager", "explorer_config_manager", "model_config_manager", diff --git a/trinity/manager/config_registry/algorithm_config_manager.py b/trinity/manager/config_registry/algorithm_config_manager.py new file mode 100644 index 0000000000..c9694dec25 --- /dev/null +++ b/trinity/manager/config_registry/algorithm_config_manager.py @@ -0,0 +1,371 @@ +import streamlit as st + +from trinity.algorithm.advantage_fn import ( + ADVANTAGE_FN, + GRPOAdvantageFn, + OPMDAdvantageFn, + PPOAdvantageFn, +) +from trinity.algorithm.algorithm import ALGORITHM_TYPE, PPOAlgorithm +from trinity.algorithm.entropy_loss_fn.entropy_loss_fn import ( + ENTROPY_LOSS_FN, + EntropyLossFn, +) +from trinity.algorithm.kl_fn.kl_fn import KL_FN, KLFn +from trinity.algorithm.policy_loss_fn import ( + POLICY_LOSS_FN, + DPOLossFn, + MIXPolicyLossFn, + OPMDPolicyLossFn, + PPOPolicyLossFn, + SFTLossFn, +) +from trinity.algorithm.sample_strategy import SAMPLE_STRATEGY, MixSampleStrategy +from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS +from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num + + +@CONFIG_GENERATORS.register_config( + default_value="ppo", + other_configs={"mode": "both", "_current_default_config": PPOAlgorithm.default_config()}, +) +def set_algorithm_type(**kwargs): + def on_change(): + if st.session_state["algorithm_type"] == "dpo": + st.session_state["mode"] = "train" + else: + st.session_state["mode"] = "both" + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + default_config = algorithm.default_config() + st.session_state["_current_default_config"] = default_config + for key, value in default_config.items(): + st.session_state[key] = value + set_trainer_gpu_num() + + candidates = list(ALGORITHM_TYPE.modules.keys()) + st.selectbox( + "Algorithm Type", + candidates, + on_change=on_change, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["repeat_times"], + visible=lambda: "repeat_times" in st.session_state["_current_default_config"], + other_configs={ + "_grouped_adv_repeat_times": 2, + "_not_grouped_adv_repeat_times": 1, + }, +) +def set_repeat_times(**kwargs): # TODO + key = kwargs.get("key") + grouped_adv_algorithms = [ + "grpo", + "opmd", # TODO: may add rloo + ] + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + min_repeat_times = 2 + st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] + else: + min_repeat_times = 1 + st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] + + def on_change(): + if st.session_state["algorithm_type"] in grouped_adv_algorithms: + st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] + else: + st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] + + st.number_input( + "Repeat Times", + min_value=min_repeat_times, + help="`repeat_times` is used to set how many experiences each task can generate, " + "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", + on_change=on_change, + **kwargs, + ) + + +# Sample_strategy Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["sample_strategy"], + visible=lambda: "sample_strategy" in st.session_state["_current_default_config"], +) +def set_sample_strategy(**kwargs): + candidates = list(SAMPLE_STRATEGY.modules.keys()) + st.selectbox( + "Sample Strategy", + candidates, + help="The sample strategy used to obtain experiences.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=MixSampleStrategy.default_args()["expert_data_ratio"], + visible=lambda: st.session_state["sample_strategy"] == "mix", +) +def set_expert_data_ratio_in_sample_strategy(**kwargs): + st.number_input( + "Expert Data Ratio", + min_value=0.0, + max_value=1.0, + value=0.5, + help="The ratio of expert data to be used in the training.", + **kwargs, + ) + + +# Advantage Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["advantage_fn"], + visible=lambda: "advantage_fn" in st.session_state["_current_default_config"], +) +def set_advantage_fn(**kwargs): + candidates = list(ADVANTAGE_FN.modules.keys()) + st.selectbox( + "Advantage Function", + candidates, + help="The advantage function used to compute advantages.", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAdvantageFn.default_args()["gamma"], + visible=lambda: st.session_state["advantage_fn"] in {"ppo", "reinforceplusplus"}, +) +def set_gamma_in_advantage_fn(**kwargs): + st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAdvantageFn.default_args()["lam"], + visible=lambda: st.session_state["advantage_fn"] == "ppo", +) +def set_lam_in_advantage_fn(**kwargs): + st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=GRPOAdvantageFn.default_args()["epsilon"], + visible=lambda: st.session_state["advantage_fn"] == "grpo", +) +def set_epsilon_in_advantage_fn(**kwargs): # TODO: update help message + st.number_input( + r"GRPO Epsilon", + help=r""" +```python +scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) +``` +""", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDAdvantageFn.default_args()["opmd_baseline"], + visible=lambda: st.session_state["advantage_fn"] == "opmd", +) +def set_opmd_baseline_in_advantage_fn(**kwargs): + st.selectbox( + "OPMD Baseline", + ["mean", "logavgexp"], + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDAdvantageFn.default_args()["tau"], + visible=lambda: st.session_state["advantage_fn"] == "opmd" + and st.session_state["opmd_baseline_in_advantage_fn"] == "logavgexp", +) +def set_tau_in_advantage_fn(**kwargs): + st.number_input("Tau for OPMD Adv.", min_value=0.0, format="%.1e", **kwargs) + + +# KL Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["kl_loss_fn"], + visible=lambda: "kl_loss_fn" in st.session_state["_current_default_config"], +) +def set_kl_loss_fn(**kwargs): + candidates = list(KL_FN.modules.keys()) + st.selectbox( + "KL Loss Type", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["kl_coef"], + visible=lambda: st.session_state["kl_loss_fn"] != "none", +) +def set_kl_coef_in_kl_loss_fn(**kwargs): + st.number_input( + r"KL Loss Coef :blue-badge[$\beta$]", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +# KL Penalty Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["kl_penalty_fn"], + visible=lambda: "kl_penalty_fn" in st.session_state["_current_default_config"], +) +def set_kl_penalty_fn(**kwargs): + candidates = list(KL_FN.modules.keys()) + st.selectbox( + "KL Penalty Type", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["adaptive"], + visible=lambda: st.session_state["kl_penalty_fn"] != "none", +) +def set_adaptive_in_kl_penalty_fn(**kwargs): + st.checkbox( + "Adaptive KL Penalty", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=KLFn.default_args()["kl_coef"], + visible=lambda: st.session_state["kl_penalty_fn"] != "none", +) +def set_kl_coef_in_kl_penalty_fn(**kwargs): + st.number_input( + r"KL Penalty Coef", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) + + +# TODO: target_kl and horizon + +# Policy Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["policy_loss_fn"], + visible=lambda: "policy_loss_fn" in st.session_state["_current_default_config"], +) +def set_policy_loss_fn(**kwargs): + candidates = list(POLICY_LOSS_FN.modules.keys()) + st.selectbox( + "Policy Loss Fn", + candidates, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=PPOPolicyLossFn.default_args()["clip_range"], + visible=lambda: st.session_state["policy_loss_fn"] in {"ppo", "mix"}, +) +def set_clip_range_in_policy_loss_fn(**kwargs): + st.number_input( + "Clip Range", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=SFTLossFn.default_args()["use_token_level_loss"], + visible=lambda: st.session_state["policy_loss_fn"] == "sft", +) +def set_use_token_level_loss_in_policy_loss_fn(**kwargs): + st.checkbox( + "Use Token Level Loss", + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=DPOLossFn.default_args()["beta"], + visible=lambda: st.session_state["policy_loss_fn"] == "dpo", +) +def set_beta_in_policy_loss_fn(**kwargs): + st.number_input( + "Beta for DPO", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=DPOLossFn.default_args()["label_smoothing"], + visible=lambda: st.session_state["policy_loss_fn"] == "dpo", +) +def set_label_smoothing_in_policy_loss_fn(**kwargs): + st.number_input( + "Label Smoothing", + min_value=0.0, + max_value=1.0, + **kwargs, + ) + + +@CONFIG_GENERATORS.register_config( + default_value=OPMDPolicyLossFn.default_args()["tau"], + visible=lambda: st.session_state["policy_loss_fn"] == "opmd", +) +def set_tau_in_policy_loss_fn(**kwargs): + st.number_input("Tau for OPMD Loss", min_value=0.0, format="%.1e", **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=MIXPolicyLossFn.default_args()["mu"], + visible=lambda: st.session_state["policy_loss_fn"] == "mix", +) +def set_mu_in_policy_loss_fn(**kwargs): + st.number_input("Mu for Mix Policy Loss", min_value=0.0, **kwargs) + + +# Entropy Loss Configs + + +@CONFIG_GENERATORS.register_config( + default_value=PPOAlgorithm.default_config()["entropy_loss_fn"], + visible=lambda: "entropy_loss_fn" in st.session_state["_current_default_config"], +) +def set_entropy_loss_fn(**kwargs): + candidates = list(ENTROPY_LOSS_FN.modules.keys()) + st.selectbox("Entropy Loss Function", candidates, **kwargs) + + +@CONFIG_GENERATORS.register_config( + default_value=EntropyLossFn.default_args()["entropy_coef"], + visible=lambda: st.session_state["entropy_loss_fn"] != "none", +) +def set_entropy_coef_in_entropy_loss_fn(**kwargs): + st.number_input( + "Entropy Coeff", + min_value=0.0, + max_value=1.0, + format="%.1e", + **kwargs, + ) diff --git a/trinity/manager/config_registry/buffer_config_manager.py b/trinity/manager/config_registry/buffer_config_manager.py index 044f982e94..f704d0ecd2 100644 --- a/trinity/manager/config_registry/buffer_config_manager.py +++ b/trinity/manager/config_registry/buffer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, PromptType, StorageType +from trinity.common.constants import PromptType, StorageType from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS from trinity.common.workflows.workflow import WORKFLOWS from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS @@ -264,7 +264,7 @@ def set_reply_prefix(**kwargs): ) def set_storage_type(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state[key] = st.session_state["_dpo_storage_type"] storage_candidates = [StorageType.FILE.value, StorageType.SQL.value] else: @@ -272,7 +272,7 @@ def set_storage_type(**kwargs): storage_candidates = [StorageType.QUEUE.value, StorageType.SQL.value] def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_storage_type"] = st.session_state[key] else: st.session_state["_not_dpo_storage_type"] = st.session_state[key] @@ -294,7 +294,7 @@ def on_change(): ) def set_experience_buffer_path(**kwargs): # TODO key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": if st.session_state["taskset_path"] and not st.session_state["_dpo_experience_buffer_path"]: st.session_state["_dpo_experience_buffer_path"] = st.session_state["taskset_path"] st.session_state[key] = st.session_state["_dpo_experience_buffer_path"] @@ -314,7 +314,7 @@ def set_experience_buffer_path(**kwargs): # TODO if `storage_type == StorageType.SQL`, default to `sqlite:///{os.path.join(checkpoint_root_dir, '.cache', project_name, experiment_name)}/data.db`.""" def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state["_dpo_experience_buffer_path"] = st.session_state[key] else: st.session_state["_not_dpo_experience_buffer_path"] = st.session_state[key] @@ -324,7 +324,7 @@ def on_change(): @CONFIG_GENERATORS.register_check() def check_experience_buffer_path(unfinished_fields: set, key: str): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": if not st.session_state[key].strip(): unfinished_fields.add(key) st.warning("Please input DPO dataset path.") diff --git a/trinity/manager/config_registry/explorer_config_manager.py b/trinity/manager/config_registry/explorer_config_manager.py index 9393187f60..12e8034a30 100644 --- a/trinity/manager/config_registry/explorer_config_manager.py +++ b/trinity/manager/config_registry/explorer_config_manager.py @@ -1,6 +1,6 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.common.constants import SyncMethod from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.model_config_manager import set_trainer_gpu_num @@ -255,7 +255,7 @@ def check_auxiliary_models(unfinished_fields: set, key: str): ) def set_sync_method(**kwargs): key = kwargs.get("key") - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] == "dpo": st.session_state[key] = SyncMethod.CHECKPOINT.value disabled = True else: @@ -263,7 +263,7 @@ def set_sync_method(**kwargs): disabled = False def on_change(): - if st.session_state["algorithm_type"] != AlgorithmType.DPO.value: + if st.session_state["algorithm_type"] != "dpo": st.session_state["_not_dpo_sync_method"] = st.session_state[key] st.selectbox( diff --git a/trinity/manager/config_registry/model_config_manager.py b/trinity/manager/config_registry/model_config_manager.py index 837bf27679..f9014e58a1 100644 --- a/trinity/manager/config_registry/model_config_manager.py +++ b/trinity/manager/config_registry/model_config_manager.py @@ -2,10 +2,9 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, MonitorType +from trinity.common.constants import MonitorType from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS from trinity.manager.config_registry.trainer_config_manager import use_critic -from trinity.trainer.verl.ray_trainer import AdvantageEstimator def set_total_gpu_num(): @@ -64,91 +63,6 @@ def set_monitor_type(**kwargs): ) -# Algorithm Configs - - -@CONFIG_GENERATORS.register_config( - default_value=AlgorithmType.PPO.value, - other_configs={"mode": "both", "adv_estimator": AdvantageEstimator.GAE.value}, -) -def set_algorithm_type(**kwargs): - def on_change(): - if st.session_state["algorithm_type"] == AlgorithmType.PPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GAE.value - elif st.session_state["algorithm_type"] == AlgorithmType.GRPO.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["mode"] = "train" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - elif st.session_state["algorithm_type"] == AlgorithmType.OPMD.value: - st.session_state["mode"] = "both" - st.session_state["adv_estimator"] = AdvantageEstimator.GRPO.value - else: # TODO: add more algorithms - pass - set_trainer_gpu_num() - - st.selectbox( - "Algorithm Type", - [ - AlgorithmType.PPO.value, - AlgorithmType.GRPO.value, - AlgorithmType.DPO.value, - AlgorithmType.OPMD.value, - ], - key="algorithm_type", - on_change=on_change, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=1, - visible=lambda: st.session_state["mode"] == "both", - other_configs={ - "_grouped_adv_repeat_times": 2, - "_not_grouped_adv_repeat_times": 1, - }, -) -def set_repeat_times(**kwargs): # TODO - key = kwargs.get("key") - grouped_adv_algorithms = [ - AlgorithmType.GRPO.value, - AlgorithmType.OPMD.value, # TODO: may add rloo - ] - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - min_repeat_times = 2 - st.session_state[key] = st.session_state["_grouped_adv_repeat_times"] - else: - min_repeat_times = 1 - st.session_state[key] = st.session_state["_not_grouped_adv_repeat_times"] - - def on_change(): - if st.session_state["algorithm_type"] in grouped_adv_algorithms: - st.session_state["_grouped_adv_repeat_times"] = st.session_state[key] - else: - st.session_state["_not_grouped_adv_repeat_times"] = st.session_state[key] - - st.number_input( - "Repeat Times", - min_value=min_repeat_times, - help="`repeat_times` is used to set how many experiences each task can generate, " - "and it must be greater than `1` when `algorithm_type` is `opmd` or `grpo`.", - on_change=on_change, - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config(default_value=1.0) -def set_gamma(**kwargs): - st.number_input(r"Gamma :blue-badge[$\gamma$]", **kwargs) - - -@CONFIG_GENERATORS.register_config(default_value=1.0) -def set_lam(**kwargs): - st.number_input(r"Lambda :blue-badge[$\lambda$]", **kwargs) - - # Model Configs diff --git a/trinity/manager/config_registry/trainer_config_manager.py b/trinity/manager/config_registry/trainer_config_manager.py index d0f5d26897..9b3e5f3ea9 100644 --- a/trinity/manager/config_registry/trainer_config_manager.py +++ b/trinity/manager/config_registry/trainer_config_manager.py @@ -1,12 +1,13 @@ import streamlit as st -from trinity.common.constants import AlgorithmType, SyncMethod +from trinity.algorithm.algorithm import ALGORITHM_TYPE +from trinity.common.constants import SyncMethod from trinity.manager.config_registry.config_registry import CONFIG_GENERATORS -from trinity.trainer.verl.ray_trainer import AdvantageEstimator def use_critic(): - return st.session_state["adv_estimator"] == AdvantageEstimator.GAE.value + algorithm = ALGORITHM_TYPE.get(st.session_state["algorithm_type"]) + return algorithm.use_critic @CONFIG_GENERATORS.register_config(default_value="verl") @@ -18,7 +19,7 @@ def set_trainer_type(**kwargs): def set_save_interval(**kwargs): key = kwargs.get("key") if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value + st.session_state["algorithm_type"] == "dpo" or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state[key] = st.session_state["_nccl_save_interval"] @@ -29,7 +30,7 @@ def set_save_interval(**kwargs): def on_change(): if ( - st.session_state["algorithm_type"] == AlgorithmType.DPO.value + st.session_state["algorithm_type"] == "dpo" or st.session_state["sync_method"] == SyncMethod.NCCL.value ): st.session_state["_nccl_save_interval"] = st.session_state[key] @@ -49,54 +50,6 @@ def set_enable_preview(**kwargs): st.checkbox("Enable Preview", **kwargs) -def _actor_use_kl_loss_visible(): - if st.session_state["algorithm_type"] == AlgorithmType.DPO.value: - st.session_state["actor_use_kl_loss"] = True - return False - return True - - -@CONFIG_GENERATORS.register_config( - default_value=True, - visible=_actor_use_kl_loss_visible, - other_configs={"_not_dpo_actor_use_kl_loss": True}, -) -def set_actor_use_kl_loss(**kwargs): - key = kwargs.get("key") - st.session_state[key] = st.session_state["_not_dpo_actor_use_kl_loss"] - - def on_change(): - st.session_state["_not_dpo_actor_use_kl_loss"] = st.session_state[key] - - st.checkbox("Use KL Loss", on_change=on_change, **kwargs) - - -@CONFIG_GENERATORS.register_config( - default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] -) -def set_actor_kl_loss_coef(**kwargs): - st.number_input( - r"KL Loss Coef :blue-badge[$\beta$]", - min_value=0.0, - max_value=1.0, - format="%.1e", - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=0.001, visible=lambda: st.session_state["actor_use_kl_loss"] -) -def set_actor_entropy_coef(**kwargs): - st.number_input( - "Entropy Coeff", - min_value=0.0, - max_value=1.0, - format="%.1e", - **kwargs, - ) - - @CONFIG_GENERATORS.register_config(default_value=1.0) def set_actor_grad_clip(**kwargs): st.number_input( @@ -108,16 +61,6 @@ def set_actor_grad_clip(**kwargs): ) -@CONFIG_GENERATORS.register_config(default_value=0.2) -def set_actor_clip_ratio(**kwargs): - st.number_input( - r"Clip Ratio :blue-badge[$\epsilon$]", - min_value=0.0, - max_value=1.0, - **kwargs, - ) - - # veRL Trainer Configs @@ -322,31 +265,6 @@ def set_actor_lr_warmup_steps_ratio(**kwargs): ) -@CONFIG_GENERATORS.register_config( - default_value=0.0, visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_tau(**kwargs): - st.number_input("Tau for OPMD", min_value=0.0, format="%.1e", **kwargs) - - -@CONFIG_GENERATORS.register_config( - default_value="mean", visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_opmd_baseline(**kwargs): - st.selectbox( - "OPMD Baseline", - ["mean", "logavgexp"], - **kwargs, - ) - - -@CONFIG_GENERATORS.register_config( - default_value=False, visible=lambda: st.session_state["algorithm_type"] == "opmd" -) -def set_actor_use_uid(**kwargs): - st.checkbox("Use UID for OPMD", **kwargs) - - @CONFIG_GENERATORS.register_config(default_value="low_var_kl") def set_actor_kl_loss_type(**kwargs): st.selectbox(