Skip to content

Commit 75db229

Browse files
authored
feat: llamaIndex llm support (#205)
Added support for LlamaIndex `ServiceContext` and `BaseLLM`. Helps directly use llamaIndex LLMs with Ragas
1 parent c2a64d5 commit 75db229

File tree

8 files changed

+222
-17
lines changed

8 files changed

+222
-17
lines changed

docs/howtos/customisations/llms.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,8 @@
179179
"from ragas import evaluate\n",
180180
"\n",
181181
"result = evaluate(\n",
182-
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n",
183-
" metrics=[faithfulness]\n",
182+
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n",
183+
" metrics=[faithfulness],\n",
184184
")\n",
185185
"\n",
186186
"result"
@@ -301,8 +301,8 @@
301301
"from ragas import evaluate\n",
302302
"\n",
303303
"result = evaluate(\n",
304-
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration \n",
305-
" metrics=[faithfulness]\n",
304+
" fiqa_eval[\"baseline\"].select(range(5)), # showing only 5 for demonstration\n",
305+
" metrics=[faithfulness],\n",
306306
")\n",
307307
"\n",
308308
"result"

requirements/docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ myst-parser[linkify]
44
sphinx_design
55
astroid<3
66
myst-nb
7+
llama_index

src/ragas/llms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from ragas.llms.base import BaseRagasLLM, LangchainLLM, llm_factory
2+
from ragas.llms.llamaindex import LlamaIndexLLM
3+
4+
__all__ = ["BaseRagasLLM", "LangchainLLM", "LlamaIndexLLM", "llm_factory"]

src/ragas/llms/base.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
from __future__ import annotations
2+
3+
import os
4+
import typing as t
5+
from abc import ABC, abstractmethod
6+
7+
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
8+
from langchain.chat_models.base import BaseChatModel
9+
from langchain.llms import AzureOpenAI, OpenAI
10+
from langchain.llms.base import BaseLLM
11+
from langchain.schema import LLMResult
12+
13+
from ragas.async_utils import run_async_tasks
14+
15+
if t.TYPE_CHECKING:
16+
from langchain.callbacks.base import Callbacks
17+
from langchain.prompts import ChatPromptTemplate
18+
19+
20+
def isOpenAI(llm: BaseLLM | BaseChatModel) -> bool:
21+
return isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)
22+
23+
24+
# have to specify it twice for runtime and static checks
25+
MULTIPLE_COMPLETION_SUPPORTED = [OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI]
26+
MultipleCompletionSupportedLLM = t.Union[
27+
OpenAI, ChatOpenAI, AzureOpenAI, AzureChatOpenAI
28+
]
29+
30+
31+
class BaseRagasLLM(ABC):
32+
"""
33+
BaseLLM is the base class for all LLMs. It provides a consistent interface for other
34+
classes that interact with LLMs like Langchains, LlamaIndex, LiteLLM etc. Handles
35+
multiple_completions even if not supported by the LLM.
36+
37+
It currently takes in ChatPromptTemplates and returns LLMResults which are Langchain
38+
primitives.
39+
"""
40+
41+
# supports multiple compeletions for the given prompt
42+
n_completions_supported: bool = False
43+
44+
@property
45+
@abstractmethod
46+
def llm(self):
47+
...
48+
49+
@abstractmethod
50+
def generate(
51+
self,
52+
prompts: list[str],
53+
n: int = 1,
54+
temperature: float = 0,
55+
callbacks: t.Optional[Callbacks] = None,
56+
) -> list[list[str]]:
57+
...
58+
59+
60+
class LangchainLLM(BaseRagasLLM):
61+
n_completions_supported: bool = True
62+
63+
def __init__(self, llm: BaseLLM | BaseChatModel):
64+
self.langchain_llm = llm
65+
66+
@property
67+
def llm(self):
68+
return self.langchain_llm
69+
70+
@staticmethod
71+
def llm_supports_completions(llm):
72+
for llm_type in MULTIPLE_COMPLETION_SUPPORTED:
73+
if isinstance(llm, llm_type):
74+
return True
75+
76+
def generate_multiple_completions(
77+
self,
78+
prompts: list[ChatPromptTemplate],
79+
n: int = 1,
80+
callbacks: t.Optional[Callbacks] = None,
81+
) -> LLMResult:
82+
self.langchain_llm = t.cast(MultipleCompletionSupportedLLM, self.langchain_llm)
83+
old_n = self.langchain_llm.n
84+
self.langchain_llm.n = n
85+
86+
if isinstance(self.llm, BaseLLM):
87+
ps = [p.format() for p in prompts]
88+
result = self.llm.generate(ps, callbacks=callbacks)
89+
else: # if BaseChatModel
90+
ps = [p.format_messages() for p in prompts]
91+
result = self.llm.generate(ps, callbacks=callbacks)
92+
self.llm.n = old_n
93+
94+
return result
95+
96+
async def generate_completions(
97+
self,
98+
prompts: list[ChatPromptTemplate],
99+
callbacks: t.Optional[Callbacks] = None,
100+
) -> LLMResult:
101+
if isinstance(self.llm, BaseLLM):
102+
ps = [p.format() for p in prompts]
103+
result = await self.llm.agenerate(ps, callbacks=callbacks)
104+
else: # if BaseChatModel
105+
ps = [p.format_messages() for p in prompts]
106+
result = await self.llm.agenerate(ps, callbacks=callbacks)
107+
108+
return result
109+
110+
def generate(
111+
self,
112+
prompts: list[ChatPromptTemplate],
113+
n: int = 1,
114+
temperature: float = 0,
115+
callbacks: t.Optional[Callbacks] = None,
116+
) -> LLMResult:
117+
# set temperature to 0.2 for multiple completions
118+
temperature = 0.2 if n > 1 else 0
119+
self.llm.temperature = temperature
120+
121+
if self.llm_supports_completions(self.llm):
122+
return self.generate_multiple_completions(prompts, n, callbacks)
123+
else: # call generate_completions n times to mimic multiple completions
124+
list_llmresults = run_async_tasks(
125+
[self.generate_completions(prompts, callbacks) for _ in range(n)]
126+
)
127+
128+
# fill results as if the LLM supported multiple completions
129+
generations = []
130+
for i in range(len(prompts)):
131+
completions = []
132+
for result in list_llmresults:
133+
completions.append(result.generations[i][0])
134+
generations.append(completions)
135+
136+
# compute total token usage by adding individual token usage
137+
llm_output = list_llmresults[0].llm_output
138+
if "token_usage" in llm_output:
139+
sum_prompt_tokens = 0
140+
sum_completion_tokens = 0
141+
sum_total_tokens = 0
142+
for result in list_llmresults:
143+
token_usage = result.llm_output["token_usage"]
144+
sum_prompt_tokens += token_usage["prompt_tokens"]
145+
sum_completion_tokens += token_usage["completion_tokens"]
146+
sum_total_tokens += token_usage["total_tokens"]
147+
148+
llm_output["token_usage"] = {
149+
"prompt_tokens": sum_prompt_tokens,
150+
"completion_tokens": sum_completion_tokens,
151+
"sum_total_tokens": sum_total_tokens,
152+
}
153+
154+
return LLMResult(generations=generations, llm_output=llm_output)
155+
156+
157+
def llm_factory() -> LangchainLLM:
158+
oai_key = os.getenv("OPENAI_API_KEY", "no-key")
159+
openai_llm = ChatOpenAI(openai_api_key=oai_key)
160+
return LangchainLLM(llm=openai_llm)

src/ragas/llms/llamaindex.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
from langchain.schema.output import Generation, LLMResult
6+
from llama_index.llms.base import LLM as LiLLM
7+
8+
from ragas.async_utils import run_async_tasks
9+
from ragas.llms.base import BaseRagasLLM
10+
11+
if t.TYPE_CHECKING:
12+
from langchain.callbacks.base import Callbacks
13+
from langchain.prompts import ChatPromptTemplate
14+
15+
16+
class LlamaIndexLLM(BaseRagasLLM):
17+
def __init__(self, llm: LiLLM) -> None:
18+
self.llama_index_llm = llm
19+
20+
@property
21+
def llm(self) -> LiLLM:
22+
return self.llama_index_llm
23+
24+
def generate(
25+
self,
26+
prompts: list[ChatPromptTemplate],
27+
n: int = 1,
28+
temperature: float = 0,
29+
callbacks: t.Optional[Callbacks] = None,
30+
) -> LLMResult:
31+
# set temperature to 0.2 for multiple completions
32+
temperature = 0.2 if n > 1 else 0
33+
self.llm.temperature = temperature
34+
35+
# get task coroutines
36+
tasks = []
37+
for p in prompts:
38+
tasks.extend([self.llm.acomplete(p.format()) for _ in range(n)])
39+
40+
# process results to LLMResult
41+
# token usage is note included for now
42+
results = run_async_tasks(tasks)
43+
results2D = [results[i : i + n] for i in range(0, len(results), n)]
44+
generations = [
45+
[Generation(text=r.text) for r in result] for result in results2D
46+
]
47+
return LLMResult(generations=generations)

src/ragas/metrics/base.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
"""
77
from __future__ import annotations
88

9-
import os
109
import typing as t
1110
from abc import ABC, abstractmethod
1211
from dataclasses import dataclass, field
@@ -20,7 +19,7 @@
2019
from tqdm import tqdm
2120

2221
from ragas.exceptions import OpenAIKeyNotFound
23-
from ragas.metrics.llms import LangchainLLM
22+
from ragas.llms import LangchainLLM, llm_factory
2423

2524
if t.TYPE_CHECKING:
2625
from langchain.callbacks.base import Callbacks
@@ -109,15 +108,9 @@ def get_batches(self, dataset_size: int) -> list[range]:
109108
return make_batches(dataset_size, self.batch_size)
110109

111110

112-
def _llm_factory() -> LangchainLLM:
113-
oai_key = os.getenv("OPENAI_API_KEY", "no-key")
114-
openai_llm = ChatOpenAI(openai_api_key=oai_key)
115-
return LangchainLLM(llm=openai_llm)
116-
117-
118111
@dataclass
119112
class MetricWithLLM(Metric):
120-
llm: LangchainLLM = field(default_factory=_llm_factory)
113+
llm: LangchainLLM = field(default_factory=llm_factory)
121114

122115
def init_model(self):
123116
if isinstance(self.llm, ChatOpenAI) or isinstance(self.llm, OpenAI):

src/ragas/metrics/critique.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
from langchain.callbacks.manager import CallbackManager, trace_as_chain_group
99
from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
1010

11-
from ragas.metrics.base import EvaluationMode, MetricWithLLM, _llm_factory
12-
from ragas.metrics.llms import LangchainLLM
11+
from ragas.llms import LangchainLLM
12+
from ragas.metrics.base import EvaluationMode, MetricWithLLM, llm_factory
1313

1414
CRITIQUE_PROMPT = HumanMessagePromptTemplate.from_template(
1515
"""Given a input and submission. Evaluate the submission only using the given criteria.
@@ -56,7 +56,7 @@ class AspectCritique(MetricWithLLM):
5656
strictness: int = field(default=1, repr=False)
5757
batch_size: int = field(default=15, repr=False)
5858
llm: LangchainLLM = field(
59-
default_factory=_llm_factory,
59+
default_factory=llm_factory,
6060
repr=False,
6161
)
6262

src/ragas/testset/testset_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from numpy.random import default_rng
2121
from tqdm import tqdm
2222

23-
from ragas.metrics.llms import LangchainLLM
23+
from ragas.llms import LangchainLLM
2424
from ragas.testset.prompts import (
2525
ANSWER_FORMULATE,
2626
COMPRESS_QUESTION,

0 commit comments

Comments
 (0)