Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
requires-python = ">=3.10"
dependencies = [
"verl==0.3.0.post1",
"verl==0.4.0",
"ray[default]>=2.45.0",
"vllm==0.8.5.post1",
"tensordict==0.6.2",
Expand Down
10 changes: 10 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class Actor:
tau: float = 0.001 # strength of regularization w.r.t. old / ref policy
opmd_baseline: str = "mean" # mean / logavgexp, applicable to opmd
use_uid: bool = False # True / False, applicable to pairwise_opmd
loss_agg_mode: str = "token-mean" # do not set


@dataclass
Expand All @@ -99,12 +100,20 @@ class _ValKwargs:
do_sample: bool = False


@dataclass
class _MultiTurn:
enable: bool = False


@dataclass
class Rollout:
# do not set
val_kwargs: _ValKwargs = field(default_factory=_ValKwargs)
multi_turn: _MultiTurn = field(default_factory=_MultiTurn)
temperature: float = 1.0
n: int = 1 # > 1 for grpo
log_prob_micro_batch_size: Optional[int] = None
log_prob_micro_batch_size_per_gpu: int = 1


@dataclass
Expand Down Expand Up @@ -148,6 +157,7 @@ class Critic:
cliprange_value: float = 0.0
checkpoint: Checkpoint = field(default_factory=Checkpoint)
rollout_n: int = 1
loss_agg_mode: str = "token-mean"


@dataclass
Expand Down
Loading