@@ -317,8 +317,8 @@ def diagnose_report(
317317
318318 def debug_report (
319319 self ,
320- text_grad_debug_path : Optional [str ] = None ,
321- few_shot_demo_debug_path : Optional [str ] = None ,
320+ text_grad_debug_path : Optional [Dict [ str , object ] ] = None ,
321+ few_shot_demo_debug_path : Optional [Dict [ str , object ] ] = None ,
322322 ):
323323 import colorama
324324 from colorama import Fore
@@ -361,9 +361,12 @@ def fit(
361361 resume_from_ckpt : Optional [
362362 str
363363 ] = None , # TODO: have a more comprehensive ckpt loading in the future
364- ):
364+ ) -> Tuple [ str , TrainerResult ] :
365365 r"""
366366 train_loader: An iterable or collection of iterables specifying training samples.
367+
368+ Returns:
369+ Tuple[str, TrainerResult]: Checkpoint file and the TrainerResult object
367370 """
368371 start_time = time .time ()
369372
@@ -491,7 +494,7 @@ def fit(
491494 train_loader , train_dataset , val_dataset , test_dataset
492495 )
493496 self .debug_report (text_grad_debug_path , few_shot_demo_debug_path )
494- return
497+ return self . ckpt_file , trainer_results
495498
496499 ########Run text_optimizers and demo optimizers in sequential order ########
497500 if (
@@ -557,6 +560,7 @@ def fit(
557560 end_time = time .time ()
558561 print (f"Training time: { end_time - start_time } s" )
559562 print (f"ckpt_file: { self .ckpt_file } " )
563+ return self .ckpt_file , trainer_results
560564
561565 @staticmethod
562566 def _estimate_num_epochs (train_loader : Any , max_steps : int ):
@@ -684,7 +688,7 @@ def _pre_fit(self, val_dataset: Any, test_dataset: Any) -> TrainerResult:
684688
685689 def _fit_demos_one_step_for_debug (
686690 self , train_loader , train_dataset : Any , val_dataset : Any , test_dataset : Any
687- ) -> str :
691+ ) -> Dict [ str , object ] :
688692 """Trace both the teacher and the student demos with scores and for sampling.
689693 For demos: we need to run both the teacher mode and the student mode."""
690694
@@ -760,6 +764,8 @@ def _fit_demos_one_step_for_debug(
760764
761765 # 2. run student mode
762766
767+ demo_debug_result_path = None
768+
763769 for batch_idx , batch in enumerate (train_loader ):
764770 print (f"Training step: { batch_idx } " )
765771 if batch_idx > 0 :
@@ -820,7 +826,9 @@ def _fit_demos_one_step_for_debug(
820826 self ._demo_optimizers_propose ()
821827 graph_path = os .path .join (debug_path , "student_graph" )
822828
823- paths = losses_student [0 ].draw_graph (filepath = graph_path ) # noqa F841
829+ demo_debug_result_path = losses_student [0 ].draw_graph (
830+ filepath = graph_path
831+ ) # noqa F841
824832
825833 # test step
826834 self ._demo_optimizers_step ()
@@ -851,9 +859,9 @@ def _fit_demos_one_step_for_debug(
851859 if len (param ._demos ) == 0 :
852860 raise ValueError (f"No demos found, param: { param } " )
853861
854- return debug_path
862+ return demo_debug_result_path
855863
856- def _fit_text_grads_one_step_for_debug (self , train_loader : Any ) -> str :
864+ def _fit_text_grads_one_step_for_debug (self , train_loader : Any ) -> Dict [ str , str ] :
857865 printc (
858866 "Debugging fitting one step with batch size 2 for text optimizer" , "blue"
859867 )
@@ -901,8 +909,8 @@ def _fit_text_grads_one_step_for_debug(self, train_loader: Any) -> str:
901909 # test optimizer
902910 self ._propose_text_optimizers ()
903911
904- total_loss .draw_graph (filepath = debug_path , full_trace = True )
905- return debug_path
912+ debug_files = total_loss .draw_graph (filepath = debug_path , full_trace = True )
913+ return debug_files
906914
907915 def _set_demo_optimizers_dataset (self , train_dataset : Any ):
908916 # init the dataset
@@ -1701,6 +1709,9 @@ def _text_grad_constraint_propose_step(
17011709 all_y_preds ,
17021710 include_demo_optimizers : bool = False ,
17031711 ):
1712+ """Handles both the mixed training and the separate training.
1713+ When include_demo_optimizers is True, the demo optimizers are included in the training
1714+ """
17041715 # comptute moving batch acc
17051716 from adalflow .optim .parameter import Parameter
17061717
@@ -1894,6 +1905,7 @@ def _fit_text_grad_constraint(
18941905 trainer_results .prompts [- 1 ],
18951906 total_steps ,
18961907 )
1908+ self ._add_failed_proposals_text_optimizers ()
18971909 continue
18981910
18991911 # prune the correct sample size if its too big, same with error samples
0 commit comments