|
589 | 589 | " \"type\": \"dpo\",\n",
|
590 | 590 | " \"dpo\": {\n",
|
591 | 591 | " \"hyperparameters\": {\n",
|
592 |
| - " \"n_epochs\": 2, \n", |
593 |
| - " \"beta\": 0.1, \n", |
594 |
| - " \"batch_size\": 8, \n", |
| 592 | + " \"n_epochs\": 2,\n", |
| 593 | + " \"beta\": 0.1,\n", |
| 594 | + " \"batch_size\": 8,\n", |
595 | 595 | " }\n",
|
596 | 596 | " },\n",
|
597 | 597 | " },\n",
|
|
610 | 610 | },
|
611 | 611 | {
|
612 | 612 | "cell_type": "code",
|
613 |
| - "execution_count": 25, |
| 613 | + "execution_count": null, |
614 | 614 | "metadata": {},
|
615 | 615 | "outputs": [],
|
616 | 616 | "source": [
|
|
620 | 620 | " responses = await generate_responses(testset, model=job.fine_tuned_model)\n",
|
621 | 621 | "\n",
|
622 | 622 | " post_run = sync_client.evals.runs.create(\n",
|
623 |
| - " name=ft.id,\n", |
624 |
| - " eval_id=logs_eval.id,\n", |
625 |
| - " data_source={\n", |
626 |
| - " \"type\": \"responses\",\n", |
627 |
| - " \"source\": {\"type\": \"responses\", \"limit\": len(test_pairs)},\n", |
628 |
| - " },\n", |
629 |
| - ")" |
| 623 | + " name=ft.id,\n", |
| 624 | + " eval_id=logs_eval.id,\n", |
| 625 | + " data_source={\n", |
| 626 | + " \"type\": \"responses\",\n", |
| 627 | + " \"source\": {\"type\": \"responses\", \"limit\": len(test_pairs)},\n", |
| 628 | + " },\n", |
| 629 | + " )" |
630 | 630 | ]
|
631 | 631 | },
|
632 | 632 | {
|
|
661 | 661 | ").data\n",
|
662 | 662 | "post_scores = [s.results[0][\"score\"] for s in post_data]\n",
|
663 | 663 | "\n",
|
| 664 | + "# print scores & a sample comparison from the test set for illustration\n", |
664 | 665 | "print(\n",
|
665 | 666 | " \"Δ mean:\",\n",
|
666 |
| - " sum(t - b for b, t in zip(base_scores, post_scores))\n", |
667 |
| - " / len(base_scores),\n", |
| 667 | + " sum(t - b for b, t in zip(base_scores, post_scores)) / len(base_scores),\n", |
668 | 668 | ")\n",
|
669 |
| - "# print a sample comparison from the test set for illustration\n", |
670 | 669 | "print(\"\\n=== SAMPLE COMPARISON ===\")\n",
|
671 | 670 | "idx = 0\n",
|
672 | 671 | "print(f\"Prompt:\\n {testset[idx]['item']['input']}\\n\")\n",
|
673 |
| - "print(\n", |
674 |
| - " f\"Base model reply: \\n {base_data[idx].sample.output[0].content} \\n\"\n", |
675 |
| - ")\n", |
676 |
| - "print(\n", |
677 |
| - " f\"DPO-tuned model reply \\n {post_data[idx].sample.output[0].content}\"\n", |
678 |
| - ")" |
| 672 | + "print(f\"Base model reply: \\n {base_data[idx].sample.output[0].content} \\n\")\n", |
| 673 | + "print(f\"DPO-tuned model reply \\n {post_data[idx].sample.output[0].content}\")" |
679 | 674 | ]
|
680 | 675 | }
|
681 | 676 | ],
|
|
0 commit comments