Skip to content

Commit 2802019

Browse files
authored
feat: add native support for OpenAI and AzureOpenAI (#261)
- Add support for OpenAI and AzureOpenai both embeddings and LLMs - rework `RagasLLM` with an async version of generate - checks for API_KEYs and tests that ensure it is working
1 parent c5b586e commit 2802019

29 files changed

+785
-288
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ lint: ## Running lint checker: ruff
1818
@ruff check src docs tests
1919
type: ## Running type checker: pyright
2020
@echo "(pyright) Typechecking codebase..."
21-
@pyright src
21+
PYRIGHT_PYTHON_FORCE_VERSION=latest pyright src
2222
clean: ## Clean all generated files
2323
@echo "Cleaning all generated files..."
2424
@cd $(GIT_ROOT)/docs && make clean

docs/howtos/customisations/azure-openai.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@
164164
" openai_api_base=\"https://your-endpoint.openai.azure.com/\",\n",
165165
" openai_api_type=\"azure\",\n",
166166
")\n",
167-
"# wrapper around azure_model \n",
167+
"# wrapper around azure_model\n",
168168
"ragas_azure_model = LangchainLLM(azure_model)\n",
169169
"# patch the new RagasLLM instance\n",
170170
"answer_relevancy.llm = ragas_azure_model\n",

docs/howtos/integrations/langfuse.ipynb

Lines changed: 70 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,14 @@
2525
"outputs": [],
2626
"source": [
2727
"import os\n",
28+
"\n",
2829
"# TODO REMOVE ENVIRONMENT VARIABLES!!!\n",
2930
"# get keys for your project from https://cloud.langfuse.com\n",
3031
"os.environ[\"LANGFUSE_PUBLIC_KEY\"] = \"\"\n",
3132
"os.environ[\"LANGFUSE_SECRET_KEY\"] = \"\"\n",
32-
" \n",
33+
"\n",
3334
"# your openai key\n",
34-
"#os.environ[\"OPENAI_API_KEY\"] = \"\""
35+
"# os.environ[\"OPENAI_API_KEY\"] = \"\""
3536
]
3637
},
3738
{
@@ -86,7 +87,7 @@
8687
"source": [
8788
"from datasets import load_dataset\n",
8889
"\n",
89-
"fiqa_eval = load_dataset(\"explodinggradients/fiqa\", \"ragas_eval\")['baseline']\n",
90+
"fiqa_eval = load_dataset(\"explodinggradients/fiqa\", \"ragas_eval\")[\"baseline\"]\n",
9091
"fiqa_eval"
9192
]
9293
},
@@ -180,7 +181,7 @@
180181
],
181182
"source": [
182183
"row = fiqa_eval[0]\n",
183-
"row['question'], row['answer']"
184+
"row[\"question\"], row[\"answer\"]"
184185
]
185186
},
186187
{
@@ -199,7 +200,7 @@
199200
"outputs": [],
200201
"source": [
201202
"from langfuse import Langfuse\n",
202-
" \n",
203+
"\n",
203204
"langfuse = Langfuse()"
204205
]
205206
},
@@ -223,7 +224,7 @@
223224
" for m in metrics:\n",
224225
" print(f\"calculating {m.name}\")\n",
225226
" scores[m.name] = m.score_single(\n",
226-
" {'question': query, 'contexts': chunks, 'answer': answer}\n",
227+
" {\"question\": query, \"contexts\": chunks, \"answer\": answer}\n",
227228
" )\n",
228229
" return scores"
229230
]
@@ -272,26 +273,38 @@
272273
}
273274
],
274275
"source": [
275-
"from langfuse.model import CreateTrace, CreateSpan, CreateGeneration, CreateEvent, CreateScore\n",
276+
"from langfuse.model import (\n",
277+
" CreateTrace,\n",
278+
" CreateSpan,\n",
279+
" CreateGeneration,\n",
280+
" CreateEvent,\n",
281+
" CreateScore,\n",
282+
")\n",
276283
"\n",
277284
"# start a new trace when you get a question\n",
278-
"question = row['question']\n",
279-
"trace = langfuse.trace(CreateTrace(name = \"rag\"))\n",
285+
"question = row[\"question\"]\n",
286+
"trace = langfuse.trace(CreateTrace(name=\"rag\"))\n",
280287
"\n",
281288
"# retrieve the relevant chunks\n",
282289
"# chunks = get_similar_chunks(question)\n",
283-
"contexts = row['contexts']\n",
290+
"contexts = row[\"contexts\"]\n",
284291
"# pass it as span\n",
285-
"trace.span(CreateSpan(\n",
286-
" name = \"retrieval\", input={'question': question}, output={'contexts': contexts}\n",
287-
"))\n",
292+
"trace.span(\n",
293+
" CreateSpan(\n",
294+
" name=\"retrieval\", input={\"question\": question}, output={\"contexts\": contexts}\n",
295+
" )\n",
296+
")\n",
288297
"\n",
289298
"# use llm to generate a answer with the chunks\n",
290299
"# answer = get_response_from_llm(question, chunks)\n",
291-
"answer = row['answer']\n",
292-
"trace.span(CreateSpan(\n",
293-
" name = \"generation\", input={'question': question, 'contexts': contexts}, output={'answer': answer}\n",
294-
"))\n",
300+
"answer = row[\"answer\"]\n",
301+
"trace.span(\n",
302+
" CreateSpan(\n",
303+
" name=\"generation\",\n",
304+
" input={\"question\": question, \"contexts\": contexts},\n",
305+
" output={\"answer\": answer},\n",
306+
" )\n",
307+
")\n",
295308
"\n",
296309
"# compute scores for the question, context, answer tuple\n",
297310
"ragas_scores = score_with_ragas(question, contexts, answer)\n",
@@ -357,20 +370,31 @@
357370
"metadata": {},
358371
"outputs": [],
359372
"source": [
360-
"from langfuse.model import CreateTrace, CreateSpan, CreateGeneration, CreateEvent, CreateScore\n",
373+
"from langfuse.model import (\n",
374+
" CreateTrace,\n",
375+
" CreateSpan,\n",
376+
" CreateGeneration,\n",
377+
" CreateEvent,\n",
378+
" CreateScore,\n",
379+
")\n",
380+
"\n",
361381
"# fiqa traces\n",
362382
"for interaction in fiqa_eval.select(range(10, 20)):\n",
363-
" trace = langfuse.trace(CreateTrace(name = \"rag\"))\n",
364-
" trace.span(CreateSpan(\n",
365-
" name = \"retrieval\", \n",
366-
" input={'question': question}, \n",
367-
" output={'contexts': contexts}\n",
368-
" ))\n",
369-
" trace.span(CreateSpan(\n",
370-
" name = \"generation\", \n",
371-
" input={'question': question, 'contexts': contexts}, \n",
372-
" output={'answer': answer}\n",
373-
" ))\n",
383+
" trace = langfuse.trace(CreateTrace(name=\"rag\"))\n",
384+
" trace.span(\n",
385+
" CreateSpan(\n",
386+
" name=\"retrieval\",\n",
387+
" input={\"question\": question},\n",
388+
" output={\"contexts\": contexts},\n",
389+
" )\n",
390+
" )\n",
391+
" trace.span(\n",
392+
" CreateSpan(\n",
393+
" name=\"generation\",\n",
394+
" input={\"question\": question, \"contexts\": contexts},\n",
395+
" output={\"answer\": answer},\n",
396+
" )\n",
397+
" )\n",
374398
"\n",
375399
"# await that Langfuse SDK has processed all events before trying to retrieve it in the next step\n",
376400
"langfuse.flush()"
@@ -393,12 +417,10 @@
393417
"source": [
394418
"def get_traces(name=None, limit=None, user_id=None):\n",
395419
" all_data = []\n",
396-
" page = 1 \n",
420+
" page = 1\n",
397421
"\n",
398422
" while True:\n",
399-
" response = langfuse.client.trace.list(\n",
400-
" name=name, page=page, user_id=user_id\n",
401-
" )\n",
423+
" response = langfuse.client.trace.list(name=name, page=page, user_id=user_id)\n",
402424
" if not response.data:\n",
403425
" break\n",
404426
" page += 1\n",
@@ -430,7 +452,7 @@
430452
"from random import sample\n",
431453
"\n",
432454
"NUM_TRACES_TO_SAMPLE = 3\n",
433-
"traces = get_traces(name='rag', limit=5)\n",
455+
"traces = get_traces(name=\"rag\", limit=5)\n",
434456
"traces_sample = sample(traces, NUM_TRACES_TO_SAMPLE)\n",
435457
"\n",
436458
"len(traces_sample)"
@@ -464,15 +486,15 @@
464486
"for t in traces_sample:\n",
465487
" observations = [langfuse.client.observations.get(o) for o in t.observations]\n",
466488
" for o in observations:\n",
467-
" if o.name == 'retrieval':\n",
468-
" question = o.input['question']\n",
469-
" contexts = o.output['contexts']\n",
470-
" if o.name=='generation':\n",
471-
" answer = o.output['answer']\n",
472-
" evaluation_batch['question'].append(question)\n",
473-
" evaluation_batch['contexts'].append(contexts)\n",
474-
" evaluation_batch['answer'].append(answer)\n",
475-
" evaluation_batch['trace_id'].append(t.id)"
489+
" if o.name == \"retrieval\":\n",
490+
" question = o.input[\"question\"]\n",
491+
" contexts = o.output[\"contexts\"]\n",
492+
" if o.name == \"generation\":\n",
493+
" answer = o.output[\"answer\"]\n",
494+
" evaluation_batch[\"question\"].append(question)\n",
495+
" evaluation_batch[\"contexts\"].append(contexts)\n",
496+
" evaluation_batch[\"answer\"].append(answer)\n",
497+
" evaluation_batch[\"trace_id\"].append(t.id)"
476498
]
477499
},
478500
{
@@ -671,10 +693,11 @@
671693
"\n",
672694
"for _, row in df.iterrows():\n",
673695
" for metric_name in [\"faithfulness\", \"answer_relevancy\"]:\n",
674-
" langfuse.score(InitialScore(\n",
675-
" name=metric_name,\n",
676-
" value=row[metric_name],\n",
677-
" trace_id=row[\"trace_id\"]))"
696+
" langfuse.score(\n",
697+
" InitialScore(\n",
698+
" name=metric_name, value=row[metric_name], trace_id=row[\"trace_id\"]\n",
699+
" )\n",
700+
" )"
678701
]
679702
},
680703
{

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ dependencies = [
66
"sentence-transformers",
77
"datasets",
88
"tiktoken",
9-
"langchain>=0.0.288",
9+
"langchain",
1010
"openai",
1111
"pysbd>=0.3.4",
12+
"nest-asyncio",
1213
]
1314
dynamic = ["version", "readme"]
1415

src/ragas/async_utils.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Async utils."""
22
import asyncio
3-
from itertools import zip_longest
4-
from typing import Any, Coroutine, Iterable, List
3+
from typing import Any, Coroutine, List
54

65

76
def run_async_tasks(
@@ -10,50 +9,40 @@ def run_async_tasks(
109
progress_bar_desc: str = "Running async tasks",
1110
) -> List[Any]:
1211
"""Run a list of async tasks."""
13-
1412
tasks_to_execute: List[Any] = tasks
15-
if show_progress:
13+
14+
# if running in notebook, use nest_asyncio to hijack the event loop
15+
try:
16+
loop = asyncio.get_running_loop()
1617
try:
1718
import nest_asyncio
18-
from tqdm.asyncio import tqdm
19-
20-
# jupyter notebooks already have an event loop running
21-
# we need to reuse it instead of creating a new one
19+
except ImportError:
20+
raise RuntimeError(
21+
"nest_asyncio is required to run async tasks in jupyter. Please install it via `pip install nest_asyncio`." # noqa
22+
)
23+
else:
2224
nest_asyncio.apply()
23-
loop = asyncio.get_event_loop()
25+
except RuntimeError:
26+
loop = asyncio.new_event_loop()
27+
28+
# gather tasks to run
29+
if show_progress:
30+
from tqdm.asyncio import tqdm
31+
32+
async def _gather() -> List[Any]:
33+
"gather tasks and show progress bar"
34+
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
35+
36+
else: # don't show_progress
2437

25-
async def _tqdm_gather() -> List[Any]:
26-
return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc)
38+
async def _gather() -> List[Any]:
39+
return await asyncio.gather(*tasks_to_execute)
2740

28-
tqdm_outputs: List[Any] = loop.run_until_complete(_tqdm_gather())
29-
return tqdm_outputs
41+
try:
42+
outputs: List[Any] = loop.run_until_complete(_gather())
43+
except Exception as e:
3044
# run the operation w/o tqdm on hitting a fatal
3145
# may occur in some environments where tqdm.asyncio
3246
# is not supported
33-
except ImportError as e:
34-
print(e)
35-
except Exception:
36-
pass
37-
38-
async def _gather() -> List[Any]:
39-
return await asyncio.gather(*tasks_to_execute)
40-
41-
outputs: List[Any] = asyncio.run(_gather())
47+
raise RuntimeError("Fatal error occurred while running async tasks.", e)
4248
return outputs
43-
44-
45-
def chunks(iterable: Iterable, size: int) -> Iterable:
46-
args = [iter(iterable)] * size
47-
return zip_longest(*args, fillvalue=None)
48-
49-
50-
async def batch_gather(
51-
tasks: List[Coroutine], batch_size: int = 10, verbose: bool = False
52-
) -> List[Any]:
53-
output: List[Any] = []
54-
for task_chunk in chunks(tasks, batch_size):
55-
output_chunk = await asyncio.gather(*task_chunk)
56-
output.extend(output_chunk)
57-
if verbose:
58-
print(f"Completed {len(output)} out of {len(tasks)} tasks")
59-
return output

src/ragas/embeddings/__init__.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,13 @@
1-
from ragas.embeddings.base import HuggingfaceEmbeddings, OpenAIEmbeddings
1+
from ragas.embeddings.base import (
2+
AzureOpenAIEmbeddings,
3+
HuggingfaceEmbeddings,
4+
OpenAIEmbeddings,
5+
RagasEmbeddings,
6+
)
27

3-
__all__ = ["HuggingfaceEmbeddings", "OpenAIEmbeddings"]
8+
__all__ = [
9+
"HuggingfaceEmbeddings",
10+
"OpenAIEmbeddings",
11+
"AzureOpenAIEmbeddings",
12+
"RagasEmbeddings",
13+
]

0 commit comments

Comments
 (0)