Skip to content

Commit 7ba0c28

Browse files
authored
fix(ragas): remove mutable defaults (#684)
Minor PR to clean the whole codebase from the well-known gotcha of [mutable default arguments](https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments) - which I encourage you guys to be more careful of.
1 parent e1e05f8 commit 7ba0c28

File tree

7 files changed

+30
-28
lines changed

7 files changed

+30
-28
lines changed

src/ragas/embeddings/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ def set_run_config(self, run_config: RunConfig):
4444

4545
class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
4646
def __init__(
47-
self,
48-
embeddings: Embeddings,
49-
run_config: t.Optional[RunConfig] = None
47+
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
5048
):
5149
self.embeddings = embeddings
5250
if run_config is None:

src/ragas/evaluation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
LangchainEmbeddingsWrapper,
1616
embedding_factory,
1717
)
18-
from ragas.llms import llm_factory
1918
from ragas.exceptions import ExceptionInRunner
2019
from ragas.executor import Executor
20+
from ragas.llms import llm_factory
2121
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
2222
from ragas.metrics._answer_correctness import AnswerCorrectness
2323
from ragas.metrics.base import Metric, MetricWithEmbeddings, MetricWithLLM
@@ -42,11 +42,11 @@ def evaluate(
4242
metrics: list[Metric] | None = None,
4343
llm: t.Optional[BaseRagasLLM | LangchainLLM] = None,
4444
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
45-
callbacks: Callbacks = [],
45+
callbacks: Callbacks = None,
4646
is_async: bool = False,
4747
run_config: t.Optional[RunConfig] = None,
4848
raise_exceptions: bool = True,
49-
column_map: t.Dict[str, str] = {},
49+
column_map: t.Optional[t.Dict[str, str]] = None,
5050
) -> Result:
5151
"""
5252
Run the evaluation on the dataset with different metrics
@@ -120,6 +120,9 @@ def evaluate(
120120
'answer_relevancy': 0.874}
121121
```
122122
"""
123+
column_map = column_map or {}
124+
callbacks = callbacks or []
125+
123126
if dataset is None:
124127
raise ValueError("Provide dataset!")
125128

src/ragas/executor.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
2-
import sys
32

43
import asyncio
54
import logging
5+
import sys
66
import threading
77
import typing as t
88
from dataclasses import dataclass, field
@@ -24,28 +24,31 @@ def runner_exception_hook(args: threading.ExceptHookArgs):
2424
# set a custom exception hook
2525
# threading.excepthook = runner_exception_hook
2626

27+
2728
def as_completed(loop, coros, max_workers):
2829
loop_arg_dict = {"loop": loop} if sys.version_info[:2] < (3, 10) else {}
2930
if max_workers == -1:
3031
return asyncio.as_completed(coros, **loop_arg_dict)
31-
32+
3233
# loop argument is removed since Python 3.10
3334
semaphore = asyncio.Semaphore(max_workers, **loop_arg_dict)
35+
3436
async def sema_coro(coro):
3537
async with semaphore:
3638
return await coro
37-
39+
3840
sema_coros = [sema_coro(c) for c in coros]
3941
return asyncio.as_completed(sema_coros, **loop_arg_dict)
4042

43+
4144
class Runner(threading.Thread):
4245
def __init__(
4346
self,
4447
jobs: t.List[t.Tuple[t.Coroutine, str]],
4548
desc: str,
4649
keep_progress_bar: bool = True,
4750
raise_exceptions: bool = True,
48-
run_config: t.Optional[RunConfig] = None
51+
run_config: t.Optional[RunConfig] = None,
4952
):
5053
super().__init__()
5154
self.jobs = jobs
@@ -59,7 +62,7 @@ def __init__(
5962
self.futures = as_completed(
6063
loop=self.loop,
6164
coros=[coro for coro, _ in self.jobs],
62-
max_workers=self.run_config.max_workers
65+
max_workers=self.run_config.max_workers,
6366
)
6467

6568
async def _aresults(self) -> t.List[t.Any]:

src/ragas/llms/base.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def generate_text(
6060
n: int = 1,
6161
temperature: float = 1e-8,
6262
stop: t.Optional[t.List[str]] = None,
63-
callbacks: Callbacks = [],
63+
callbacks: Callbacks = None,
6464
) -> LLMResult:
6565
...
6666

@@ -71,7 +71,7 @@ async def agenerate_text(
7171
n: int = 1,
7272
temperature: float = 1e-8,
7373
stop: t.Optional[t.List[str]] = None,
74-
callbacks: Callbacks = [],
74+
callbacks: Callbacks = None,
7575
) -> LLMResult:
7676
...
7777

@@ -81,7 +81,7 @@ async def generate(
8181
n: int = 1,
8282
temperature: float = 1e-8,
8383
stop: t.Optional[t.List[str]] = None,
84-
callbacks: Callbacks = [],
84+
callbacks: Callbacks = None,
8585
is_async: bool = True,
8686
) -> LLMResult:
8787
"""Generate text using the given event loop."""
@@ -119,9 +119,7 @@ class LangchainLLMWrapper(BaseRagasLLM):
119119
"""
120120

121121
def __init__(
122-
self,
123-
langchain_llm: BaseLanguageModel,
124-
run_config: t.Optional[RunConfig] = None
122+
self, langchain_llm: BaseLanguageModel, run_config: t.Optional[RunConfig] = None
125123
):
126124
self.langchain_llm = langchain_llm
127125
if run_config is None:
@@ -134,7 +132,7 @@ def generate_text(
134132
n: int = 1,
135133
temperature: float = 1e-8,
136134
stop: t.Optional[t.List[str]] = None,
137-
callbacks: t.Optional[Callbacks] = None,
135+
callbacks: Callbacks = None,
138136
) -> LLMResult:
139137
temperature = self.get_temperature(n=n)
140138
if is_multiple_completion_supported(self.langchain_llm):
@@ -164,7 +162,7 @@ async def agenerate_text(
164162
n: int = 1,
165163
temperature: float = 1e-8,
166164
stop: t.Optional[t.List[str]] = None,
167-
callbacks: t.Optional[Callbacks] = None,
165+
callbacks: Callbacks = None,
168166
) -> LLMResult:
169167
temperature = self.get_temperature(n=n)
170168
if is_multiple_completion_supported(self.langchain_llm):
@@ -206,8 +204,7 @@ def set_run_config(self, run_config: RunConfig):
206204

207205

208206
def llm_factory(
209-
model: str = "gpt-3.5-turbo-16k",
210-
run_config: t.Optional[RunConfig] = None
207+
model: str = "gpt-3.5-turbo-16k", run_config: t.Optional[RunConfig] = None
211208
) -> BaseRagasLLM:
212209
timeout = None
213210
if run_config is not None:

src/ragas/metrics/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def save(self, cache_dir: t.Optional[str] = None) -> None:
6060
"adapt() is not implemented for {} metric".format(self.name)
6161
)
6262

63-
def score(self: t.Self, row: t.Dict, callbacks: Callbacks = []) -> float:
63+
def score(self: t.Self, row: t.Dict, callbacks: Callbacks = None) -> float:
64+
callbacks = callbacks or []
6465
rm, group_cm = new_group(
6566
self.name, inputs=row, callbacks=callbacks, is_async=False
6667
)
@@ -78,8 +79,9 @@ def score(self: t.Self, row: t.Dict, callbacks: Callbacks = []) -> float:
7879
return score
7980

8081
async def ascore(
81-
self: t.Self, row: t.Dict, callbacks: Callbacks = [], is_async: bool = True
82+
self: t.Self, row: t.Dict, callbacks: Callbacks = None, is_async: bool = True
8283
) -> float:
84+
callbacks = callbacks or []
8385
rm, group_cm = new_group(
8486
self.name, inputs=row, callbacks=callbacks, is_async=True
8587
)

src/ragas/testset/docstore.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class Direction(str, Enum):
7878
PREV = "prev"
7979
UP = "up"
8080
DOWN = "down"
81-
81+
8282

8383
class Node(Document):
8484
keyphrases: t.List[str] = Field(default_factory=list, repr=False)
@@ -240,7 +240,7 @@ def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
240240
)
241241
result_idx += 1
242242

243-
if n.keyphrases == []:
243+
if not n.keyphrases:
244244
nodes_to_extract.update({i: result_idx})
245245
executor.submit(
246246
self.extractor.extract,
@@ -250,7 +250,7 @@ def add_nodes(self, nodes: t.Sequence[Node], show_progress=True):
250250
result_idx += 1
251251

252252
results = executor.results()
253-
if results == []:
253+
if not results:
254254
raise ExceptionInRunner()
255255

256256
for i, n in enumerate(nodes):
@@ -336,7 +336,6 @@ def adjustment_factor(wins, alpha):
336336
def get_similar(
337337
self, node: Node, threshold: float = 0.7, top_k: int = 3
338338
) -> t.Union[t.List[Document], t.List[Node]]:
339-
items = []
340339
doc = node
341340
if doc.embedding is None:
342341
raise ValueError("Document has no embedding.")

src/ragas/testset/evolutions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ async def _aevolve(
464464
# find a similar node and generate a question based on both
465465
merged_node = self.merge_nodes(current_nodes)
466466
similar_node = self.docstore.get_similar(merged_node, top_k=1)
467-
if similar_node == []:
467+
if not similar_node:
468468
# retry
469469
new_random_nodes = self.docstore.get_random_nodes(k=1)
470470
current_nodes = CurrentNodes(

0 commit comments

Comments
 (0)