Skip to content

Commit 2b9734d

Browse files
authored
fix: n_swapped check for generate (#73)
1 parent 135612d commit 2b9734d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

src/ragas/metrics/llms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,12 @@ def generate(
2020
n: t.Optional[int] = None,
2121
) -> LLMResult:
2222
old_n = None
23+
n_swapped = False
2324
if n is not None:
2425
if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI):
2526
old_n = llm.n
2627
llm.n = n
28+
n_swapped = True
2729
else:
2830
raise Exception(
2931
f"n={n} was passed to generate but the LLM {llm} does not support it."
@@ -36,6 +38,6 @@ def generate(
3638
ps = [p.format_messages() for p in prompts]
3739
result = llm.generate(ps)
3840

39-
if isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI):
41+
if (isinstance(llm, OpenAI) or isinstance(llm, ChatOpenAI)) and n_swapped:
4042
llm.n = old_n # type: ignore
4143
return result

0 commit comments

Comments
 (0)