|
177 | 177 | "outputs": [], |
178 | 178 | "source": [ |
179 | 179 | "def get_corr(targets, predictions):\n", |
180 | | - " scores = [kendalltau(x, y).correlation for x, y in zip(targets, predictions)]\n", |
181 | | - " return [score if not np.isnan(score) else 0 for score in scores ]" |
| 180 | + " scores = [kendalltau(x, y).correlation for x, y in zip(targets, predictions)]\n", |
| 181 | + " return [score if not np.isnan(score) else 0 for score in scores]" |
182 | 182 | ] |
183 | 183 | }, |
184 | 184 | { |
|
355 | 355 | "metadata": {}, |
356 | 356 | "outputs": [], |
357 | 357 | "source": [ |
358 | | - "def gpt_faithfulness(question:list, context:list, answer:list):\n", |
359 | | - " prompt = [faithfulness.format(c,q, a) for c,q,a in zip(question,context,answer)]\n", |
360 | | - " output = [output for output in llm(prompt)['choices']]\n", |
361 | | - " scores = [(out[\"text\"].strip()) for out in output ]\n", |
362 | | - " scores = [int(score) if score in ['1','2','3','4','5'] else 1 for score in scores]\n", |
| 358 | + "def gpt_faithfulness(question: list, context: list, answer: list):\n", |
| 359 | + " prompt = [\n", |
| 360 | + " faithfulness.format(c, q, a) for c, q, a in zip(question, context, answer)\n", |
| 361 | + " ]\n", |
| 362 | + " output = [output for output in llm(prompt)[\"choices\"]]\n", |
| 363 | + " scores = [(out[\"text\"].strip()) for out in output]\n", |
| 364 | + " scores = [\n", |
| 365 | + " int(score) if score in [\"1\", \"2\", \"3\", \"4\", \"5\"] else 1 for score in scores\n", |
| 366 | + " ]\n", |
363 | 367 | " return scores\n", |
364 | 368 | "\n", |
365 | | - "def gpt_relevance(question:list, answer:list):\n", |
366 | | - " prompt = [relevence.format(q,a) for q,a in zip(question,answer)]\n", |
367 | | - " output = [output for output in llm(prompt)['choices']]\n", |
368 | | - " scores = [(out[\"text\"].strip()) for out in output ]\n", |
369 | | - " scores = [int(score) if score in ['1','2','3','4','5'] else 1 for score in scores]\n", |
| 369 | + "\n", |
| 370 | + "def gpt_relevance(question: list, answer: list):\n", |
| 371 | + " prompt = [relevence.format(q, a) for q, a in zip(question, answer)]\n", |
| 372 | + " output = [output for output in llm(prompt)[\"choices\"]]\n", |
| 373 | + " scores = [(out[\"text\"].strip()) for out in output]\n", |
| 374 | + " scores = [\n", |
| 375 | + " int(score) if score in [\"1\", \"2\", \"3\", \"4\", \"5\"] else 1 for score in scores\n", |
| 376 | + " ]\n", |
370 | 377 | " return scores" |
371 | 378 | ] |
372 | 379 | }, |
|
425 | 432 | "metadata": {}, |
426 | 433 | "outputs": [], |
427 | 434 | "source": [ |
428 | | - "q,a,c = wikiqa_ragas['train'][0]['question'],wikiqa_ragas['train'][0]['generated_without_rag'],wikiqa_ragas['train'][0]['context']" |
| 435 | + "q, a, c = (\n", |
| 436 | + " wikiqa_ragas[\"train\"][0][\"question\"],\n", |
| 437 | + " wikiqa_ragas[\"train\"][0][\"generated_without_rag\"],\n", |
| 438 | + " wikiqa_ragas[\"train\"][0][\"context\"],\n", |
| 439 | + ")" |
429 | 440 | ] |
430 | 441 | }, |
431 | 442 | { |
|
446 | 457 | } |
447 | 458 | ], |
448 | 459 | "source": [ |
449 | | - "gpt_faithfulness([q],[c], [a])" |
| 460 | + "gpt_faithfulness([q], [c], [a])" |
450 | 461 | ] |
451 | 462 | }, |
452 | 463 | { |
|
517 | 528 | "def predict_(examples):\n", |
518 | 529 | " scores = {}\n", |
519 | 530 | " questions = examples[\"question\"]\n", |
520 | | - " context = examples['context']\n", |
| 531 | + " context = examples[\"context\"]\n", |
521 | 532 | " for col in COLUMNS:\n", |
522 | 533 | " passage = examples[col]\n", |
523 | 534 | " inputs = list(zip(questions, passage))\n", |
524 | | - " #scores[f\"{col}_relevance\"] = t5_qgen.predict(inputs, show_progress=False)\n", |
525 | | - " scores[f\"{col}_relevance\"] = gpt_faithfulness(questions,context,passage)\n", |
| 535 | + " # scores[f\"{col}_relevance\"] = t5_qgen.predict(inputs, show_progress=False)\n", |
| 536 | + " scores[f\"{col}_relevance\"] = gpt_faithfulness(questions, context, passage)\n", |
526 | 537 | " return scores" |
527 | 538 | ] |
528 | 539 | }, |
|
553 | 564 | }, |
554 | 565 | "outputs": [], |
555 | 566 | "source": [ |
556 | | - "output = (\n", |
557 | | - " wikiqa_ragas[\"train\"]\n", |
558 | | - " .map(predict_relevance, batched=True, batch_size=10)\n", |
559 | | - ")" |
| 567 | + "output = wikiqa_ragas[\"train\"].map(predict_relevance, batched=True, batch_size=10)" |
560 | 568 | ] |
561 | 569 | }, |
562 | 570 | { |
|
622 | 630 | } |
623 | 631 | ], |
624 | 632 | "source": [ |
625 | | - "output = (\n", |
626 | | - " wikiqa_ragas[\"train\"]\n", |
627 | | - " .map(predict_relevance, batched=True, batch_size=10)\n", |
628 | | - ")" |
| 633 | + "output = wikiqa_ragas[\"train\"].map(predict_relevance, batched=True, batch_size=10)" |
629 | 634 | ] |
630 | 635 | }, |
631 | 636 | { |
|
877 | 882 | "metadata": {}, |
878 | 883 | "outputs": [], |
879 | 884 | "source": [ |
880 | | - "def predict_faithfulness(examples,scoring_fun=NLI.score):\n", |
| 885 | + "def predict_faithfulness(examples, scoring_fun=NLI.score):\n", |
881 | 886 | " scores = {}\n", |
882 | 887 | " questions = examples[\"question\"]\n", |
883 | 888 | " contexts = examples[\"answer_context\"]\n", |
|
0 commit comments