Skip to content

Commit 585f534

Browse files
committed
deploy: d613940
1 parent e820f5d commit 585f534

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

_modules/trinity/common/config.html

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ <h1>Source code for trinity.common.config</h1><div class="highlight"><pre>
356356
<span class="k">class</span><span class="w"> </span><span class="nc">Config</span><span class="p">:</span>
357357
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Global Configuration&quot;&quot;&quot;</span>
358358

359-
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;both&quot;</span> <span class="c1"># `explore`, `train` or `both`</span>
359+
<span class="n">mode</span><span class="p">:</span> <span class="nb">str</span> <span class="o">=</span> <span class="s2">&quot;both&quot;</span> <span class="c1"># `explore`, `train`, `both` or `bench`</span>
360360
<span class="n">data</span><span class="p">:</span> <span class="n">DataConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">DataConfig</span><span class="p">)</span>
361361
<span class="n">model</span><span class="p">:</span> <span class="n">ModelConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">ModelConfig</span><span class="p">)</span>
362362
<span class="n">cluster</span><span class="p">:</span> <span class="n">ClusterConfig</span> <span class="o">=</span> <span class="n">field</span><span class="p">(</span><span class="n">default_factory</span><span class="o">=</span><span class="n">ClusterConfig</span><span class="p">)</span>
@@ -425,7 +425,7 @@ <h1>Source code for trinity.common.config</h1><div class="highlight"><pre>
425425
<span class="k">def</span><span class="w"> </span><span class="nf">check_and_update</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># noqa: C901</span>
426426
<span class="w"> </span><span class="sd">&quot;&quot;&quot;Check and update the config.&quot;&quot;&quot;</span>
427427
<span class="c1"># check mode</span>
428-
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;explore&quot;</span><span class="p">,</span> <span class="s2">&quot;train&quot;</span><span class="p">,</span> <span class="s2">&quot;both&quot;</span><span class="p">]:</span>
428+
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;explore&quot;</span><span class="p">,</span> <span class="s2">&quot;train&quot;</span><span class="p">,</span> <span class="s2">&quot;both&quot;</span><span class="p">,</span> <span class="s2">&quot;bench&quot;</span><span class="p">]:</span>
429429
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Invalid mode: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">mode</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
430430
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">algorithm_type</span> <span class="o">==</span> <span class="n">AlgorithmType</span><span class="o">.</span><span class="n">DPO</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;both&quot;</span><span class="p">:</span>
431431
<span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;DPO does not support `both` mode&quot;</span><span class="p">)</span>
@@ -448,6 +448,11 @@ <h1>Source code for trinity.common.config</h1><div class="highlight"><pre>
448448
<span class="bp">self</span><span class="o">.</span><span class="n">explorer</span><span class="o">.</span><span class="n">engine_num</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">explorer</span><span class="o">.</span><span class="n">tensor_parallel_size</span>
449449
<span class="p">)</span>
450450
<span class="bp">self</span><span class="o">.</span><span class="n">synchronizer</span><span class="o">.</span><span class="n">backend</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">explorer</span><span class="o">.</span><span class="n">backend</span>
451+
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">mode</span> <span class="o">==</span> <span class="s2">&quot;bench&quot;</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">synchronizer</span><span class="o">.</span><span class="n">sync_method</span> <span class="o">!=</span> <span class="n">SyncMethod</span><span class="o">.</span><span class="n">CHECKPOINT</span><span class="p">:</span>
452+
<span class="bp">self</span><span class="o">.</span><span class="n">synchronizer</span><span class="o">.</span><span class="n">sync_method</span> <span class="o">=</span> <span class="s2">&quot;checkpoint&quot;</span>
453+
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span>
454+
<span class="s2">&quot;Bench mode only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`.&quot;</span>
455+
<span class="p">)</span>
451456
<span class="k">if</span> <span class="p">(</span>
452457
<span class="bp">self</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">algorithm_type</span> <span class="o">==</span> <span class="n">AlgorithmType</span><span class="o">.</span><span class="n">DPO</span>
453458
<span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">synchronizer</span><span class="o">.</span><span class="n">sync_method</span> <span class="o">!=</span> <span class="n">SyncMethod</span><span class="o">.</span><span class="n">CHECKPOINT</span>

_modules/trinity/common/verl_config.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,9 +459,10 @@ <h1>Source code for trinity.common.verl_config</h1><div class="highlight"><pre>
459459
<span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">actor</span><span class="o">.</span><span class="n">use_kl_loss</span> <span class="o">=</span> <span class="kc">True</span>
460460
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;DPO must use KL loss.&quot;</span><span class="p">)</span>
461461
<span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;DPO micro batch size is doubled for computing loss.&quot;</span><span class="p">)</span>
462-
<span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">actor</span><span class="o">.</span><span class="n">ppo_mini_batch_size</span> <span class="o">*=</span> <span class="mi">2</span>
463462
<span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">actor</span><span class="o">.</span><span class="n">ppo_micro_batch_size_per_gpu</span> <span class="o">*=</span> <span class="mi">2</span> <span class="c1"># type: ignore</span>
464463
<span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">ref</span><span class="o">.</span><span class="n">log_prob_micro_batch_size_per_gpu</span> <span class="o">*=</span> <span class="mi">2</span> <span class="c1"># type: ignore</span>
464+
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">rollout</span><span class="o">.</span><span class="n">n</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span>
465+
<span class="bp">self</span><span class="o">.</span><span class="n">actor_rollout_ref</span><span class="o">.</span><span class="n">rollout</span><span class="o">.</span><span class="n">n</span> <span class="o">=</span> <span class="mi">2</span>
465466
<span class="c1"># TODO: check other fields</span>
466467
<span class="bp">self</span><span class="o">.</span><span class="n">enable_preview</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">enable_preview</span></div>
467468
</div>

_modules/trinity/trainer/trainer.html

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,32 +159,34 @@ <h1>Source code for trinity.trainer.trainer</h1><div class="highlight"><pre>
159159
<span class="sd"> bool: Whether to continue training.</span>
160160
<span class="sd"> &quot;&quot;&quot;</span>
161161
<span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">set_mode</span><span class="p">(</span><span class="n">algo_type</span><span class="p">)</span>
162+
<span class="k">if</span> <span class="n">algo_type</span><span class="o">.</span><span class="n">is_rft</span><span class="p">()</span> <span class="ow">and</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">get_exp_strategy</span><span class="p">:</span>
163+
<span class="n">strategy</span> <span class="o">=</span> <span class="n">ReadStrategy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">get_exp_strategy</span><span class="p">)</span>
164+
<span class="k">else</span><span class="p">:</span>
165+
<span class="n">strategy</span> <span class="o">=</span> <span class="kc">None</span>
166+
<span class="k">try</span><span class="p">:</span>
167+
<span class="k">if</span> <span class="n">algo_type</span><span class="o">.</span><span class="n">is_sft</span><span class="p">():</span>
168+
<span class="n">exps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sft_warmup_buffer</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
169+
<span class="k">else</span><span class="p">:</span>
170+
<span class="n">exps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_buffer</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">strategy</span><span class="o">=</span><span class="n">strategy</span><span class="p">)</span>
171+
<span class="k">except</span> <span class="ne">StopIteration</span><span class="p">:</span>
172+
<span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;No more data to train. Stop training.&quot;</span><span class="p">)</span>
173+
<span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="mi">0</span> <span class="c1"># TODO: get the actual step number</span>
174+
162175
<span class="k">if</span> <span class="n">algo_type</span><span class="o">.</span><span class="n">is_sft</span><span class="p">():</span>
163-
<span class="n">exps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sft_warmup_buffer</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
164176
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">train_sft_step</span><span class="p">(</span>
165177
<span class="n">Experiences</span><span class="o">.</span><span class="n">gather_experiences</span><span class="p">(</span>
166178
<span class="n">exps</span><span class="p">,</span>
167179
<span class="n">pad_token_id</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">pad_token_id</span><span class="p">,</span> <span class="c1"># type: ignore</span>
168180
<span class="p">)</span>
169181
<span class="p">)</span>
170182
<span class="k">elif</span> <span class="n">algo_type</span><span class="o">.</span><span class="n">is_rft</span><span class="p">():</span>
171-
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">get_exp_strategy</span><span class="p">:</span>
172-
<span class="n">strategy</span> <span class="o">=</span> <span class="n">ReadStrategy</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">trainer</span><span class="o">.</span><span class="n">get_exp_strategy</span><span class="p">)</span>
173-
<span class="k">else</span><span class="p">:</span>
174-
<span class="n">strategy</span> <span class="o">=</span> <span class="kc">None</span>
175-
<span class="k">try</span><span class="p">:</span>
176-
<span class="n">exps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_buffer</span><span class="o">.</span><span class="n">read</span><span class="p">(</span><span class="n">strategy</span><span class="o">=</span><span class="n">strategy</span><span class="p">)</span>
177-
<span class="k">except</span> <span class="ne">StopIteration</span><span class="p">:</span>
178-
<span class="bp">self</span><span class="o">.</span><span class="n">logger</span><span class="o">.</span><span class="n">warning</span><span class="p">(</span><span class="s2">&quot;No more data to train. Stop training.&quot;</span><span class="p">)</span>
179-
<span class="k">return</span> <span class="kc">False</span><span class="p">,</span> <span class="mi">0</span> <span class="c1"># TODO: get the actual step number</span>
180183
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">train_rft_step</span><span class="p">(</span>
181184
<span class="n">Experiences</span><span class="o">.</span><span class="n">gather_experiences</span><span class="p">(</span>
182185
<span class="n">exps</span><span class="p">,</span>
183186
<span class="n">pad_token_id</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">buffer</span><span class="o">.</span><span class="n">pad_token_id</span><span class="p">,</span> <span class="c1"># type: ignore</span>
184187
<span class="p">)</span>
185188
<span class="p">)</span>
186189
<span class="k">elif</span> <span class="n">algo_type</span><span class="o">.</span><span class="n">is_dpo</span><span class="p">():</span>
187-
<span class="n">exps</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_buffer</span><span class="o">.</span><span class="n">read</span><span class="p">()</span>
188190
<span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">engine</span><span class="o">.</span><span class="n">train_dpo_step</span><span class="p">(</span>
189191
<span class="n">Experiences</span><span class="o">.</span><span class="n">gather_dpo_experiences</span><span class="p">(</span>
190192
<span class="n">exps</span><span class="p">,</span>

0 commit comments

Comments
 (0)