1+ from collections .abc import Iterator
12from pathlib import Path
23from typing import Any
34
@@ -19,6 +20,15 @@ class _PromptInput(pydantic.BaseModel):
1920 age : int
2021
2122
23+ class _PromptInputWithIterator (pydantic .BaseModel ):
24+ """
25+ Input format for the TestPromptWithIterator, which can be consumed only once.
26+ """
27+
28+ model_config = pydantic .ConfigDict (arbitrary_types_allowed = True )
29+ context : Iterator [str ]
30+
31+
2232class _SingleAttachmentPromptInput (pydantic .BaseModel ):
2333 """
2434 Single input format for the TestAttachmentPrompt.
@@ -204,6 +214,9 @@ class TestAttachmentPrompt(Prompt):
204214 assert chat [0 ]["role" ] == "user"
205215 assert chat [0 ]["content" ][0 ]["text" ] == "What is in this image?"
206216 assert chat [0 ]["content" ][1 ]["type" ] == "image_url"
217+ assert len (prompt .rendered_user_prompt ) == 2
218+ assert prompt .rendered_user_prompt [0 ] == {"type" : "text" , "text" : "What is in this image?" }
219+ assert prompt .rendered_user_prompt [1 ]["type" ] == "image_url"
207220
208221
209222def test_image_prompt_encoding ():
@@ -270,6 +283,9 @@ class TestAttachmentPrompt(Prompt):
270283 assert chat [0 ]["role" ] == "user"
271284 assert chat [0 ]["content" ][0 ]["text" ] == "What is in this PDF?"
272285 assert chat [0 ]["content" ][1 ]["type" ] == "file"
286+ assert len (prompt .rendered_user_prompt ) == 2
287+ assert prompt .rendered_user_prompt [0 ] == {"type" : "text" , "text" : "What is in this PDF?" }
288+ assert prompt .rendered_user_prompt [1 ]["type" ] == "file"
273289
274290
275291def test_pdf_prompt_encoding ():
@@ -337,6 +353,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
337353 ]
338354
339355
356+ def test_prompt_with_iterator_input ():
357+ """Test that a prompt can be created with an iterator input."""
358+
359+ class TestPrompt (Prompt [_PromptInputWithIterator , str ]): # pylint: disable=unused-variable
360+ """A test prompt"""
361+
362+ user_prompt = "Context: {% for chunk in context %}{{ chunk }}{%- endfor %}"
363+
364+ prompt = TestPrompt (_PromptInputWithIterator (context = iter (["Hello" , " World" ])))
365+ assert prompt .rendered_user_prompt == "Context: Hello World"
366+ assert prompt .chat == [
367+ {"role" : "user" , "content" : "Context: Hello World" },
368+ ]
369+
370+
340371def test_input_type_must_be_pydantic_model ():
341372 """Test that an error is raised when the input type is not a Pydantic model."""
342373 with pytest .raises (AssertionError ):
0 commit comments