Skip to content

Commit 0fdea83

Browse files
fix: prompt consumes same iterator twice (#768)
1 parent 38d0668 commit 0fdea83

File tree

3 files changed

+37
-2
lines changed

3 files changed

+37
-2
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Unreleased
44

5+
- Add tool_choice parameter to LLM interface (#738)
6+
- Fix Prompt consumes same iterator twice leading to no data added to chat (#768)
7+
58
## 1.2.1 (2025-08-04)
69

710
## 1.2.0 (2025-08-01)

packages/ragbits-core/src/ragbits/core/prompt/prompt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,6 @@ def __init__(self, input_data: PromptInputT | None = None, history: ChatFormat |
188188
self.rendered_system_prompt = (
189189
self._render_template(self.system_prompt_template, input_data) if self.system_prompt_template else None
190190
)
191-
self.rendered_user_prompt = self._render_template(self.user_prompt_template, input_data)
192191
self.attachments = self._get_attachments_from_input_data(input_data)
193192

194193
# Additional few shot examples that can be added dynamically using methods
@@ -197,7 +196,9 @@ def __init__(self, input_data: PromptInputT | None = None, history: ChatFormat |
197196

198197
# Additional conversation history that can be added dynamically using methods
199198
self._conversation_history: list[dict[str, Any]] = history or []
200-
self.add_user_message(input_data if input_data else self.rendered_user_prompt)
199+
200+
self.add_user_message(input_data or self._render_template(self.user_prompt_template, input_data))
201+
self.rendered_user_prompt = self.chat[-1]["content"]
201202
super().__init__()
202203

203204
@property

packages/ragbits-core/tests/unit/prompts/test_prompt.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Iterator
12
from pathlib import Path
23
from typing import Any
34

@@ -19,6 +20,15 @@ class _PromptInput(pydantic.BaseModel):
1920
age: int
2021

2122

23+
class _PromptInputWithIterator(pydantic.BaseModel):
24+
"""
25+
Input format for the TestPromptWithIterator, which can be consumed only once.
26+
"""
27+
28+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
29+
context: Iterator[str]
30+
31+
2232
class _SingleAttachmentPromptInput(pydantic.BaseModel):
2333
"""
2434
Single input format for the TestAttachmentPrompt.
@@ -204,6 +214,9 @@ class TestAttachmentPrompt(Prompt):
204214
assert chat[0]["role"] == "user"
205215
assert chat[0]["content"][0]["text"] == "What is in this image?"
206216
assert chat[0]["content"][1]["type"] == "image_url"
217+
assert len(prompt.rendered_user_prompt) == 2
218+
assert prompt.rendered_user_prompt[0] == {"type": "text", "text": "What is in this image?"}
219+
assert prompt.rendered_user_prompt[1]["type"] == "image_url"
207220

208221

209222
def test_image_prompt_encoding():
@@ -270,6 +283,9 @@ class TestAttachmentPrompt(Prompt):
270283
assert chat[0]["role"] == "user"
271284
assert chat[0]["content"][0]["text"] == "What is in this PDF?"
272285
assert chat[0]["content"][1]["type"] == "file"
286+
assert len(prompt.rendered_user_prompt) == 2
287+
assert prompt.rendered_user_prompt[0] == {"type": "text", "text": "What is in this PDF?"}
288+
assert prompt.rendered_user_prompt[1]["type"] == "file"
273289

274290

275291
def test_pdf_prompt_encoding():
@@ -337,6 +353,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
337353
]
338354

339355

356+
def test_prompt_with_iterator_input():
357+
"""Test that a prompt can be created with an iterator input."""
358+
359+
class TestPrompt(Prompt[_PromptInputWithIterator, str]): # pylint: disable=unused-variable
360+
"""A test prompt"""
361+
362+
user_prompt = "Context: {% for chunk in context %}{{ chunk }}{%- endfor %}"
363+
364+
prompt = TestPrompt(_PromptInputWithIterator(context=iter(["Hello", " World"])))
365+
assert prompt.rendered_user_prompt == "Context: Hello World"
366+
assert prompt.chat == [
367+
{"role": "user", "content": "Context: Hello World"},
368+
]
369+
370+
340371
def test_input_type_must_be_pydantic_model():
341372
"""Test that an error is raised when the input type is not a Pydantic model."""
342373
with pytest.raises(AssertionError):

0 commit comments

Comments
 (0)