Skip to content

Commit 23d5f64

Browse files
hwchase17ssprigge
andauthored
Harrison/ngram example (#846)
Co-authored-by: Sean Spriggens <[email protected]>
1 parent 0de5504 commit 23d5f64

File tree

3 files changed

+420
-3
lines changed

3 files changed

+420
-3
lines changed

docs/modules/prompts/examples/example_selectors.ipynb

Lines changed: 235 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
},
2424
{
2525
"cell_type": "code",
26-
"execution_count": null,
26+
"execution_count": 1,
2727
"id": "8244ff60",
2828
"metadata": {},
2929
"outputs": [],
@@ -81,7 +81,7 @@
8181
" template=\"Input: {input}\\nOutput: {output}\",\n",
8282
")\n",
8383
"example_selector = LengthBasedExampleSelector(\n",
84-
" # These are the examples is has available to choose from.\n",
84+
" # These are the examples it has available to choose from.\n",
8585
" examples=examples, \n",
8686
" # This is the PromptTemplate being used to format the examples.\n",
8787
" example_prompt=example_prompt, \n",
@@ -439,10 +439,242 @@
439439
"print(similar_prompt.format(adjective=\"worried\"))"
440440
]
441441
},
442+
{
443+
"cell_type": "markdown",
444+
"id": "4aaeed2f",
445+
"metadata": {},
446+
"source": [
447+
"## NGram Overlap ExampleSelector\n",
448+
"\n",
449+
"The NGramOverlapExampleSelector selects and orders examples based on which examples are most similar to the input, according to an ngram overlap score. The ngram overlap score is a float between 0.0 and 1.0, inclusive. \n",
450+
"\n",
451+
"The selector allows for a threshold score to be set. Examples with an ngram overlap score less than or equal to the threshold are excluded. The threshold is set to -1.0, by default, so will not exclude any examples, only reorder them. Setting the threshold to 0.0 will exclude examples that have no ngram overlaps with the input.\n"
452+
]
453+
},
454+
{
455+
"cell_type": "code",
456+
"execution_count": 2,
457+
"id": "9cbc0acc",
458+
"metadata": {},
459+
"outputs": [],
460+
"source": [
461+
"from langchain.prompts import PromptTemplate\n",
462+
"from langchain.prompts.example_selector.ngram_overlap import NGramOverlapExampleSelector"
463+
]
464+
},
465+
{
466+
"cell_type": "code",
467+
"execution_count": 3,
468+
"id": "4f318f4b",
469+
"metadata": {},
470+
"outputs": [],
471+
"source": [
472+
"# These are examples of a fictional translation task.\n",
473+
"examples = [\n",
474+
" {\"input\": \"See Spot run.\", \"output\": \"Ver correr a Spot.\"},\n",
475+
" {\"input\": \"My dog barks.\", \"output\": \"Mi perro ladra.\"},\n",
476+
" {\"input\": \"Spot can run.\", \"output\": \"Spot puede correr.\"},\n",
477+
"]"
478+
]
479+
},
480+
{
481+
"cell_type": "code",
482+
"execution_count": 4,
483+
"id": "bf75e0fe",
484+
"metadata": {},
485+
"outputs": [],
486+
"source": [
487+
"example_prompt = PromptTemplate(\n",
488+
" input_variables=[\"input\", \"output\"],\n",
489+
" template=\"Input: {input}\\nOutput: {output}\",\n",
490+
")\n",
491+
"example_selector = NGramOverlapExampleSelector(\n",
492+
" # These are the examples it has available to choose from.\n",
493+
" examples=examples, \n",
494+
" # This is the PromptTemplate being used to format the examples.\n",
495+
" example_prompt=example_prompt, \n",
496+
" # This is the threshold, at which selector stops.\n",
497+
" # It is set to -1.0 by default.\n",
498+
" threshold=-1.0,\n",
499+
" # For negative threshold:\n",
500+
" # Selector sorts examples by ngram overlap score, and excludes none.\n",
501+
" # For threshold greater than 1.0:\n",
502+
" # Selector excludes all examples, and returns an empty list.\n",
503+
" # For threshold equal to 0.0:\n",
504+
" # Selector sorts examples by ngram overlap score,\n",
505+
" # and excludes those with no ngram overlap with input.\n",
506+
")\n",
507+
"dynamic_prompt = FewShotPromptTemplate(\n",
508+
" # We provide an ExampleSelector instead of examples.\n",
509+
" example_selector=example_selector,\n",
510+
" example_prompt=example_prompt,\n",
511+
" prefix=\"Give the Spanish translation of every input\",\n",
512+
" suffix=\"Input: {sentence}\\nOutput:\", \n",
513+
" input_variables=[\"sentence\"],\n",
514+
")"
515+
]
516+
},
517+
{
518+
"cell_type": "code",
519+
"execution_count": 5,
520+
"id": "83fb218a",
521+
"metadata": {},
522+
"outputs": [
523+
{
524+
"name": "stdout",
525+
"output_type": "stream",
526+
"text": [
527+
"Give the Spanish translation of every input\n",
528+
"\n",
529+
"Input: Spot can run.\n",
530+
"Output: Spot puede correr.\n",
531+
"\n",
532+
"Input: See Spot run.\n",
533+
"Output: Ver correr a Spot.\n",
534+
"\n",
535+
"Input: My dog barks.\n",
536+
"Output: Mi perro ladra.\n",
537+
"\n",
538+
"Input: Spot can run fast.\n",
539+
"Output:\n"
540+
]
541+
}
542+
],
543+
"source": [
544+
"# An example input with large ngram overlap with \"Spot can run.\"\n",
545+
"# and no overlap with \"My dog barks.\"\n",
546+
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
547+
]
548+
},
549+
{
550+
"cell_type": "code",
551+
"execution_count": 6,
552+
"id": "485f5307",
553+
"metadata": {},
554+
"outputs": [
555+
{
556+
"name": "stdout",
557+
"output_type": "stream",
558+
"text": [
559+
"Give the Spanish translation of every input\n",
560+
"\n",
561+
"Input: Spot can run.\n",
562+
"Output: Spot puede correr.\n",
563+
"\n",
564+
"Input: See Spot run.\n",
565+
"Output: Ver correr a Spot.\n",
566+
"\n",
567+
"Input: Spot plays fetch.\n",
568+
"Output: Spot juega a buscar.\n",
569+
"\n",
570+
"Input: My dog barks.\n",
571+
"Output: Mi perro ladra.\n",
572+
"\n",
573+
"Input: Spot can run fast.\n",
574+
"Output:\n"
575+
]
576+
}
577+
],
578+
"source": [
579+
"# You can add examples to NGramOverlapExampleSelector as well.\n",
580+
"new_example = {\"input\": \"Spot plays fetch.\", \"output\": \"Spot juega a buscar.\"}\n",
581+
"\n",
582+
"example_selector.add_example(new_example)\n",
583+
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
584+
]
585+
},
586+
{
587+
"cell_type": "code",
588+
"execution_count": 7,
589+
"id": "606ce697",
590+
"metadata": {},
591+
"outputs": [
592+
{
593+
"name": "stdout",
594+
"output_type": "stream",
595+
"text": [
596+
"Give the Spanish translation of every input\n",
597+
"\n",
598+
"Input: Spot can run.\n",
599+
"Output: Spot puede correr.\n",
600+
"\n",
601+
"Input: See Spot run.\n",
602+
"Output: Ver correr a Spot.\n",
603+
"\n",
604+
"Input: Spot plays fetch.\n",
605+
"Output: Spot juega a buscar.\n",
606+
"\n",
607+
"Input: Spot can run fast.\n",
608+
"Output:\n"
609+
]
610+
}
611+
],
612+
"source": [
613+
"# You can set a threshold at which examples are excluded.\n",
614+
"# For example, setting threshold equal to 0.0\n",
615+
"# excludes examples with no ngram overlaps with input.\n",
616+
"# Since \"My dog barks.\" has no ngram overlaps with \"Spot can run fast.\"\n",
617+
"# it is excluded.\n",
618+
"example_selector.threshold=0.0\n",
619+
"print(dynamic_prompt.format(sentence=\"Spot can run fast.\"))"
620+
]
621+
},
622+
{
623+
"cell_type": "code",
624+
"execution_count": 87,
625+
"id": "7f8d72f7",
626+
"metadata": {},
627+
"outputs": [
628+
{
629+
"name": "stdout",
630+
"output_type": "stream",
631+
"text": [
632+
"Give the Spanish translation of every input\n",
633+
"\n",
634+
"Input: Spot can run.\n",
635+
"Output: Spot puede correr.\n",
636+
"\n",
637+
"Input: Spot plays fetch.\n",
638+
"Output: Spot juega a buscar.\n",
639+
"\n",
640+
"Input: Spot can play fetch.\n",
641+
"Output:\n"
642+
]
643+
}
644+
],
645+
"source": [
646+
"# Setting small nonzero threshold\n",
647+
"example_selector.threshold=0.09\n",
648+
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
649+
]
650+
},
651+
{
652+
"cell_type": "code",
653+
"execution_count": 88,
654+
"id": "09633aa8",
655+
"metadata": {},
656+
"outputs": [
657+
{
658+
"name": "stdout",
659+
"output_type": "stream",
660+
"text": [
661+
"Give the Spanish translation of every input\n",
662+
"\n",
663+
"Input: Spot can play fetch.\n",
664+
"Output:\n"
665+
]
666+
}
667+
],
668+
"source": [
669+
"# Setting threshold greater than 1.0\n",
670+
"example_selector.threshold=1.0+1e-9\n",
671+
"print(dynamic_prompt.format(sentence=\"Spot can play fetch.\"))"
672+
]
673+
},
442674
{
443675
"cell_type": "code",
444676
"execution_count": null,
445-
"id": "c746d6f4",
677+
"id": "39f30097",
446678
"metadata": {},
447679
"outputs": [],
448680
"source": []
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
"""Select and order examples based on ngram overlap score (sentence_bleu score).
2+
3+
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
4+
https://aclanthology.org/P02-1040.pdf
5+
"""
6+
from typing import Dict, List
7+
8+
import numpy as np
9+
from pydantic import BaseModel, root_validator
10+
11+
from langchain.prompts.example_selector.base import BaseExampleSelector
12+
from langchain.prompts.prompt import PromptTemplate
13+
14+
15+
def ngram_overlap_score(source: List[str], example: List[str]) -> float:
16+
"""Compute ngram overlap score of source and example as sentence_bleu score.
17+
18+
Use sentence_bleu with method1 smoothing function and auto reweighting.
19+
Return float value between 0.0 and 1.0 inclusive.
20+
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
21+
https://aclanthology.org/P02-1040.pdf
22+
"""
23+
from nltk.translate.bleu_score import ( # type: ignore
24+
SmoothingFunction,
25+
sentence_bleu,
26+
)
27+
28+
hypotheses = source[0].split()
29+
references = [s.split() for s in example]
30+
31+
return float(
32+
sentence_bleu(
33+
references,
34+
hypotheses,
35+
smoothing_function=SmoothingFunction().method1,
36+
auto_reweigh=True,
37+
)
38+
)
39+
40+
41+
class NGramOverlapExampleSelector(BaseExampleSelector, BaseModel):
42+
"""Select and order examples based on ngram overlap score (sentence_bleu score).
43+
44+
https://www.nltk.org/_modules/nltk/translate/bleu_score.html
45+
https://aclanthology.org/P02-1040.pdf
46+
"""
47+
48+
examples: List[dict]
49+
"""A list of the examples that the prompt template expects."""
50+
51+
example_prompt: PromptTemplate
52+
"""Prompt template used to format the examples."""
53+
54+
threshold: float = -1.0
55+
"""Threshold at which algorithm stops. Set to -1.0 by default.
56+
57+
For negative threshold:
58+
select_examples sorts examples by ngram_overlap_score, but excludes none.
59+
For threshold greater than 1.0:
60+
select_examples excludes all examples, and returns an empty list.
61+
For threshold equal to 0.0:
62+
select_examples sorts examples by ngram_overlap_score,
63+
and excludes examples with no ngram overlap with input.
64+
"""
65+
66+
@root_validator(pre=True)
67+
def check_dependencies(cls, values: Dict) -> Dict:
68+
"""Check that valid dependencies exist."""
69+
try:
70+
from nltk.translate.bleu_score import ( # noqa: disable=F401
71+
SmoothingFunction,
72+
sentence_bleu,
73+
)
74+
except ImportError as e:
75+
raise ValueError(
76+
"Not all the correct dependencies for this ExampleSelect exist"
77+
) from e
78+
79+
return values
80+
81+
def add_example(self, example: Dict[str, str]) -> None:
82+
"""Add new example to list."""
83+
self.examples.append(example)
84+
85+
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:
86+
"""Return list of examples sorted by ngram_overlap_score with input.
87+
88+
Descending order.
89+
Excludes any examples with ngram_overlap_score less than or equal to threshold.
90+
"""
91+
inputs = list(input_variables.values())
92+
examples = []
93+
k = len(self.examples)
94+
score = [0.0] * k
95+
first_prompt_template_key = self.example_prompt.input_variables[0]
96+
97+
for i in range(k):
98+
score[i] = ngram_overlap_score(
99+
inputs, [self.examples[i][first_prompt_template_key]]
100+
)
101+
102+
while True:
103+
arg_max = np.argmax(score)
104+
if (score[arg_max] < self.threshold) or abs(
105+
score[arg_max] - self.threshold
106+
) < 1e-9:
107+
break
108+
109+
examples.append(self.examples[arg_max])
110+
score[arg_max] = self.threshold - 1.0
111+
112+
return examples

0 commit comments

Comments
 (0)