@@ -94,6 +94,7 @@ class Trainer(Component):
9494 max_error_samples : Optional [int ] = 2
9595 max_correct_samples : Optional [int ] = 2
9696 debug : bool = False
97+ sequential_order : List [str ] = ["text" , "demo" ]
9798
9899 def __init__ (
99100 self ,
@@ -119,6 +120,7 @@ def __init__(
119120 exclude_input_fields_from_bootstrap_demos : bool = False ,
120121 debug : bool = False ,
121122 save_traces : bool = False , # save traces in the few-shto demos
123+ sequential_order : List [str ] = ["text" , "demo" ],
122124 * args ,
123125 ** kwargs ,
124126 ) -> None :
@@ -161,6 +163,7 @@ def __init__(
161163 self .exclude_input_fields_from_bootstrap_demos = (
162164 exclude_input_fields_from_bootstrap_demos
163165 )
166+ self .sequential_order = sequential_order
164167
165168 # TODO: need to support checkpoint resume too!
166169 def diagnose (self , dataset : Any , split : str = "train" ):
@@ -503,7 +506,6 @@ def fit(
503506 and len (self .text_optimizers ) > 0
504507 ):
505508 if self .strategy == "random" :
506-
507509 self ._fit_text_grad_demo_mix_random (
508510 train_loader ,
509511 train_dataset ,
@@ -525,37 +527,62 @@ def fit(
525527 raise ValueError (f"Strategy { self .strategy } not supported" )
526528
527529 else : # sequential, text first and demo second
528- if len (self .text_optimizers ) > 0 :
529- if self .strategy == "random" :
530- trainer_results = self ._fit_text_grad_random (
531- train_loader ,
532- val_dataset ,
533- test_dataset ,
534- trainer_results ,
535- starting_step = starting_step ,
536- )
537- starting_step += self .max_steps
538- elif self .strategy == "constrained" :
539- trainer_results = self ._fit_text_grad_constraint (
530+
531+ def run_text_optimizers (starting_step : int , trainer_results : TrainerResult ):
532+ if len (self .text_optimizers ) > 0 :
533+ if self .strategy == "random" :
534+ trainer_results = self ._fit_text_grad_random (
535+ train_loader ,
536+ val_dataset ,
537+ test_dataset ,
538+ trainer_results ,
539+ starting_step = starting_step ,
540+ )
541+ starting_step += self .max_steps
542+ elif self .strategy == "constrained" :
543+ trainer_results = self ._fit_text_grad_constraint (
544+ train_loader ,
545+ val_dataset ,
546+ test_dataset ,
547+ trainer_results = trainer_results ,
548+ starting_step = starting_step ,
549+ )
550+ starting_step += self .max_steps
551+ else :
552+ raise ValueError (f"Strategy { self .strategy } not supported" )
553+
554+ def run_demo_optimizers (starting_step : int , trainer_results : TrainerResult ):
555+ if len (self .demo_optimizers ) > 0 :
556+ self .adaltask .configure_teacher_generator ()
557+ self ._fit_demos_random (
540558 train_loader ,
559+ train_dataset ,
541560 val_dataset ,
542561 test_dataset ,
543562 trainer_results = trainer_results ,
544563 starting_step = starting_step ,
545564 )
546- starting_step += self .max_steps
547- else :
548- raise ValueError (f"Strategy { self .strategy } not supported" )
549- if len (self .demo_optimizers ) > 0 :
550- self .adaltask .configure_teacher_generator () # attemp to use the newest teacher as
551- self ._fit_demos_random (
552- train_loader ,
553- train_dataset ,
554- val_dataset ,
555- test_dataset ,
556- trainer_results = trainer_results ,
557- starting_step = starting_step ,
558- )
565+
566+ if self .sequential_order == ["text" , "demo" ]:
567+ run_text_optimizers (starting_step , trainer_results )
568+ run_demo_optimizers (starting_step , trainer_results )
569+ else :
570+ run_demo_optimizers (starting_step , trainer_results )
571+ run_text_optimizers (starting_step , trainer_results )
572+ # if len(self.text_optimizers) > 0:
573+ # run_text_optimizers(starting_step, trainer_results)
574+
575+ # if len(self.demo_optimizers) > 0:
576+ # run_demo_optimizers(starting_step, trainer_results)
577+ # self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
578+ # self._fit_demos_random(
579+ # train_loader,
580+ # train_dataset,
581+ # val_dataset,
582+ # test_dataset,
583+ # trainer_results=trainer_results,
584+ # starting_step=starting_step,
585+ # )
559586
560587 end_time = time .time ()
561588 print (f"Training time: { end_time - start_time } s" )
0 commit comments