Skip to content

Commit 9deccc8

Browse files
SimFGanistark
andauthored
feat: add bypass_n option to LangchainLLMWrapper for n-completion control (#2354)
When using some openai-compatible thinking models at that time, it was found that centain errors would occur. 2025-10-10 11:22:45,625 - ragas.executor - ERROR - Exception raised in Job[3]: BadRequestError(Error code: 400 - {'error': {'code': 'InvalidParameter', 'message': 'Reasoning model does not support n > 1, logit_bias, logprobs, top_logprobs Request id: xxx', 'param': '', 'type': 'BadRequest'}}) --------- Co-authored-by: Ani <[email protected]>
1 parent 7360d2e commit 9deccc8

File tree

2 files changed

+206
-6
lines changed

2 files changed

+206
-6
lines changed

src/ragas/llms/base.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
150150
cache: t.Optional[CacheInterface] = None,
151151
bypass_temperature: bool = False,
152+
bypass_n: bool = False,
152153
):
153154
super().__init__(cache=cache)
154155
self.langchain_llm = langchain_llm
@@ -158,6 +159,8 @@ def __init__(
158159
self.is_finished_parser = is_finished_parser
159160
# Certain LLMs (e.g., OpenAI o1 series) do not support temperature
160161
self.bypass_temperature = bypass_temperature
162+
# Certain reasoning LLMs (e.g., OpenAI o1 series) do not support n parameter for
163+
self.bypass_n = bypass_n
161164

162165
def is_finished(self, response: LLMResult) -> bool:
163166
"""
@@ -225,7 +228,7 @@ def generate_text(
225228
old_temperature = self.langchain_llm.temperature # type: ignore
226229
self.langchain_llm.temperature = temperature # type: ignore
227230

228-
if is_multiple_completion_supported(self.langchain_llm):
231+
if is_multiple_completion_supported(self.langchain_llm) and not self.bypass_n:
229232
result = self.langchain_llm.generate_prompt(
230233
prompts=[prompt],
231234
n=n,
@@ -278,7 +281,7 @@ async def agenerate_text(
278281
self.langchain_llm.temperature = temperature # type: ignore
279282

280283
# handle n
281-
if hasattr(self.langchain_llm, "n"):
284+
if hasattr(self.langchain_llm, "n") and not self.bypass_n:
282285
self.langchain_llm.n = n # type: ignore
283286
result = await self.langchain_llm.agenerate_prompt(
284287
prompts=[prompt],

tests/unit/llms/test_llm.py

Lines changed: 201 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from unittest.mock import MagicMock, patch
45

6+
import pytest
57
from langchain_core.outputs import Generation, LLMResult
8+
from langchain_core.prompt_values import PromptValue
69

7-
from ragas.llms.base import BaseRagasLLM
8-
9-
if t.TYPE_CHECKING:
10-
from langchain_core.prompt_values import PromptValue
10+
from ragas.llms.base import BaseRagasLLM, LangchainLLMWrapper
1111

1212

1313
class FakeTestLLM(BaseRagasLLM):
@@ -38,3 +38,200 @@ async def agenerate_text(
3838

3939
def is_finished(self, response: LLMResult) -> bool:
4040
return True
41+
42+
43+
class MockLangchainLLM:
44+
"""Mock Langchain LLM for testing bypass_n functionality."""
45+
46+
def __init__(self):
47+
self.n = None # This makes hasattr(self.langchain_llm, "n") return True
48+
self.temperature = None
49+
self.model_name = "mock-model"
50+
51+
def generate_prompt(self, prompts, n=None, stop=None, callbacks=None):
52+
# Track if n was passed to the method
53+
self._n_passed = n
54+
# Simulate the behavior where if n is passed, we return n generations per prompt
55+
# If n is not passed, we return one generation per prompt
56+
num_prompts = len(prompts)
57+
if n is not None:
58+
# If n is specified, return n generations for each prompt
59+
generations = [
60+
[Generation(text="test response")] * n for _ in range(num_prompts)
61+
]
62+
else:
63+
# If n is not specified, return one generation per prompt
64+
generations = [
65+
[Generation(text="test response")] for _ in range(num_prompts)
66+
]
67+
return LLMResult(generations=generations)
68+
69+
async def agenerate_prompt(self, prompts, n=None, stop=None, callbacks=None):
70+
# Track if n was passed to the method
71+
self._n_passed = n
72+
# If n is not passed as parameter but self.n is set, use self.n
73+
if n is None and hasattr(self, "n") and self.n is not None:
74+
n = self.n
75+
# Simulate the behavior where if n is passed, we return n generations per prompt
76+
# If n is not passed, we return one generation per prompt
77+
num_prompts = len(prompts)
78+
if n is not None:
79+
# If n is specified, return n generations for each prompt
80+
generations = [
81+
[Generation(text="test response")] * n for _ in range(num_prompts)
82+
]
83+
else:
84+
# If n is not specified, return one generation per prompt
85+
generations = [
86+
[Generation(text="test response")] for _ in range(num_prompts)
87+
]
88+
return LLMResult(generations=generations)
89+
90+
91+
def create_mock_prompt():
92+
"""Create a mock prompt for testing."""
93+
prompt = MagicMock(spec=PromptValue)
94+
prompt.to_string.return_value = "test prompt"
95+
return prompt
96+
97+
98+
class TestLangchainLLMWrapperBypassN:
99+
"""Test bypass_n functionality in LangchainLLMWrapper."""
100+
101+
def test_bypass_n_true_sync_does_not_pass_n(self):
102+
"""Test that when bypass_n=True, n is not passed to underlying LLM in sync method."""
103+
mock_llm = MockLangchainLLM()
104+
# Mock is_multiple_completion_supported to return True for this test
105+
with patch(
106+
"ragas.llms.base.is_multiple_completion_supported", return_value=True
107+
):
108+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
109+
prompt = create_mock_prompt()
110+
111+
# Call generate_text with n=3
112+
result = wrapper.generate_text(prompt, n=3)
113+
114+
# Verify that n was not passed to the underlying LLM
115+
assert mock_llm._n_passed is None
116+
# When bypass_n=True, the wrapper should duplicate prompts instead of passing n
117+
# The result should still have 3 generations (created by duplicating prompts)
118+
assert len(result.generations[0]) == 3
119+
120+
def test_bypass_n_false_sync_passes_n(self):
121+
"""Test that when bypass_n=False (default), n is passed to underlying LLM in sync method."""
122+
mock_llm = MockLangchainLLM()
123+
# Mock is_multiple_completion_supported to return True for this test
124+
with patch(
125+
"ragas.llms.base.is_multiple_completion_supported", return_value=True
126+
):
127+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False)
128+
prompt = create_mock_prompt()
129+
130+
# Call generate_text with n=3
131+
result = wrapper.generate_text(prompt, n=3)
132+
133+
# Verify that n was passed to the underlying LLM
134+
assert mock_llm._n_passed == 3
135+
# Result should have 3 generations
136+
assert len(result.generations[0]) == 3
137+
138+
@pytest.mark.asyncio
139+
async def test_bypass_n_true_async_does_not_pass_n(self):
140+
"""Test that when bypass_n=True, n is not passed to underlying LLM in async method."""
141+
mock_llm = MockLangchainLLM()
142+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
143+
prompt = create_mock_prompt()
144+
145+
# Call agenerate_text with n=3
146+
result = await wrapper.agenerate_text(prompt, n=3)
147+
148+
# Verify that n was not passed to the underlying LLM
149+
assert mock_llm._n_passed is None
150+
# When bypass_n=True, the wrapper should duplicate prompts instead of passing n
151+
# The result should still have 3 generations (created by duplicating prompts)
152+
assert len(result.generations[0]) == 3
153+
154+
@pytest.mark.asyncio
155+
async def test_bypass_n_false_async_passes_n(self):
156+
"""Test that when bypass_n=False (default), n is passed to underlying LLM in async method."""
157+
mock_llm = MockLangchainLLM()
158+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=False)
159+
prompt = create_mock_prompt()
160+
161+
# Call agenerate_text with n=3
162+
result = await wrapper.agenerate_text(prompt, n=3)
163+
164+
# Verify that n was passed to the underlying LLM (via n attribute)
165+
assert mock_llm.n == 3
166+
# Result should have 3 generations
167+
assert len(result.generations[0]) == 3
168+
169+
def test_default_bypass_n_behavior(self):
170+
"""Test that default behavior (bypass_n=False) remains unchanged."""
171+
mock_llm = MockLangchainLLM()
172+
# Mock is_multiple_completion_supported to return True for this test
173+
with patch(
174+
"ragas.llms.base.is_multiple_completion_supported", return_value=True
175+
):
176+
# Create wrapper without explicitly setting bypass_n (should default to False)
177+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm)
178+
prompt = create_mock_prompt()
179+
180+
# Call generate_text with n=2
181+
result = wrapper.generate_text(prompt, n=2)
182+
183+
# Verify that n was passed to the underlying LLM (default behavior)
184+
assert mock_llm._n_passed == 2
185+
assert len(result.generations[0]) == 2
186+
187+
@pytest.mark.asyncio
188+
async def test_default_bypass_n_behavior_async(self):
189+
"""Test that default behavior (bypass_n=False) remains unchanged in async method."""
190+
mock_llm = MockLangchainLLM()
191+
# Create wrapper without explicitly setting bypass_n (should default to False)
192+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm)
193+
prompt = create_mock_prompt()
194+
195+
# Call agenerate_text with n=2
196+
result = await wrapper.agenerate_text(prompt, n=2)
197+
198+
# Verify that n was passed to the underlying LLM (default behavior)
199+
assert mock_llm.n == 2
200+
assert len(result.generations[0]) == 2
201+
202+
def test_bypass_n_true_with_multiple_completion_supported(self):
203+
"""Test bypass_n=True with LLM that supports multiple completions."""
204+
# Create a mock LLM that would normally support multiple completions
205+
mock_llm = MockLangchainLLM()
206+
# Mock the is_multiple_completion_supported to return True for this test
207+
with patch(
208+
"ragas.llms.base.is_multiple_completion_supported", return_value=True
209+
):
210+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
211+
prompt = create_mock_prompt()
212+
213+
# Call generate_text with n=3
214+
result = wrapper.generate_text(prompt, n=3)
215+
216+
# Verify that n was not passed to the underlying LLM due to bypass_n=True
217+
assert mock_llm._n_passed is None
218+
# Result should still have 3 generations (created by duplicating prompts)
219+
assert len(result.generations[0]) == 3
220+
221+
@pytest.mark.asyncio
222+
async def test_bypass_n_true_with_multiple_completion_supported_async(self):
223+
"""Test bypass_n=True with LLM that supports multiple completions in async method."""
224+
mock_llm = MockLangchainLLM()
225+
with patch(
226+
"ragas.llms.base.is_multiple_completion_supported", return_value=True
227+
):
228+
wrapper = LangchainLLMWrapper(langchain_llm=mock_llm, bypass_n=True)
229+
prompt = create_mock_prompt()
230+
231+
# Call agenerate_text with n=3
232+
result = await wrapper.agenerate_text(prompt, n=3)
233+
234+
# Verify that n was not passed to the underlying LLM due to bypass_n=True
235+
assert mock_llm._n_passed is None
236+
# Result should still have 3 generations
237+
assert len(result.generations[0]) == 3

0 commit comments

Comments
 (0)