1
+ from collections .abc import Iterator
1
2
from pathlib import Path
2
3
from typing import Any
3
4
@@ -19,6 +20,15 @@ class _PromptInput(pydantic.BaseModel):
19
20
age : int
20
21
21
22
23
+ class _PromptInputWithIterator (pydantic .BaseModel ):
24
+ """
25
+ Input format for the TestPromptWithIterator, which can be consumed only once.
26
+ """
27
+
28
+ model_config = pydantic .ConfigDict (arbitrary_types_allowed = True )
29
+ context : Iterator [str ]
30
+
31
+
22
32
class _SingleAttachmentPromptInput (pydantic .BaseModel ):
23
33
"""
24
34
Single input format for the TestAttachmentPrompt.
@@ -204,6 +214,9 @@ class TestAttachmentPrompt(Prompt):
204
214
assert chat [0 ]["role" ] == "user"
205
215
assert chat [0 ]["content" ][0 ]["text" ] == "What is in this image?"
206
216
assert chat [0 ]["content" ][1 ]["type" ] == "image_url"
217
+ assert len (prompt .rendered_user_prompt ) == 2
218
+ assert prompt .rendered_user_prompt [0 ] == {"type" : "text" , "text" : "What is in this image?" }
219
+ assert prompt .rendered_user_prompt [1 ]["type" ] == "image_url"
207
220
208
221
209
222
def test_image_prompt_encoding ():
@@ -270,6 +283,9 @@ class TestAttachmentPrompt(Prompt):
270
283
assert chat [0 ]["role" ] == "user"
271
284
assert chat [0 ]["content" ][0 ]["text" ] == "What is in this PDF?"
272
285
assert chat [0 ]["content" ][1 ]["type" ] == "file"
286
+ assert len (prompt .rendered_user_prompt ) == 2
287
+ assert prompt .rendered_user_prompt [0 ] == {"type" : "text" , "text" : "What is in this PDF?" }
288
+ assert prompt .rendered_user_prompt [1 ]["type" ] == "file"
273
289
274
290
275
291
def test_pdf_prompt_encoding ():
@@ -337,6 +353,21 @@ class TestPrompt(Prompt[_PromptInput, str]): # pylint: disable=unused-variable
337
353
]
338
354
339
355
356
+ def test_prompt_with_iterator_input ():
357
+ """Test that a prompt can be created with an iterator input."""
358
+
359
+ class TestPrompt (Prompt [_PromptInputWithIterator , str ]): # pylint: disable=unused-variable
360
+ """A test prompt"""
361
+
362
+ user_prompt = "Context: {% for chunk in context %}{{ chunk }}{%- endfor %}"
363
+
364
+ prompt = TestPrompt (_PromptInputWithIterator (context = iter (["Hello" , " World" ])))
365
+ assert prompt .rendered_user_prompt == "Context: Hello World"
366
+ assert prompt .chat == [
367
+ {"role" : "user" , "content" : "Context: Hello World" },
368
+ ]
369
+
370
+
340
371
def test_input_type_must_be_pydantic_model ():
341
372
"""Test that an error is raised when the input type is not a Pydantic model."""
342
373
with pytest .raises (AssertionError ):
0 commit comments