|
137 | 137 | "\n", |
138 | 138 | "# Prompt user to enter their API keys securely\n", |
139 | 139 | "openai_api_key = getpass(\"Please enter your OpenAI API key: \")\n", |
140 | | - "groq_api_key = getpass(\"Please enter your GROQ API key, simplly press Enter if you don't have one: \")\n", |
| 140 | + "groq_api_key = getpass(\n", |
| 141 | + " \"Please enter your GROQ API key, simplly press Enter if you don't have one: \"\n", |
| 142 | + ")\n", |
141 | 143 | "\n", |
142 | 144 | "\n", |
143 | 145 | "# Set environment variables\n", |
144 | | - "os.environ['OPENAI_API_KEY'] = openai_api_key\n", |
145 | | - "os.environ['GROQ_API_KEY'] = groq_api_key\n", |
| 146 | + "os.environ[\"OPENAI_API_KEY\"] = openai_api_key\n", |
| 147 | + "os.environ[\"GROQ_API_KEY\"] = groq_api_key\n", |
146 | 148 | "\n", |
147 | 149 | "print(\"API keys have been set.\")" |
148 | 150 | ] |
|
209 | 211 | "<END_OF_USER>\n", |
210 | 212 | "\"\"\"\n", |
211 | 213 | "\n", |
| 214 | + "\n", |
212 | 215 | "class ObjectCountTaskPipeline(adal.Component):\n", |
213 | 216 | " def __init__(self, model_client: adal.ModelClient, model_kwargs: Dict):\n", |
214 | 217 | " super().__init__()\n", |
|
242 | 245 | " self, question: str, id: str = None\n", |
243 | 246 | " ) -> Union[adal.GeneratorOutput, adal.Parameter]:\n", |
244 | 247 | " output = self.llm_counter(prompt_kwargs={\"input_str\": question}, id=id)\n", |
245 | | - " return output\n", |
246 | | - "\n", |
247 | | - "\n" |
| 248 | + " return output" |
248 | 249 | ] |
249 | 250 | }, |
250 | 251 | { |
|
329 | 330 | "from adalflow.components.model_client.groq_client import GroqAPIClient\n", |
330 | 331 | "\n", |
331 | 332 | "\n", |
332 | | - "if len(os.environ['OPENAI_API_KEY']) > 1:\n", |
333 | | - " gpt_3_model = {\n", |
334 | | - " \"model_client\": OpenAIClient(),\n", |
335 | | - " \"model_kwargs\": {\n", |
336 | | - " \"model\": \"gpt-3.5-turbo\",\n", |
337 | | - " \"max_tokens\": 2000,\n", |
338 | | - " \"temperature\": 0.0,\n", |
339 | | - " \"top_p\": 0.99,\n", |
340 | | - " \"frequency_penalty\": 0,\n", |
341 | | - " \"presence_penalty\": 0,\n", |
342 | | - " \"stop\": None,\n", |
343 | | - " },\n", |
344 | | - " }\n", |
345 | | - " gpt_4o_model = {\n", |
346 | | - " \"model_client\": OpenAIClient(),\n", |
347 | | - " \"model_kwargs\": {\n", |
348 | | - " \"model\": \"gpt-4o\",\n", |
349 | | - " \"max_tokens\": 4000,\n", |
350 | | - " \"temperature\": 0.0,\n", |
351 | | - " \"top_p\": 0.99,\n", |
352 | | - " \"frequency_penalty\": 0,\n", |
353 | | - " \"presence_penalty\": 0,\n", |
354 | | - " \"stop\": None,\n", |
355 | | - " },\n", |
356 | | - " }\n", |
| 333 | + "if len(os.environ[\"OPENAI_API_KEY\"]) > 1:\n", |
| 334 | + " gpt_3_model = {\n", |
| 335 | + " \"model_client\": OpenAIClient(),\n", |
| 336 | + " \"model_kwargs\": {\n", |
| 337 | + " \"model\": \"gpt-3.5-turbo\",\n", |
| 338 | + " \"max_tokens\": 2000,\n", |
| 339 | + " \"temperature\": 0.0,\n", |
| 340 | + " \"top_p\": 0.99,\n", |
| 341 | + " \"frequency_penalty\": 0,\n", |
| 342 | + " \"presence_penalty\": 0,\n", |
| 343 | + " \"stop\": None,\n", |
| 344 | + " },\n", |
| 345 | + " }\n", |
| 346 | + " gpt_4o_model = {\n", |
| 347 | + " \"model_client\": OpenAIClient(),\n", |
| 348 | + " \"model_kwargs\": {\n", |
| 349 | + " \"model\": \"gpt-4o\",\n", |
| 350 | + " \"max_tokens\": 4000,\n", |
| 351 | + " \"temperature\": 0.0,\n", |
| 352 | + " \"top_p\": 0.99,\n", |
| 353 | + " \"frequency_penalty\": 0,\n", |
| 354 | + " \"presence_penalty\": 0,\n", |
| 355 | + " \"stop\": None,\n", |
| 356 | + " },\n", |
| 357 | + " }\n", |
357 | 358 | "\n", |
358 | | - "if len(os.environ['GROQ_API_KEY']) > 1:\n", |
359 | | - " llama_3_1_model ={\n", |
360 | | - " \"model_client\": GroqAPIClient(),\n", |
361 | | - " \"model_kwargs\": {\n", |
362 | | - " \"model\": \"llama-3.1-8b-instant\"\n", |
363 | | - " }\n", |
364 | | - " }\n", |
| 359 | + "if len(os.environ[\"GROQ_API_KEY\"]) > 1:\n", |
| 360 | + " llama_3_1_model = {\n", |
| 361 | + " \"model_client\": GroqAPIClient(),\n", |
| 362 | + " \"model_kwargs\": {\"model\": \"llama-3.1-8b-instant\"},\n", |
| 363 | + " }\n", |
365 | 364 | "\n", |
366 | 365 | "\n", |
367 | 366 | "question = \"I have a flute, a piano, a trombone, four stoves, a violin, an accordion, a clarinet, a drum, two lamps, and a trumpet. How many musical instruments do I have?\"\n", |
368 | 367 | "task_pipeline = ObjectCountTaskPipeline(**gpt_3_model)\n", |
369 | | - "print(task_pipeline)\n" |
| 368 | + "print(task_pipeline)" |
370 | 369 | ] |
371 | 370 | }, |
372 | 371 | { |
|
467 | 466 | "from adalflow.datasets.big_bench_hard import BigBenchHard\n", |
468 | 467 | "from adalflow.utils.data import subset_dataset\n", |
469 | 468 | "\n", |
| 469 | + "\n", |
470 | 470 | "def load_datasets(max_samples: int = None):\n", |
471 | 471 | " \"\"\"Load the dataset\"\"\"\n", |
472 | 472 | " train_data = BigBenchHard(split=\"train\")\n", |
|
479 | 479 | " val_data = subset_dataset(val_data, max_samples)\n", |
480 | 480 | " test_data = subset_dataset(test_data, max_samples)\n", |
481 | 481 | "\n", |
482 | | - " return train_data, val_data, test_data\n" |
| 482 | + " return train_data, val_data, test_data" |
483 | 483 | ] |
484 | 484 | }, |
485 | 485 | { |
|
583 | 583 | " def prepare_task(self, sample: Example):\n", |
584 | 584 | " return self.task.call, {\"question\": sample.question, \"id\": sample.id}\n", |
585 | 585 | "\n", |
586 | | - " def prepare_eval(\n", |
587 | | - " self, sample: Example, y_pred: adal.GeneratorOutput\n", |
588 | | - " ) -> float:\n", |
| 586 | + " def prepare_eval(self, sample: Example, y_pred: adal.GeneratorOutput) -> float:\n", |
589 | 587 | " y_label = -1\n", |
590 | | - " if (y_pred is not None and y_pred.data is not None): # if y_pred and y_pred.data: might introduce bug when the data is 0\n", |
| 588 | + " if (\n", |
| 589 | + " y_pred is not None and y_pred.data is not None\n", |
| 590 | + " ): # if y_pred and y_pred.data: might introduce bug when the data is 0\n", |
591 | 591 | " y_label = y_pred.data\n", |
592 | 592 | " return self.eval_fn, {\"y\": y_label, \"y_gt\": sample.answer}" |
593 | 593 | ] |
|
820 | 820 | "from adalflow.datasets.types import Example\n", |
821 | 821 | "\n", |
822 | 822 | "\n", |
823 | | - "class ObjectCountAdalComponent(adal.AdalComponent):# noqa: F811\n", |
| 823 | + "class ObjectCountAdalComponent(adal.AdalComponent): # noqa: F811\n", |
824 | 824 | " def __init__(\n", |
825 | 825 | " self,\n", |
826 | 826 | " model_client: adal.ModelClient,\n", |
|
844 | 844 | " def prepare_task(self, sample: Example):\n", |
845 | 845 | " return self.task.call, {\"question\": sample.question, \"id\": sample.id}\n", |
846 | 846 | "\n", |
847 | | - "\n", |
848 | | - " def prepare_eval(\n", |
849 | | - " self, sample: Example, y_pred: adal.GeneratorOutput\n", |
850 | | - " ) -> float:\n", |
| 847 | + " def prepare_eval(self, sample: Example, y_pred: adal.GeneratorOutput) -> float:\n", |
851 | 848 | " y_label = -1\n", |
852 | | - " if (y_pred is not None and y_pred.data is not None): # if y_pred and y_pred.data: might introduce bug when the data is 0\n", |
| 849 | + " if (\n", |
| 850 | + " y_pred is not None and y_pred.data is not None\n", |
| 851 | + " ): # if y_pred and y_pred.data: might introduce bug when the data is 0\n", |
853 | 852 | " y_label = y_pred.data\n", |
854 | 853 | " return self.eval_fn, {\"y\": y_label, \"y_gt\": sample.answer}\n", |
855 | 854 | "\n", |
|
891 | 890 | " **gpt_3_model,\n", |
892 | 891 | " teacher_model_config=gpt_4o_model,\n", |
893 | 892 | " text_optimizer_model_config=gpt_4o_model,\n", |
894 | | - " backward_engine_model_config=gpt_4o_model\n", |
| 893 | + " backward_engine_model_config=gpt_4o_model,\n", |
895 | 894 | " )\n", |
896 | 895 | " print(adal_component)\n", |
897 | 896 | " trainer = adal.Trainer(\n", |
|
916 | 915 | " test_dataset=test_dataset,\n", |
917 | 916 | " debug=debug,\n", |
918 | 917 | " resume_from_ckpt=resume_from_ckpt,\n", |
919 | | - " )\n" |
| 918 | + " )" |
920 | 919 | ] |
921 | 920 | }, |
922 | 921 | { |
|
3255 | 3254 | } |
3256 | 3255 | ], |
3257 | 3256 | "source": [ |
3258 | | - "train(debug=False, max_steps=12, strategy=\"constrained\",\n", |
3259 | | - " raw_shots=0, bootstrap_shots=1,\n", |
3260 | | - " exclude_input_fields_from_bootstrap_demos=True\n", |
3261 | | - " )" |
| 3257 | + "train(\n", |
| 3258 | + " debug=False,\n", |
| 3259 | + " max_steps=12,\n", |
| 3260 | + " strategy=\"constrained\",\n", |
| 3261 | + " raw_shots=0,\n", |
| 3262 | + " bootstrap_shots=1,\n", |
| 3263 | + " exclude_input_fields_from_bootstrap_demos=True,\n", |
| 3264 | + ")" |
3262 | 3265 | ] |
3263 | 3266 | }, |
3264 | 3267 | { |
|
6015 | 6018 | } |
6016 | 6019 | ], |
6017 | 6020 | "source": [ |
6018 | | - "\n", |
6019 | 6021 | "ckpt_path = \"/content/adalflow/ckpt/ObjectCountAdalComponent/constrained_max_steps_12_4e8a1_run_1.json\"\n", |
6020 | 6022 | "\n", |
6021 | | - "train(debug=False, max_steps=12, strategy=\"constrained\",\n", |
6022 | | - " raw_shots=0, bootstrap_shots=1,\n", |
6023 | | - " resume_from_ckpt=ckpt_path,\n", |
6024 | | - " exclude_input_fields_from_bootstrap_demos=True)" |
| 6023 | + "train(\n", |
| 6024 | + " debug=False,\n", |
| 6025 | + " max_steps=12,\n", |
| 6026 | + " strategy=\"constrained\",\n", |
| 6027 | + " raw_shots=0,\n", |
| 6028 | + " bootstrap_shots=1,\n", |
| 6029 | + " resume_from_ckpt=ckpt_path,\n", |
| 6030 | + " exclude_input_fields_from_bootstrap_demos=True,\n", |
| 6031 | + ")" |
6025 | 6032 | ] |
6026 | 6033 | }, |
6027 | 6034 | { |
|
8038 | 8045 | } |
8039 | 8046 | ], |
8040 | 8047 | "source": [ |
8041 | | - "\n", |
8042 | | - "train(debug=False, max_steps=12, strategy=\"random\",\n", |
8043 | | - " raw_shots=0, bootstrap_shots=1,\n", |
8044 | | - " resume_from_ckpt=ckpt_path,\n", |
8045 | | - " exclude_input_fields_from_bootstrap_demos=False)" |
| 8048 | + "train(\n", |
| 8049 | + " debug=False,\n", |
| 8050 | + " max_steps=12,\n", |
| 8051 | + " strategy=\"random\",\n", |
| 8052 | + " raw_shots=0,\n", |
| 8053 | + " bootstrap_shots=1,\n", |
| 8054 | + " resume_from_ckpt=ckpt_path,\n", |
| 8055 | + " exclude_input_fields_from_bootstrap_demos=False,\n", |
| 8056 | + ")" |
8046 | 8057 | ] |
8047 | 8058 | }, |
8048 | 8059 | { |
|
0 commit comments