diff --git a/src/smolagents/models.py b/src/smolagents/models.py index 5d482afbc..2811f36b5 100644 --- a/src/smolagents/models.py +++ b/src/smolagents/models.py @@ -74,8 +74,15 @@ def convert(obj): return convert(obj) -def remove_content_after_stop_sequences(content: str, stop_sequences: list[str]) -> str: - """Remove content after any stop sequence is encountered.""" +def remove_content_after_stop_sequences(content: str | None, stop_sequences: list[str] | None) -> str | None: + """Remove content after any stop sequence is encountered. + + Some providers may return ``None`` content (for example when responding purely with tool calls), + so we skip processing in that case. + """ + if content is None or not stop_sequences: + return content + for stop_seq in stop_sequences: split = content.split(stop_seq) content = split[0] diff --git a/tests/test_models.py b/tests/test_models.py index 456299e3a..0151c25fb 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -724,6 +724,17 @@ def test_remove_content_after_stop_sequences(): assert removed_content == "Hello" +def test_remove_content_after_stop_sequences_handles_none(): + # Test with None stop sequence + content = "Hello world!" + removed_content = remove_content_after_stop_sequences(content, None) + assert removed_content == content + + # Test with None content + removed_content = remove_content_after_stop_sequences(None, [""]) + assert removed_content is None + + @pytest.mark.parametrize( "convert_images_to_image_urls, expected_clean_message", [