Skip to content

Commit 65809e5

Browse files
authored
Merge pull request #783 from parea-ai/chore-update-evals-output-is-list
chore: update evals list parsing
2 parents 3c219ce + a015af7 commit 65809e5

File tree

6 files changed

+27
-35
lines changed

6 files changed

+27
-35
lines changed

parea/evals/rag/context_query_relevancy.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, List, Optional
22

3-
from parea.evals.utils import call_openai, sent_tokenize
3+
from parea.evals.utils import call_openai, get_context, sent_tokenize
44
from parea.schemas.log import Log
55

66

@@ -27,13 +27,7 @@ def context_query_relevancy_factory(
2727
def context_query_relevancy(log: Log) -> float:
2828
"""Quantifies how much the retrieved context relates to the query."""
2929
question = log.inputs[question_field]
30-
if context_fields:
31-
context = "\n".join(log.inputs[context_field] for context_field in context_fields)
32-
else:
33-
if isinstance(log.output, list):
34-
context = "\n".join(log.output)
35-
else:
36-
context = str(log.output)
30+
context = get_context(log, context_fields)
3731

3832
extracted_sentences = call_openai(
3933
model=model,

parea/evals/rag/context_ranking_listwise.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, List, Optional
22

3-
from parea.evals.utils import call_openai, ndcg
3+
from parea.evals.utils import call_openai, get_context, ndcg
44
from parea.schemas.log import Log
55

66

@@ -99,13 +99,7 @@ def progressive_reranking(query: str, contexts: List[str]) -> List[int]:
9999
def context_ranking(log: Log) -> float:
100100
"""Quantifies if the retrieved context is ranked by their relevancy by re-ranking the contexts."""
101101
question = log.inputs[question_field]
102-
if context_fields:
103-
contexts = [log.inputs[context_field] for context_field in context_fields]
104-
else:
105-
if isinstance(log.output, list):
106-
contexts = log.output
107-
else:
108-
contexts = [str(log.output)]
102+
contexts = get_context(log, context_fields, True)
109103

110104
reranked_indices = progressive_reranking(question, contexts)
111105

parea/evals/rag/context_ranking_pointwise.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Callable, List, Optional
22

3-
from parea.evals.utils import call_openai, safe_json_loads
3+
from parea.evals.utils import call_openai, get_context, safe_json_loads
44
from parea.schemas.log import Log
55

66

@@ -40,13 +40,7 @@ def context_ranking_pointwise_factory(
4040
def context_ranking_pointwise(log: Log) -> float:
4141
"""Quantifies if the retrieved context is ranked by their relevancy"""
4242
question = log.inputs[question_field]
43-
if context_fields:
44-
contexts = [log.inputs[context_field] for context_field in context_fields]
45-
else:
46-
if isinstance(log.output, list):
47-
contexts = log.output
48-
else:
49-
contexts = [str(log.output)]
43+
contexts = get_context(log, context_fields, True)
5044

5145
verifications = []
5246
for context in contexts:

parea/evals/rag/percent_target_supported_by_context.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44

5-
from parea.evals.utils import call_openai
5+
from parea.evals.utils import call_openai, get_context
66
from parea.schemas.log import Log
77

88

@@ -14,13 +14,8 @@ def percent_target_supported_by_context_factory(
1414
def percent_target_supported_by_context(log: Log) -> Union[float, None]:
1515
"""Quantifies how many sentences in the target/ground truth are supported by the retrieved context."""
1616
question = log.inputs[question_field]
17-
if context_fields:
18-
context = "\n".join(log.inputs[context_field] for context_field in context_fields)
19-
else:
20-
if isinstance(log.output, list):
21-
context = "\n".join(log.output)
22-
else:
23-
context = str(log.output)
17+
context = get_context(log, context_fields)
18+
2419
if (target := log.target) is None:
2520
return None
2621

parea/evals/utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, List, Union
1+
from typing import Callable, List, Optional, Union
22

33
import json
44
import warnings
@@ -10,7 +10,7 @@
1010
from openai import __version__ as openai_version
1111

1212
from parea.parea_logger import parea_logger
13-
from parea.schemas import EvaluationResult
13+
from parea.schemas import EvaluationResult, Log
1414
from parea.schemas.log import Log
1515
from parea.schemas.models import UpdateLog
1616

@@ -179,3 +179,18 @@ def get_tokens(model: str, text: str) -> List[int]:
179179
except Exception as e:
180180
print(f"Error encoding text: {e}")
181181
return []
182+
183+
184+
def get_context(log: Log, context_fields: Optional[List[str]] = None, as_list: bool = False) -> str:
185+
if context_fields:
186+
context_list = [log.inputs[context_field] for context_field in context_fields]
187+
return context_list if as_list else "\n".join(context_list)
188+
else:
189+
context = log.output
190+
try:
191+
loaded_context = json.loads(log.output)
192+
if isinstance(log.output, list):
193+
return loaded_context if as_list else "\n".join(loaded_context)
194+
except json.JSONDecodeError:
195+
pass
196+
return [context] if as_list else context

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.135"
9+
version = "0.2.136"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

0 commit comments

Comments
 (0)