@@ -159,32 +159,34 @@ <h1>Source code for trinity.trainer.trainer</h1><div class="highlight"><pre>
159159< span class ="sd "> bool: Whether to continue training.</ span >
160160< span class ="sd "> """</ span >
161161 < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> engine</ span > < span class ="o "> .</ span > < span class ="n "> set_mode</ span > < span class ="p "> (</ span > < span class ="n "> algo_type</ span > < span class ="p "> )</ span >
162+ < span class ="k "> if</ span > < span class ="n "> algo_type</ span > < span class ="o "> .</ span > < span class ="n "> is_rft</ span > < span class ="p "> ()</ span > < span class ="ow "> and</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> trainer</ span > < span class ="o "> .</ span > < span class ="n "> get_exp_strategy</ span > < span class ="p "> :</ span >
163+ < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="n "> ReadStrategy</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> trainer</ span > < span class ="o "> .</ span > < span class ="n "> get_exp_strategy</ span > < span class ="p "> )</ span >
164+ < span class ="k "> else</ span > < span class ="p "> :</ span >
165+ < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
166+ < span class ="k "> try</ span > < span class ="p "> :</ span >
167+ < span class ="k "> if</ span > < span class ="n "> algo_type</ span > < span class ="o "> .</ span > < span class ="n "> is_sft</ span > < span class ="p "> ():</ span >
168+ < span class ="n "> exps</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> sft_warmup_buffer</ span > < span class ="o "> .</ span > < span class ="n "> read</ span > < span class ="p "> ()</ span >
169+ < span class ="k "> else</ span > < span class ="p "> :</ span >
170+ < span class ="n "> exps</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> train_buffer</ span > < span class ="o "> .</ span > < span class ="n "> read</ span > < span class ="p "> (</ span > < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="n "> strategy</ span > < span class ="p "> )</ span >
171+ < span class ="k "> except</ span > < span class ="ne "> StopIteration</ span > < span class ="p "> :</ span >
172+ < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> logger</ span > < span class ="o "> .</ span > < span class ="n "> warning</ span > < span class ="p "> (</ span > < span class ="s2 "> "No more data to train. Stop training."</ span > < span class ="p "> )</ span >
173+ < span class ="k "> return</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="mi "> 0</ span > < span class ="c1 "> # TODO: get the actual step number</ span >
174+
162175 < span class ="k "> if</ span > < span class ="n "> algo_type</ span > < span class ="o "> .</ span > < span class ="n "> is_sft</ span > < span class ="p "> ():</ span >
163- < span class ="n "> exps</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> sft_warmup_buffer</ span > < span class ="o "> .</ span > < span class ="n "> read</ span > < span class ="p "> ()</ span >
164176 < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> engine</ span > < span class ="o "> .</ span > < span class ="n "> train_sft_step</ span > < span class ="p "> (</ span >
165177 < span class ="n "> Experiences</ span > < span class ="o "> .</ span > < span class ="n "> gather_experiences</ span > < span class ="p "> (</ span >
166178 < span class ="n "> exps</ span > < span class ="p "> ,</ span >
167179 < span class ="n "> pad_token_id</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> buffer</ span > < span class ="o "> .</ span > < span class ="n "> pad_token_id</ span > < span class ="p "> ,</ span > < span class ="c1 "> # type: ignore</ span >
168180 < span class ="p "> )</ span >
169181 < span class ="p "> )</ span >
170182 < span class ="k "> elif</ span > < span class ="n "> algo_type</ span > < span class ="o "> .</ span > < span class ="n "> is_rft</ span > < span class ="p "> ():</ span >
171- < span class ="k "> if</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> trainer</ span > < span class ="o "> .</ span > < span class ="n "> get_exp_strategy</ span > < span class ="p "> :</ span >
172- < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="n "> ReadStrategy</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> trainer</ span > < span class ="o "> .</ span > < span class ="n "> get_exp_strategy</ span > < span class ="p "> )</ span >
173- < span class ="k "> else</ span > < span class ="p "> :</ span >
174- < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="kc "> None</ span >
175- < span class ="k "> try</ span > < span class ="p "> :</ span >
176- < span class ="n "> exps</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> train_buffer</ span > < span class ="o "> .</ span > < span class ="n "> read</ span > < span class ="p "> (</ span > < span class ="n "> strategy</ span > < span class ="o "> =</ span > < span class ="n "> strategy</ span > < span class ="p "> )</ span >
177- < span class ="k "> except</ span > < span class ="ne "> StopIteration</ span > < span class ="p "> :</ span >
178- < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> logger</ span > < span class ="o "> .</ span > < span class ="n "> warning</ span > < span class ="p "> (</ span > < span class ="s2 "> "No more data to train. Stop training."</ span > < span class ="p "> )</ span >
179- < span class ="k "> return</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="mi "> 0</ span > < span class ="c1 "> # TODO: get the actual step number</ span >
180183 < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> engine</ span > < span class ="o "> .</ span > < span class ="n "> train_rft_step</ span > < span class ="p "> (</ span >
181184 < span class ="n "> Experiences</ span > < span class ="o "> .</ span > < span class ="n "> gather_experiences</ span > < span class ="p "> (</ span >
182185 < span class ="n "> exps</ span > < span class ="p "> ,</ span >
183186 < span class ="n "> pad_token_id</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> config</ span > < span class ="o "> .</ span > < span class ="n "> buffer</ span > < span class ="o "> .</ span > < span class ="n "> pad_token_id</ span > < span class ="p "> ,</ span > < span class ="c1 "> # type: ignore</ span >
184187 < span class ="p "> )</ span >
185188 < span class ="p "> )</ span >
186189 < span class ="k "> elif</ span > < span class ="n "> algo_type</ span > < span class ="o "> .</ span > < span class ="n "> is_dpo</ span > < span class ="p "> ():</ span >
187- < span class ="n "> exps</ span > < span class ="o "> =</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> train_buffer</ span > < span class ="o "> .</ span > < span class ="n "> read</ span > < span class ="p "> ()</ span >
188190 < span class ="k "> return</ span > < span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> engine</ span > < span class ="o "> .</ span > < span class ="n "> train_dpo_step</ span > < span class ="p "> (</ span >
189191 < span class ="n "> Experiences</ span > < span class ="o "> .</ span > < span class ="n "> gather_dpo_experiences</ span > < span class ="p "> (</ span >
190192 < span class ="n "> exps</ span > < span class ="p "> ,</ span >
0 commit comments