Skip to content

Commit 2459574

Browse files
committed
facilitate few shot retrieval
1 parent 5f6bc59 commit 2459574

File tree

9 files changed

+895
-1
lines changed

9 files changed

+895
-1
lines changed

nbs/embedding/base.ipynb

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"data": {
10+
"text/plain": [
11+
"True"
12+
]
13+
},
14+
"execution_count": null,
15+
"metadata": {},
16+
"output_type": "execute_result"
17+
}
18+
],
19+
"source": [
20+
"#| default_exp embedding.base"
21+
]
22+
},
23+
{
24+
"cell_type": "markdown",
25+
"metadata": {},
26+
"source": [
27+
"## RagasEmbedding"
28+
]
29+
},
30+
{
31+
"cell_type": "code",
32+
"execution_count": null,
33+
"metadata": {},
34+
"outputs": [],
35+
"source": [
36+
"#| export\n",
37+
"\n",
38+
"import typing as t\n",
39+
"from dataclasses import dataclass\n",
40+
"\n",
41+
"\n",
42+
"@dataclass\n",
43+
"class RagasEmbedding:\n",
44+
" client: t.Any\n",
45+
" model: str\n",
46+
" \n",
47+
" def embed_text(self,text:str,**kwargs: t.Any) -> t.List[float]:\n",
48+
" \n",
49+
" return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding\n",
50+
" \n",
51+
" async def aembed_text(self,text:str,**kwargs: t.Any):\n",
52+
" \n",
53+
" await self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding\n",
54+
" \n",
55+
" \n",
56+
" def embed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:\n",
57+
" embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n",
58+
" return [embedding.embedding for embedding in embeddings.data]\n",
59+
" \n",
60+
" async def aembed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:\n",
61+
" embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs)\n",
62+
" return [embedding.embedding for embedding in embeddings.data]\n"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"metadata": {},
68+
"source": [
69+
"### Example Usage"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"#| eval: false\n",
79+
"\n",
80+
"from openai import OpenAI\n",
81+
"embedding = RagasEmbedding(client=OpenAI(),model=\"text-embedding-3-small\")\n",
82+
"embedding.embed_text(\"Hello, world!\")"
83+
]
84+
}
85+
],
86+
"metadata": {
87+
"kernelspec": {
88+
"display_name": "python3",
89+
"language": "python",
90+
"name": "python3"
91+
}
92+
},
93+
"nbformat": 4,
94+
"nbformat_minor": 2
95+
}

nbs/metric/base.ipynb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@
9494
" async_tasks.append(self.ascore(reasoning=reasoning, n=n, **input_dict))\n",
9595
" \n",
9696
" # Run all tasks concurrently and return results\n",
97-
" return await asyncio.gather(*async_tasks)"
97+
" return await asyncio.gather(*async_tasks)\n",
98+
" \n",
99+
" \n",
100+
" "
98101
]
99102
},
100103
{

nbs/prompt/base.ipynb

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"#| default_exp prompt.base"
10+
]
11+
},
12+
{
13+
"cell_type": "markdown",
14+
"metadata": {},
15+
"source": [
16+
"## Prompt"
17+
]
18+
},
19+
{
20+
"cell_type": "code",
21+
"execution_count": null,
22+
"metadata": {},
23+
"outputs": [],
24+
"source": [
25+
"#| export\n",
26+
"\n",
27+
"import typing as t\n",
28+
"import re\n",
29+
"\n",
30+
"class Prompt:\n",
31+
" def __init__(\n",
32+
" self,\n",
33+
" instruction: str,\n",
34+
" examples: t.Optional[t.List[t.Tuple[t.Dict, t.Dict]]] = None\n",
35+
" ):\n",
36+
" \"\"\"\n",
37+
" Create a simple prompt object.\n",
38+
" \n",
39+
" Parameters:\n",
40+
" -----------\n",
41+
" instruction : str\n",
42+
" The prompt instruction template with placeholders like {response}, {expected_answer}\n",
43+
" examples : Optional[List[Tuple[Dict, Dict]]]\n",
44+
" List of (input_dict, output_dict) pairs for few-shot learning\n",
45+
" \"\"\"\n",
46+
" self.instruction = instruction\n",
47+
" self.examples = []\n",
48+
" \n",
49+
" # Validate the instruction\n",
50+
" self._validate_instruction()\n",
51+
" \n",
52+
" # Add examples if provided\n",
53+
" if examples:\n",
54+
" for inputs, output in examples:\n",
55+
" self.add_example(inputs, output)\n",
56+
" \n",
57+
" def _validate_instruction(self):\n",
58+
" \"\"\"Ensure the instruction contains at least one placeholder.\"\"\"\n",
59+
" if not re.findall(r\"\\{(\\w+)\\}\", self.instruction):\n",
60+
" raise ValueError(\"Instruction must contain at least one placeholder like {response}\")\n",
61+
" \n",
62+
" def format(self, **kwargs) -> str:\n",
63+
" \"\"\"Format the prompt with the provided variables.\"\"\"\n",
64+
" # Start with examples if we have them\n",
65+
" prompt_parts = []\n",
66+
" \n",
67+
" \n",
68+
" # Add instruction with variables filled in\n",
69+
" prompt_parts.append(self.instruction.format(**kwargs))\n",
70+
" \n",
71+
" # Add examples in a simple format\n",
72+
" if self.examples:\n",
73+
" prompt_parts.append(\"Examples:\")\n",
74+
" for i, (inputs, output) in enumerate(self.examples, 1):\n",
75+
" example_input = \"\\n\".join([f\"{k}: {v}\" for k, v in inputs.items()])\n",
76+
" example_output = \"\\n\".join([f\"{k}: {v}\" for k, v in output.items()])\n",
77+
" \n",
78+
" prompt_parts.append(f\"Example {i}:\\nInput:\\n{example_input}\\nOutput:\\n{example_output}\")\n",
79+
"\n",
80+
" # Combine all parts\n",
81+
" return \"\\n\\n\".join(prompt_parts)\n",
82+
" \n",
83+
" def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n",
84+
" \"\"\"\n",
85+
" Add an example to the prompt.\n",
86+
" \n",
87+
" Parameters:\n",
88+
" -----------\n",
89+
" inputs : Dict\n",
90+
" Dictionary of input values\n",
91+
" output : Dict\n",
92+
" Dictionary of output values\n",
93+
" \n",
94+
" Raises:\n",
95+
" -------\n",
96+
" TypeError\n",
97+
" If inputs or output is not a dictionary\n",
98+
" \"\"\"\n",
99+
" if not isinstance(inputs, dict):\n",
100+
" raise TypeError(f\"Expected inputs to be dict, got {type(inputs).__name__}\")\n",
101+
" if not isinstance(output, dict):\n",
102+
" raise TypeError(f\"Expected output to be dict, got {type(output).__name__}\")\n",
103+
" \n",
104+
" self.examples.append((inputs, output))\n",
105+
" \n",
106+
" def __str__(self) -> str:\n",
107+
" \"\"\"String representation showing the instruction.\"\"\"\n",
108+
" return f\"Prompt(instruction='{self.instruction}')\""
109+
]
110+
},
111+
{
112+
"cell_type": "markdown",
113+
"metadata": {},
114+
"source": [
115+
"### Example Usage"
116+
]
117+
},
118+
{
119+
"cell_type": "code",
120+
"execution_count": null,
121+
"metadata": {},
122+
"outputs": [
123+
{
124+
"name": "stdout",
125+
"output_type": "stream",
126+
"text": [
127+
"Evaluate if given answer You can get a full refund if you miss your flight. is same as expected answer Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n",
128+
"\n",
129+
"Examples:\n",
130+
"\n",
131+
"Example 1:\n",
132+
"Input:\n",
133+
"response: You can get a full refund if you miss your flight.\n",
134+
"expected_answer: Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n",
135+
"Output:\n",
136+
"score: fail\n",
137+
"\n",
138+
"Example 2:\n",
139+
"Input:\n",
140+
"response: Each passenger gets 1 free checked bag up to 23kg.\n",
141+
"expected_answer: Each passenger gets 1 free checked bag up to 23kg.\n",
142+
"Output:\n",
143+
"score: pass\n"
144+
]
145+
}
146+
],
147+
"source": [
148+
"# Create a basic prompt\n",
149+
"prompt = Prompt(\n",
150+
" instruction=\"Evaluate if given answer {response} is same as expected answer {expected_answer}\"\n",
151+
")\n",
152+
"\n",
153+
"# Add examples with dict inputs and dict outputs\n",
154+
"prompt.add_example(\n",
155+
" {\n",
156+
" \"response\": \"You can get a full refund if you miss your flight.\",\n",
157+
" \"expected_answer\": \"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"\n",
158+
" },\n",
159+
" {\"score\": \"fail\"}\n",
160+
")\n",
161+
"\n",
162+
"prompt.add_example(\n",
163+
" {\n",
164+
" \"response\": \"Each passenger gets 1 free checked bag up to 23kg.\",\n",
165+
" \"expected_answer\": \"Each passenger gets 1 free checked bag up to 23kg.\"\n",
166+
" },\n",
167+
" {\"score\": \"pass\"}\n",
168+
")\n",
169+
"\n",
170+
"print(prompt.format(response=\"You can get a full refund if you miss your flight.\", expected_answer=\"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"))"
171+
]
172+
}
173+
],
174+
"metadata": {
175+
"kernelspec": {
176+
"display_name": "python3",
177+
"language": "python",
178+
"name": "python3"
179+
}
180+
},
181+
"nbformat": 4,
182+
"nbformat_minor": 2
183+
}

0 commit comments

Comments
 (0)