Skip to content

Commit 92428a6

Browse files
committed
change embedding api
1 parent fc18988 commit 92428a6

File tree

10 files changed

+1227
-72
lines changed

10 files changed

+1227
-72
lines changed

nbs/embedding/base.ipynb

Lines changed: 1072 additions & 17 deletions
Large diffs are not rendered by default.

nbs/metric/base.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
"from tqdm import tqdm\n",
4747
"\n",
4848
"from ragas_annotator.prompt.base import Prompt\n",
49-
"from ragas_annotator.embedding.base import RagasEmbedding\n",
49+
"from ragas_annotator.embedding.base import BaseEmbedding\n",
5050
"from ragas_annotator.metric import MetricResult\n",
5151
"from ragas_annotator.llm import RagasLLM\n",
5252
"from ragas_annotator.project.core import Project\n",
@@ -119,7 +119,7 @@
119119
" # Run all tasks concurrently and return results\n",
120120
" return await asyncio.gather(*async_tasks)\n",
121121
" \n",
122-
" def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: RagasEmbedding,method: t.Dict[str, t.Any]):\n",
122+
" def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]):\n",
123123
" \n",
124124
" assert isinstance(self.prompt, Prompt)\n",
125125
" self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)\n",

nbs/prompt/base.ipynb

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,24 +61,28 @@
6161
" \n",
6262
" def format(self, **kwargs) -> str:\n",
6363
" \"\"\"Format the prompt with the provided variables.\"\"\"\n",
64-
" # Start with examples if we have them\n",
64+
"\n",
6565
" prompt_parts = []\n",
66-
" \n",
67-
" \n",
68-
" # Add instruction with variables filled in\n",
6966
" prompt_parts.append(self.instruction.format(**kwargs))\n",
67+
" prompt_parts.append(self._format_examples())\n",
68+
"\n",
69+
" # Combine all parts\n",
70+
" return \"\\n\\n\".join(prompt_parts)\n",
71+
" \n",
72+
" def _format_examples(self) -> str:\n",
7073
" \n",
7174
" # Add examples in a simple format\n",
75+
" examples = []\n",
7276
" if self.examples:\n",
73-
" prompt_parts.append(\"Examples:\")\n",
77+
" examples.append(\"Examples:\")\n",
7478
" for i, (inputs, output) in enumerate(self.examples, 1):\n",
7579
" example_input = \"\\n\".join([f\"{k}: {v}\" for k, v in inputs.items()])\n",
7680
" example_output = \"\\n\".join([f\"{k}: {v}\" for k, v in output.items()])\n",
7781
" \n",
78-
" prompt_parts.append(f\"Example {i}:\\nInput:\\n{example_input}\\nOutput:\\n{example_output}\")\n",
79-
"\n",
80-
" # Combine all parts\n",
81-
" return \"\\n\\n\".join(prompt_parts)\n",
82+
" examples.append(f\"Example {i}:\\nInput:\\n{example_input}\\nOutput:\\n{example_output}\")\n",
83+
" \n",
84+
" return \"\\n\\n\".join(examples) if examples else \"\"\n",
85+
" \n",
8286
" \n",
8387
" def add_example(self, inputs: t.Dict, output: t.Dict) -> None:\n",
8488
" \"\"\"\n",
@@ -105,7 +109,7 @@
105109
" \n",
106110
" def __str__(self) -> str:\n",
107111
" \"\"\"String representation showing the instruction.\"\"\"\n",
108-
" return f\"Prompt(instruction='{self.instruction}')\""
112+
" return f\"Prompt(instruction='{self.instruction}',\\n examples={self.examples})\""
109113
]
110114
},
111115
{
@@ -169,6 +173,45 @@
169173
"\n",
170174
"print(prompt.format(response=\"You can get a full refund if you miss your flight.\", expected_answer=\"Refunds depend on ticket type; only refundable tickets qualify for full refunds.\"))"
171175
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"metadata": {},
181+
"outputs": [
182+
{
183+
"name": "stdout",
184+
"output_type": "stream",
185+
"text": [
186+
"Prompt(instruction='Evaluate if given answer {response} is same as expected answer {expected_answer}',\n",
187+
" examples=Examples:\n",
188+
"\n",
189+
"Example 1:\n",
190+
"Input:\n",
191+
"response: You can get a full refund if you miss your flight.\n",
192+
"expected_answer: Refunds depend on ticket type; only refundable tickets qualify for full refunds.\n",
193+
"Output:\n",
194+
"score: fail\n",
195+
"\n",
196+
"Example 2:\n",
197+
"Input:\n",
198+
"response: Each passenger gets 1 free checked bag up to 23kg.\n",
199+
"expected_answer: Each passenger gets 1 free checked bag up to 23kg.\n",
200+
"Output:\n",
201+
"score: pass)\n"
202+
]
203+
}
204+
],
205+
"source": [
206+
"print(str(prompt))"
207+
]
208+
},
209+
{
210+
"cell_type": "code",
211+
"execution_count": null,
212+
"metadata": {},
213+
"outputs": [],
214+
"source": []
172215
}
173216
],
174217
"metadata": {

nbs/prompt/dynamic_few_shot.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"from abc import ABC, abstractmethod\n",
3030
"\n",
3131
"from ragas_annotator.prompt.base import Prompt\n",
32-
"from ragas_annotator.embedding.base import RagasEmbedding\n",
32+
"from ragas_annotator.embedding import BaseEmbedding\n",
3333
"\n",
3434
"class ExampleStore(ABC):\n",
3535
" @abstractmethod\n",
@@ -205,7 +205,7 @@
205205
" def from_prompt(\n",
206206
" cls,\n",
207207
" prompt: Prompt,\n",
208-
" embedding_model: RagasEmbedding,\n",
208+
" embedding_model: BaseEmbedding,\n",
209209
" num_examples: int = 3\n",
210210
" ) -> \"DynamicFewShotPrompt\":\n",
211211
" \"\"\"Create a DynamicFewShotPrompt from a Prompt object.\"\"\"\n",
@@ -251,10 +251,11 @@
251251
],
252252
"source": [
253253
"#| eval: false\n",
254+
"from ragas_annotator.embedding import ragas_embedding\n",
254255
"from ragas_annotator.prompt import Prompt\n",
255256
"from openai import OpenAI\n",
256257
"\n",
257-
"embedding = RagasEmbedding(client=OpenAI(),model=\"text-embedding-3-small\")\n",
258+
"embedding = ragas_embedding(provider=\"openai\", client=OpenAI(),model=\"text-embedding-3-small\")\n",
258259
"\n",
259260
"# Create a basic prompt\n",
260261
"prompt = Prompt(\n",

ragas_annotator/_modidx.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -119,16 +119,30 @@
119119
'ragas_annotator.dataset.Dataset.pop': ('dataset.html#dataset.pop', 'ragas_annotator/dataset.py'),
120120
'ragas_annotator.dataset.Dataset.save': ( 'dataset.html#dataset.save',
121121
'ragas_annotator/dataset.py')},
122-
'ragas_annotator.embedding.base': { 'ragas_annotator.embedding.base.RagasEmbedding': ( 'embedding/base.html#ragasembedding',
123-
'ragas_annotator/embedding/base.py'),
124-
'ragas_annotator.embedding.base.RagasEmbedding.aembed_document': ( 'embedding/base.html#ragasembedding.aembed_document',
125-
'ragas_annotator/embedding/base.py'),
126-
'ragas_annotator.embedding.base.RagasEmbedding.aembed_text': ( 'embedding/base.html#ragasembedding.aembed_text',
127-
'ragas_annotator/embedding/base.py'),
128-
'ragas_annotator.embedding.base.RagasEmbedding.embed_document': ( 'embedding/base.html#ragasembedding.embed_document',
122+
'ragas_annotator.embedding.base': { 'ragas_annotator.embedding.base.BaseEmbedding': ( 'embedding/base.html#baseembedding',
123+
'ragas_annotator/embedding/base.py'),
124+
'ragas_annotator.embedding.base.BaseEmbedding.aembed_document': ( 'embedding/base.html#baseembedding.aembed_document',
129125
'ragas_annotator/embedding/base.py'),
130-
'ragas_annotator.embedding.base.RagasEmbedding.embed_text': ( 'embedding/base.html#ragasembedding.embed_text',
131-
'ragas_annotator/embedding/base.py')},
126+
'ragas_annotator.embedding.base.BaseEmbedding.aembed_text': ( 'embedding/base.html#baseembedding.aembed_text',
127+
'ragas_annotator/embedding/base.py'),
128+
'ragas_annotator.embedding.base.BaseEmbedding.embed_document': ( 'embedding/base.html#baseembedding.embed_document',
129+
'ragas_annotator/embedding/base.py'),
130+
'ragas_annotator.embedding.base.BaseEmbedding.embed_text': ( 'embedding/base.html#baseembedding.embed_text',
131+
'ragas_annotator/embedding/base.py'),
132+
'ragas_annotator.embedding.base.OpenAIEmbeddings': ( 'embedding/base.html#openaiembeddings',
133+
'ragas_annotator/embedding/base.py'),
134+
'ragas_annotator.embedding.base.OpenAIEmbeddings.__init__': ( 'embedding/base.html#openaiembeddings.__init__',
135+
'ragas_annotator/embedding/base.py'),
136+
'ragas_annotator.embedding.base.OpenAIEmbeddings.aembed_document': ( 'embedding/base.html#openaiembeddings.aembed_document',
137+
'ragas_annotator/embedding/base.py'),
138+
'ragas_annotator.embedding.base.OpenAIEmbeddings.aembed_text': ( 'embedding/base.html#openaiembeddings.aembed_text',
139+
'ragas_annotator/embedding/base.py'),
140+
'ragas_annotator.embedding.base.OpenAIEmbeddings.embed_document': ( 'embedding/base.html#openaiembeddings.embed_document',
141+
'ragas_annotator/embedding/base.py'),
142+
'ragas_annotator.embedding.base.OpenAIEmbeddings.embed_text': ( 'embedding/base.html#openaiembeddings.embed_text',
143+
'ragas_annotator/embedding/base.py'),
144+
'ragas_annotator.embedding.base.ragas_embedding': ( 'embedding/base.html#ragas_embedding',
145+
'ragas_annotator/embedding/base.py')},
132146
'ragas_annotator.exceptions': { 'ragas_annotator.exceptions.DuplicateError': ( 'utils/exceptions.html#duplicateerror',
133147
'ragas_annotator/exceptions.py'),
134148
'ragas_annotator.exceptions.NotFoundError': ( 'utils/exceptions.html#notfounderror',
@@ -417,6 +431,8 @@
417431
'ragas_annotator/prompt/base.py'),
418432
'ragas_annotator.prompt.base.Prompt.__str__': ( 'prompt/base.html#prompt.__str__',
419433
'ragas_annotator/prompt/base.py'),
434+
'ragas_annotator.prompt.base.Prompt._format_examples': ( 'prompt/base.html#prompt._format_examples',
435+
'ragas_annotator/prompt/base.py'),
420436
'ragas_annotator.prompt.base.Prompt._validate_instruction': ( 'prompt/base.html#prompt._validate_instruction',
421437
'ragas_annotator/prompt/base.py'),
422438
'ragas_annotator.prompt.base.Prompt.add_example': ( 'prompt/base.html#prompt.add_example',
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1-
from ragas_annotator.embedding.base import RagasEmbedding
1+
from ragas_annotator.embedding.base import BaseEmbedding
2+
from ragas_annotator.embedding.base import ragas_embedding
23

3-
__all__ = ['RagasEmbedding']
4+
__all__ = ['ragas_embedding','BaseEmbedding']

ragas_annotator/embedding/base.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,67 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/embedding/base.ipynb.
22

33
# %% auto 0
4-
__all__ = ['RagasEmbedding']
4+
__all__ = ['BaseEmbedding', 'OpenAIEmbeddings', 'ragas_embedding']
55

66
# %% ../../nbs/embedding/base.ipynb 2
77
import typing as t
8-
from dataclasses import dataclass
8+
from abc import ABC, abstractmethod
99

10+
#TODO: Add support for other providers like HuggingFace, Cohere, etc.
11+
#TODO: handle async calls properly and ensure that the client supports async if needed.
1012

11-
@dataclass
12-
class RagasEmbedding:
13-
client: t.Any
14-
model: str
13+
class BaseEmbedding(ABC):
14+
@abstractmethod
15+
def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
16+
pass
1517

16-
def embed_text(self,text:str,**kwargs: t.Any) -> t.List[float]:
17-
18+
@abstractmethod
19+
async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
20+
pass
21+
22+
@abstractmethod
23+
def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
24+
pass
25+
26+
@abstractmethod
27+
async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
28+
pass
29+
30+
31+
class OpenAIEmbeddings(BaseEmbedding):
32+
def __init__(self, client: t.Any, model: str):
33+
self.client = client
34+
self.model = model
35+
36+
def embed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
1837
return self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding
19-
20-
async def aembed_text(self,text:str,**kwargs: t.Any):
21-
22-
await self.client.embeddings.create(input=text, model=self.model, **kwargs).data[0].embedding
2338

39+
async def aembed_text(self, text: str, **kwargs: t.Any) -> t.List[float]:
40+
response = await self.client.embeddings.create(input=text, model=self.model, **kwargs)
41+
return response.data[0].embedding
2442

25-
def embed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:
43+
def embed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
2644
embeddings = self.client.embeddings.create(input=documents, model=self.model, **kwargs)
2745
return [embedding.embedding for embedding in embeddings.data]
2846

29-
async def aembed_document(self,documents:t.List[str],**kwargs: t.Any) -> t.List[t.List[float]]:
47+
async def aembed_document(self, documents: t.List[str], **kwargs: t.Any) -> t.List[t.List[float]]:
3048
embeddings = await self.client.embeddings.create(input=documents, model=self.model, **kwargs)
3149
return [embedding.embedding for embedding in embeddings.data]
32-
50+
51+
52+
def ragas_embedding(provider: str, model: str, client: t.Any) -> BaseEmbedding:
53+
"""
54+
Factory function to create an embedding instance based on the provider.
55+
56+
Args:
57+
provider (str): The name of the embedding provider (e.g., "openai").
58+
model (str): The model name to use for embeddings.
59+
**kwargs: Additional arguments for the provider's client.
60+
61+
Returns:
62+
BaseEmbedding: An instance of the specified embedding provider.
63+
"""
64+
if provider.lower() == "openai":
65+
return OpenAIEmbeddings(client=client, model=model)
66+
67+
raise ValueError(f"Unsupported provider: {provider}")

ragas_annotator/metric/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tqdm import tqdm
1616

1717
from ..prompt.base import Prompt
18-
from ..embedding.base import RagasEmbedding
18+
from ..embedding.base import BaseEmbedding
1919
from . import MetricResult
2020
from ..llm import RagasLLM
2121
from ..project.core import Project
@@ -88,7 +88,7 @@ async def abatch_score(self, inputs: t.List[t.Dict[str, t.Any]], reasoning: bool
8888
# Run all tasks concurrently and return results
8989
return await asyncio.gather(*async_tasks)
9090

91-
def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: RagasEmbedding,method: t.Dict[str, t.Any]):
91+
def train(self,project:Project, experiment_names: t.List[str], model:NotionModel, embedding_model: BaseEmbedding,method: t.Dict[str, t.Any]):
9292

9393
assert isinstance(self.prompt, Prompt)
9494
self.prompt = DynamicFewShotPrompt.from_prompt(self.prompt,embedding_model)

ragas_annotator/prompt/base.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,28 @@ def _validate_instruction(self):
4141

4242
def format(self, **kwargs) -> str:
4343
"""Format the prompt with the provided variables."""
44-
# Start with examples if we have them
44+
4545
prompt_parts = []
46-
47-
48-
# Add instruction with variables filled in
4946
prompt_parts.append(self.instruction.format(**kwargs))
47+
prompt_parts.append(self._format_examples())
48+
49+
# Combine all parts
50+
return "\n\n".join(prompt_parts)
51+
52+
def _format_examples(self) -> str:
5053

5154
# Add examples in a simple format
55+
examples = []
5256
if self.examples:
53-
prompt_parts.append("Examples:")
57+
examples.append("Examples:")
5458
for i, (inputs, output) in enumerate(self.examples, 1):
5559
example_input = "\n".join([f"{k}: {v}" for k, v in inputs.items()])
5660
example_output = "\n".join([f"{k}: {v}" for k, v in output.items()])
5761

58-
prompt_parts.append(f"Example {i}:\nInput:\n{example_input}\nOutput:\n{example_output}")
59-
60-
# Combine all parts
61-
return "\n\n".join(prompt_parts)
62+
examples.append(f"Example {i}:\nInput:\n{example_input}\nOutput:\n{example_output}")
63+
64+
return "\n\n".join(examples) if examples else ""
65+
6266

6367
def add_example(self, inputs: t.Dict, output: t.Dict) -> None:
6468
"""
@@ -85,4 +89,4 @@ def add_example(self, inputs: t.Dict, output: t.Dict) -> None:
8589

8690
def __str__(self) -> str:
8791
"""String representation showing the instruction."""
88-
return f"Prompt(instruction='{self.instruction}')"
92+
return f"Prompt(instruction='{self.instruction}',\n examples={self.examples})"

ragas_annotator/prompt/dynamic_few_shot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from abc import ABC, abstractmethod
1010

1111
from .base import Prompt
12-
from ..embedding.base import RagasEmbedding
12+
from ..embedding import BaseEmbedding
1313

1414
class ExampleStore(ABC):
1515
@abstractmethod
@@ -185,7 +185,7 @@ def add_example(self, inputs: t.Dict, output: t.Dict) -> None:
185185
def from_prompt(
186186
cls,
187187
prompt: Prompt,
188-
embedding_model: RagasEmbedding,
188+
embedding_model: BaseEmbedding,
189189
num_examples: int = 3
190190
) -> "DynamicFewShotPrompt":
191191
"""Create a DynamicFewShotPrompt from a Prompt object."""

0 commit comments

Comments
 (0)