Skip to content

Commit 6c5841a

Browse files
committed
wrap up multi hop rag
1 parent 0439354 commit 6c5841a

File tree

4 files changed

+56
-28
lines changed

4 files changed

+56
-28
lines changed

adalflow/adalflow/optim/trainer/trainer.py

Lines changed: 53 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ class Trainer(Component):
9494
max_error_samples: Optional[int] = 2
9595
max_correct_samples: Optional[int] = 2
9696
debug: bool = False
97+
sequential_order: List[str] = ["text", "demo"]
9798

9899
def __init__(
99100
self,
@@ -119,6 +120,7 @@ def __init__(
119120
exclude_input_fields_from_bootstrap_demos: bool = False,
120121
debug: bool = False,
121122
save_traces: bool = False, # save traces in the few-shto demos
123+
sequential_order: List[str] = ["text", "demo"],
122124
*args,
123125
**kwargs,
124126
) -> None:
@@ -161,6 +163,7 @@ def __init__(
161163
self.exclude_input_fields_from_bootstrap_demos = (
162164
exclude_input_fields_from_bootstrap_demos
163165
)
166+
self.sequential_order = sequential_order
164167

165168
# TODO: need to support checkpoint resume too!
166169
def diagnose(self, dataset: Any, split: str = "train"):
@@ -503,7 +506,6 @@ def fit(
503506
and len(self.text_optimizers) > 0
504507
):
505508
if self.strategy == "random":
506-
507509
self._fit_text_grad_demo_mix_random(
508510
train_loader,
509511
train_dataset,
@@ -525,37 +527,62 @@ def fit(
525527
raise ValueError(f"Strategy {self.strategy} not supported")
526528

527529
else: # sequential, text first and demo second
528-
if len(self.text_optimizers) > 0:
529-
if self.strategy == "random":
530-
trainer_results = self._fit_text_grad_random(
531-
train_loader,
532-
val_dataset,
533-
test_dataset,
534-
trainer_results,
535-
starting_step=starting_step,
536-
)
537-
starting_step += self.max_steps
538-
elif self.strategy == "constrained":
539-
trainer_results = self._fit_text_grad_constraint(
530+
531+
def run_text_optimizers(starting_step: int, trainer_results: TrainerResult):
532+
if len(self.text_optimizers) > 0:
533+
if self.strategy == "random":
534+
trainer_results = self._fit_text_grad_random(
535+
train_loader,
536+
val_dataset,
537+
test_dataset,
538+
trainer_results,
539+
starting_step=starting_step,
540+
)
541+
starting_step += self.max_steps
542+
elif self.strategy == "constrained":
543+
trainer_results = self._fit_text_grad_constraint(
544+
train_loader,
545+
val_dataset,
546+
test_dataset,
547+
trainer_results=trainer_results,
548+
starting_step=starting_step,
549+
)
550+
starting_step += self.max_steps
551+
else:
552+
raise ValueError(f"Strategy {self.strategy} not supported")
553+
554+
def run_demo_optimizers(starting_step: int, trainer_results: TrainerResult):
555+
if len(self.demo_optimizers) > 0:
556+
self.adaltask.configure_teacher_generator()
557+
self._fit_demos_random(
540558
train_loader,
559+
train_dataset,
541560
val_dataset,
542561
test_dataset,
543562
trainer_results=trainer_results,
544563
starting_step=starting_step,
545564
)
546-
starting_step += self.max_steps
547-
else:
548-
raise ValueError(f"Strategy {self.strategy} not supported")
549-
if len(self.demo_optimizers) > 0:
550-
self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
551-
self._fit_demos_random(
552-
train_loader,
553-
train_dataset,
554-
val_dataset,
555-
test_dataset,
556-
trainer_results=trainer_results,
557-
starting_step=starting_step,
558-
)
565+
566+
if self.sequential_order == ["text", "demo"]:
567+
run_text_optimizers(starting_step, trainer_results)
568+
run_demo_optimizers(starting_step, trainer_results)
569+
else:
570+
run_demo_optimizers(starting_step, trainer_results)
571+
run_text_optimizers(starting_step, trainer_results)
572+
# if len(self.text_optimizers) > 0:
573+
# run_text_optimizers(starting_step, trainer_results)
574+
575+
# if len(self.demo_optimizers) > 0:
576+
# run_demo_optimizers(starting_step, trainer_results)
577+
# self.adaltask.configure_teacher_generator() # attemp to use the newest teacher as
578+
# self._fit_demos_random(
579+
# train_loader,
580+
# train_dataset,
581+
# val_dataset,
582+
# test_dataset,
583+
# trainer_results=trainer_results,
584+
# starting_step=starting_step,
585+
# )
559586

560587
end_time = time.time()
561588
print(f"Training time: {end_time - start_time}s")

benchmarks/hotpot_qa/adal_exp/build_multi_hop_rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def __init__(self, model_client, model_kwargs, passages_per_hop=3, max_hops=2):
255255
name=f"few_shot_demos_{i}",
256256
data=None,
257257
role_desc="To provide few shot demos to the language model",
258-
requires_opt=False,
258+
requires_opt=True,
259259
param_type=ParameterType.DEMOS,
260260
),
261261
"task_desc_str": Parameter(

benchmarks/hotpot_qa/adal_exp/build_vanilla_rag.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def __init__(self, passages_per_hop=3, model_client=None, model_kwargs=None):
161161
),
162162
"few_shot_demos": adal.Parameter(
163163
data=None,
164-
requires_opt=False,
164+
requires_opt=True,
165165
role_desc="To provide few shot demos to the language model",
166166
param_type=adal.ParameterType.DEMOS,
167167
),

benchmarks/hotpot_qa/adal_exp/train_multi_hop_rag.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def train(
130130
weighted_sampling=True,
131131
optimization_order=optimization_order,
132132
exclude_input_fields_from_bootstrap_demos=exclude_input_fields_from_bootstrap_demos,
133+
sequential_order=["text", "demo"],
133134
)
134135
print(trainer)
135136

0 commit comments

Comments
 (0)