Skip to content

Commit 4f260c8

Browse files
authored
fix: added support for azure LLM (#133)
fixes #114 #126 Shoutout to @gabriead for helping test this out :)
1 parent 5de2347 commit 4f260c8

File tree

2 files changed

+36
-30
lines changed

2 files changed

+36
-30
lines changed

docs/quickstart.ipynb

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
"\n",
6161
"Ragas performs a `ground_truth` free evaluation of your RAG pipelines. This is because for most people building a gold labeled dataset which represents in the distribution they get in production is a very expensive process.\n",
6262
"\n",
63+
"**Note:** *While originially ragas was aimed at `ground_truth` free evalutions there is some aspects of the RAG pipeline that need `ground_truth` in order to measure. We're in the process of building a testset generation features that will make it easier. Checkout [issue#136](https://github.com/explodinggradients/ragas/issues/136) for more details.*\n",
64+
"\n",
6365
"Hence to work with ragas all you need are the following data\n",
6466
"- question: `list[str]` - These are the questions you RAG pipeline will be evaluated on. \n",
6567
"- answer: `list[str]` - The answer generated from the RAG pipeline and give to the user.\n",
@@ -73,7 +75,7 @@
7375
},
7476
{
7577
"cell_type": "code",
76-
"execution_count": 8,
78+
"execution_count": 1,
7779
"id": "b658e02f",
7880
"metadata": {},
7981
"outputs": [
@@ -87,7 +89,7 @@
8789
{
8890
"data": {
8991
"application/vnd.jupyter.widget-view+json": {
90-
"model_id": "e481f1b6ae824149aaf5afe96330fda3",
92+
"model_id": "a2dfebb012dd4b79b3a6ed951ce0d406",
9193
"version_major": 2,
9294
"version_minor": 0
9395
},
@@ -109,7 +111,7 @@
109111
"})"
110112
]
111113
},
112-
"execution_count": 8,
114+
"execution_count": 1,
113115
"metadata": {},
114116
"output_type": "execute_result"
115117
}
@@ -141,7 +143,7 @@
141143
},
142144
{
143145
"cell_type": "code",
144-
"execution_count": 9,
146+
"execution_count": 3,
145147
"id": "f17bcf9d",
146148
"metadata": {},
147149
"outputs": [],
@@ -185,7 +187,7 @@
185187
},
186188
{
187189
"cell_type": "code",
188-
"execution_count": 10,
190+
"execution_count": null,
189191
"id": "22eb6f97",
190192
"metadata": {},
191193
"outputs": [
@@ -200,7 +202,7 @@
200202
"name": "stderr",
201203
"output_type": "stream",
202204
"text": [
203-
"100%|█████████████████████████████████████████████████████████████| 1/1 [00:06<00:00, 6.57s/it]\n"
205+
"100%|████████████████████████████████████████████████████████████| 2/2 [04:08<00:00, 124.31s/it]\n"
204206
]
205207
},
206208
{
@@ -214,7 +216,7 @@
214216
"name": "stderr",
215217
"output_type": "stream",
216218
"text": [
217-
"100%|█████████████████████████████████████████████████████████████| 1/1 [00:28<00:00, 28.82s/it]\n"
219+
"100%|████████████████████████████████████████████████████████████| 2/2 [06:29<00:00, 194.60s/it]\n"
218220
]
219221
},
220222
{
@@ -228,7 +230,7 @@
228230
"name": "stderr",
229231
"output_type": "stream",
230232
"text": [
231-
"100%|█████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.53s/it]\n"
233+
"100%|█████████████████████████████████████████████████████████████| 2/2 [01:16<00:00, 38.12s/it]\n"
232234
]
233235
},
234236
{
@@ -242,7 +244,7 @@
242244
"name": "stderr",
243245
"output_type": "stream",
244246
"text": [
245-
"100%|█████████████████████████████████████████████████████████████| 1/1 [00:24<00:00, 24.13s/it]\n"
247+
"100%|████████████████████████████████████████████████████████████| 2/2 [07:53<00:00, 236.95s/it]\n"
246248
]
247249
},
248250
{
@@ -256,25 +258,15 @@
256258
"name": "stderr",
257259
"output_type": "stream",
258260
"text": [
259-
"100%|█████████████████████████████████████████████████████████████| 1/1 [00:07<00:00, 7.31s/it]\n"
261+
" 50%|██████████████████████████████| 1/2 [00:46<00:46, 46.32s/it]"
260262
]
261-
},
262-
{
263-
"data": {
264-
"text/plain": [
265-
"{'ragas_score': 0.3482, 'context_relevancy': 0.1296, 'faithfulness': 0.8889, 'answer_relevancy': 0.9285, 'context_recall': 0.6370, 'harmfulness': 0.0000}"
266-
]
267-
},
268-
"execution_count": 10,
269-
"metadata": {},
270-
"output_type": "execute_result"
271263
}
272264
],
273265
"source": [
274266
"from ragas import evaluate\n",
275267
"\n",
276268
"result = evaluate(\n",
277-
" fiqa_eval[\"baseline\"].select(range(3)),\n",
269+
" fiqa_eval[\"baseline\"],\n",
278270
" metrics=[\n",
279271
" context_relevancy,\n",
280272
" faithfulness,\n",
@@ -454,8 +446,6 @@
454446
"source": [
455447
"And thats it!\n",
456448
"\n",
457-
"You can check out the [ragas in action] notebook to get a feel of what is like to use it while trying to improve your pipelines.\n",
458-
"\n",
459449
"if you have any suggestion/feedbacks/things your not happy about, please do share it in the [issue section](https://github.com/explodinggradients/ragas/issues). We love hearing from you 😁"
460450
]
461451
}

src/ragas/metrics/llms.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
import typing as t
44

5-
from langchain.chat_models import ChatOpenAI
5+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
66
from langchain.chat_models.base import BaseChatModel
7-
from langchain.llms import OpenAI
7+
from langchain.llms import AzureOpenAI, OpenAI
88
from langchain.llms.base import BaseLLM
99
from langchain.prompts import ChatPromptTemplate
1010
from langchain.schema import LLMResult
@@ -17,18 +17,33 @@ def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
1717
return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)
1818

1919

20+
# have to specify it twice for runtime and static checks
21+
MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI]
22+
MultipleCompletionSupportedLLM = t.Union[
23+
OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI
24+
]
25+
26+
27+
def multiple_completion_supported(llm: BaseLLM | BaseChatModel) -> bool:
28+
for model in MULTIPLE_COMPLETION_SUPPORTED:
29+
if isinstance(llm, model):
30+
return True
31+
return False
32+
33+
2034
def generate(
2135
prompts: list[ChatPromptTemplate],
2236
llm: BaseLLM | BaseChatModel,
23-
n: t.Optional[int] = None,
37+
n: int = 1,
2438
temperature: float = 0,
2539
callbacks: t.Optional[Callbacks] = None,
2640
) -> LLMResult:
27-
old_n = None
41+
old_n: int = 1
2842
n_swapped = False
2943
llm.temperature = temperature
3044
if n is not None:
31-
if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI):
45+
if multiple_completion_supported(llm):
46+
llm = t.cast(MultipleCompletionSupportedLLM, llm)
3247
old_n = llm.n
3348
llm.n = n
3449
n_swapped = True
@@ -44,7 +59,8 @@ def generate(
4459
ps = [p.format_messages() for p in prompts]
4560
result = llm.generate(ps, callbacks=callbacks)
4661

47-
if (isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)) and n_swapped:
48-
llm.n = old_n # type: ignore
62+
if multiple_completion_supported(llm) and n_swapped:
63+
llm = t.cast(MultipleCompletionSupportedLLM, llm)
64+
llm.n = old_n
4965

5066
return result

0 commit comments

Comments
 (0)