Skip to content

Commit 2df7a14

Browse files
PouyanpiRobGeada
andauthored
fix: ensure that stop token is not ignored if llm_params is None (#1529)
* Fix ignored stop token if llm_params is None Signed-off-by: Rob Geada <[email protected]> * fix(llm): pass stop tokens to ainvoke or invoke instead of bind Commit 67de947 ("chore(types): Type-clean /actions") introduced a bug where stop tokens were only passed to the LLM when llm_params was truthy. When llm_params was None or empty, stop tokens were completely ignored. Fix: Restored the original pattern where stop is passed directly to ainvoke() as a kwarg, and .bind() is only used for llm_params: - stop is now passed to _invoke_with_string_prompt() and _invoke_with_message_list() - Those functions pass stop=stop to llm.ainvoke() - .bind() is only called when llm_params is truthy (no stop in bind) --------- Signed-off-by: Rob Geada <[email protected]> Co-authored-by: Rob Geada <[email protected]>
1 parent f716ee0 commit 2df7a14

File tree

4 files changed

+47
-12
lines changed

4 files changed

+47
-12
lines changed

nemoguardrails/actions/llm/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,12 @@ async def llm_call(
164164
_setup_llm_call_info(llm, model_name, model_provider)
165165
all_callbacks = _prepare_callbacks(custom_callback_handlers)
166166

167-
generation_llm: Union[BaseLanguageModel, Runnable] = (
168-
llm.bind(stop=stop, **llm_params) if llm_params and llm is not None else llm
169-
)
167+
generation_llm: Union[BaseLanguageModel, Runnable] = llm.bind(**llm_params) if llm_params else llm
170168

171169
if isinstance(prompt, str):
172-
response = await _invoke_with_string_prompt(generation_llm, prompt, all_callbacks)
170+
response = await _invoke_with_string_prompt(generation_llm, prompt, all_callbacks, stop)
173171
else:
174-
response = await _invoke_with_message_list(generation_llm, prompt, all_callbacks)
172+
response = await _invoke_with_message_list(generation_llm, prompt, all_callbacks, stop)
175173

176174
_store_reasoning_traces(response)
177175
_store_tool_calls(response)
@@ -206,10 +204,11 @@ async def _invoke_with_string_prompt(
206204
llm: Union[BaseLanguageModel, Runnable],
207205
prompt: str,
208206
callbacks: BaseCallbackManager,
207+
stop: Optional[List[str]],
209208
):
210209
"""Invoke LLM with string prompt."""
211210
try:
212-
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks))
211+
return await llm.ainvoke(prompt, config=RunnableConfig(callbacks=callbacks), stop=stop)
213212
except Exception as e:
214213
raise LLMCallException(e)
215214

@@ -218,12 +217,13 @@ async def _invoke_with_message_list(
218217
llm: Union[BaseLanguageModel, Runnable],
219218
prompt: List[dict],
220219
callbacks: BaseCallbackManager,
220+
stop: Optional[List[str]],
221221
):
222222
"""Invoke LLM with message list after converting to LangChain format."""
223223
messages = _convert_messages_to_langchain_format(prompt)
224224

225225
try:
226-
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks))
226+
return await llm.ainvoke(messages, config=RunnableConfig(callbacks=callbacks), stop=stop)
227227
except Exception as e:
228228
raise LLMCallException(e)
229229

tests/test_actions_llm_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,19 @@ def test_store_tool_calls_with_real_aimessage_multiple_tool_calls():
532532
assert len(tool_calls) == 2
533533
assert tool_calls[0]["name"] == "foo"
534534
assert tool_calls[1]["name"] == "bar"
535+
536+
537+
@pytest.mark.asyncio
538+
@pytest.mark.parametrize("llm_params", [None, {}])
539+
async def test_llm_call_stop_tokens_passed_without_llm_params(llm_params):
540+
"""Stop tokens must be passed to ainvoke even when llm_params is None or empty."""
541+
from unittest.mock import AsyncMock, MagicMock
542+
543+
from nemoguardrails.actions.llm.utils import llm_call
544+
545+
mock_llm = AsyncMock()
546+
mock_llm.ainvoke.return_value = MagicMock(content="response")
547+
548+
await llm_call(mock_llm, "prompt", stop=["User:"], llm_params=llm_params)
549+
550+
assert mock_llm.ainvoke.call_args[1]["stop"] == ["User:"]

tests/test_llm_params_e2e.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ def nim_config_content():
4646
models:
4747
- type: main
4848
engine: nim
49-
model: meta/llama-3.1-70b-instruct
50-
api_base: https://integrate.api.nvidia.com/v1
49+
model: meta/llama-3.3-70b-instruct
5150
"""
5251

5352

@@ -197,6 +196,26 @@ async def test_openai_llm_params_streaming(self, openai_config_path):
197196
content = response.response[-1]["content"]
198197
assert "1" in content
199198

199+
@pytest.mark.asyncio
200+
@pytest.mark.skipif(
201+
not os.getenv("OPENAI_API_KEY"),
202+
reason="OpenAI API key not available for e2e testing",
203+
)
204+
async def test_openai_stop_tokens_without_llm_params(self, openai_config_path):
205+
"""Test stop tokens work without llm_params (regression test for 67de94723)."""
206+
config = RailsConfig.from_path(openai_config_path)
207+
rails = LLMRails(config, verbose=False)
208+
209+
response = await llm_call(
210+
rails.llm,
211+
"Count from 1 to 10, one number per line.",
212+
stop=["5"],
213+
llm_params=None,
214+
)
215+
216+
assert "4" in response
217+
assert "5" not in response
218+
200219

201220
@pytest.mark.skipif(
202221
not LIVE_TEST_MODE,
@@ -392,7 +411,7 @@ async def test_openai_unsupported_params_error_handling(self, openai_config_path
392411
models:
393412
- type: main
394413
engine: openai
395-
model: o1-mini
414+
model: o3-mini
396415
"""
397416

398417
with tempfile.TemporaryDirectory() as temp_dir:

tests/test_tool_calling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ async def test_llm_call_with_llm_params():
249249
result = await llm_call(mock_llm, "Test prompt", llm_params=llm_params)
250250

251251
assert result == "LLM response with params"
252-
mock_llm.bind.assert_called_once_with(stop=None, **llm_params)
252+
mock_llm.bind.assert_called_once_with(**llm_params)
253253
mock_bound_llm.ainvoke.assert_called_once()
254254

255255

@@ -298,7 +298,7 @@ async def test_llm_call_with_llm_params_temperature_max_tokens():
298298
result = await llm_call(mock_llm, "Test prompt", llm_params=llm_params)
299299

300300
assert result == "Response with temp and tokens"
301-
mock_llm.bind.assert_called_once_with(stop=None, temperature=0.8, max_tokens=50)
301+
mock_llm.bind.assert_called_once_with(temperature=0.8, max_tokens=50)
302302
mock_bound_llm.ainvoke.assert_called_once()
303303

304304

0 commit comments

Comments
 (0)