Skip to content

Commit 7b12353

Browse files
authored
fix: typecast in TestsetGeneration fails for python3.8 (#215)
1 parent f284000 commit 7b12353

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

src/ragas/metrics/answer_correctness.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ def _score_batch(
6666
faith_scores = self.faithfulness._score_batch(ds_faithfulness) # type: ignore
6767
similarity_scores = self.answer_similarity._score_batch(dataset) # type: ignore
6868

69-
scores = np.vstack([faith_scores, similarity_scores])
69+
scores_stacked = np.vstack([faith_scores, similarity_scores])
7070
scores = np.average(
71-
[faith_scores, similarity_scores], axis=0, weights=self.weights
71+
scores_stacked,
72+
axis=0,
73+
weights=self.weights,
7274
)
7375

7476
return scores.tolist()

src/ragas/testset/testset_generator.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,14 +242,16 @@ def _generate_context(self, question: str, text_chunk: str) -> t.List[str]:
242242
for qstn in question.split("\n")
243243
]
244244

245-
def _remove_nodes(self, available_indices: list, node_idx: list) -> t.List:
245+
def _remove_nodes(
246+
self, available_indices: list[BaseNode], node_idx: list
247+
) -> t.List[BaseNode]:
246248
for idx in node_idx:
247249
available_indices.remove(idx)
248250
return available_indices
249251

250252
def _generate_doc_nodes_map(
251253
self, documenet_nodes: t.List[BaseNode]
252-
) -> t.Dict[str, BaseNode]:
254+
) -> t.Dict[str, t.List[BaseNode]]:
253255
doc_nodes_map: t.Dict[str, t.List[BaseNode]] = defaultdict(list[BaseNode])
254256
for node in documenet_nodes:
255257
if node.ref_doc_id:
@@ -288,15 +290,15 @@ def generate(
288290

289291
if isinstance(documents[0], LangchainDocument):
290292
# cast to LangchainDocument since its the only case here
291-
documents = t.cast(list[LangchainDocument], documents)
293+
documents = t.cast(t.List[LangchainDocument], documents)
292294
documents = [
293295
LlamaindexDocument.from_langchain_format(doc) for doc in documents
294296
]
295297
# Convert documents into nodes
296298
node_parser = SimpleNodeParser.from_defaults(
297299
chunk_size=self.chunk_size, chunk_overlap=0, include_metadata=True
298300
)
299-
documents = t.cast(list[LlamaindexDocument], documents)
301+
documents = t.cast(t.List[LlamaindexDocument], documents)
300302
document_nodes: t.List[BaseNode] = node_parser.get_nodes_from_documents(
301303
documents=documents
302304
)
@@ -319,7 +321,7 @@ def generate(
319321
pbar = tqdm(total=test_size)
320322
while count < test_size and available_nodes != []:
321323
evolve_type = self._get_evolve_type()
322-
curr_node = self.rng.choice(available_nodes, size=1)[0]
324+
curr_node = self.rng.choice(np.array(available_nodes), size=1)[0]
323325
available_nodes = self._remove_nodes(available_nodes, [curr_node])
324326

325327
neighbor_nodes = doc_nodes_map[curr_node.source_node.node_id]
@@ -353,6 +355,8 @@ def generate(
353355
similarity_cutoff=self.threshold / 10,
354356
)
355357
if indices:
358+
# type cast indices from list[Any] to list[int]
359+
indices = t.cast(t.List[int], indices)
356360
best_neighbor = neighbor_nodes[indices[0]]
357361
question = self._multicontext_question(
358362
question=seed_question,

tests/unit/test_simple.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1+
from __future__ import annotations
2+
3+
import typing as t
4+
5+
16
def test_import():
27
import ragas
38
from ragas.testset.testset_generator import TestsetGenerator
49

510
assert TestsetGenerator is not None
611
assert ragas is not None
12+
13+
14+
def test_type_casting():
15+
t.cast(t.List[int], [1, 2, 3])

0 commit comments

Comments
 (0)