@@ -68,6 +68,7 @@ class Checkpoint:
6868class Actor :
6969 strategy : str = "fsdp"
7070 ppo_mini_batch_size : int = 256
71+ ppo_micro_batch_size : Any = None
7172 ppo_micro_batch_size_per_gpu : int = 1
7273 use_dynamic_bsz : bool = False
7374 ppo_max_token_len_per_gpu : int = (
@@ -94,6 +95,7 @@ class Actor:
9495@dataclass
9596class Ref :
9697 fsdp_config : FSDPConfig = field (default_factory = FSDPConfig )
98+ log_prob_micro_batch_size : Any = None
9799 log_prob_micro_batch_size_per_gpu : int = 1
98100 log_prob_use_dynamic_bsz : bool = False
99101 log_prob_max_token_len_per_gpu : int = 0
@@ -119,6 +121,7 @@ class Rollout:
119121 max_num_batched_tokens : int = 8192
120122 max_model_len : Optional [int ] = None
121123 max_num_seqs : int = 1024
124+ log_prob_micro_batch_size : Any = None
122125 log_prob_micro_batch_size_per_gpu : int = 1
123126 log_prob_use_dynamic_bsz : bool = False
124127 log_prob_max_token_len_per_gpu : int = 0
@@ -155,6 +158,7 @@ class Critic:
155158 optim : Optim = field (default_factory = Optim )
156159 model : CriticModel = field (default_factory = CriticModel )
157160 ppo_mini_batch_size : int = 0
161+ ppo_micro_batch_size : Any = None
158162 ppo_micro_batch_size_per_gpu : int = 1
159163 forward_micro_batch_size : Optional [int ] = None
160164 forward_micro_batch_size_per_gpu : Optional [int ] = None
0 commit comments