22
33import inspect
44import json
5- from collections .abc import Callable
6- from typing import Any , Awaitable , Literal , Sequence
5+ from collections .abc import Awaitable , Callable , Sequence
6+ from typing import Any , Literal
77
88import pydantic_core
99from pydantic import BaseModel , Field , TypeAdapter , validate_call
@@ -19,7 +19,7 @@ class Message(BaseModel):
1919 role : Literal ["user" , "assistant" ]
2020 content : CONTENT_TYPES
2121
22- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
22+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
2323 if isinstance (content , str ):
2424 content = TextContent (type = "text" , text = content )
2525 super ().__init__ (content = content , ** kwargs )
@@ -30,7 +30,7 @@ class UserMessage(Message):
3030
3131 role : Literal ["user" , "assistant" ] = "user"
3232
33- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
33+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
3434 super ().__init__ (content = content , ** kwargs )
3535
3636
@@ -39,11 +39,13 @@ class AssistantMessage(Message):
3939
4040 role : Literal ["user" , "assistant" ] = "assistant"
4141
42- def __init__ (self , content : str | CONTENT_TYPES , ** kwargs ):
42+ def __init__ (self , content : str | CONTENT_TYPES , ** kwargs : Any ):
4343 super ().__init__ (content = content , ** kwargs )
4444
4545
46- message_validator = TypeAdapter (UserMessage | AssistantMessage )
46+ message_validator = TypeAdapter [UserMessage | AssistantMessage ](
47+ UserMessage | AssistantMessage
48+ )
4749
4850SyncPromptResult = (
4951 str | Message | dict [str , Any ] | Sequence [str | Message | dict [str , Any ]]
@@ -73,12 +75,12 @@ class Prompt(BaseModel):
7375 arguments : list [PromptArgument ] | None = Field (
7476 None , description = "Arguments that can be passed to the prompt"
7577 )
76- fn : Callable = Field (exclude = True )
78+ fn : Callable [..., PromptResult | Awaitable [ PromptResult ]] = Field (exclude = True )
7779
7880 @classmethod
7981 def from_function (
8082 cls ,
81- fn : Callable [..., PromptResult ],
83+ fn : Callable [..., PromptResult | Awaitable [ PromptResult ] ],
8284 name : str | None = None ,
8385 description : str | None = None ,
8486 ) -> "Prompt" :
@@ -99,7 +101,7 @@ def from_function(
99101 parameters = TypeAdapter (fn ).json_schema ()
100102
101103 # Convert parameters to PromptArguments
102- arguments = []
104+ arguments : list [ PromptArgument ] = []
103105 if "properties" in parameters :
104106 for param_name , param in parameters ["properties" ].items ():
105107 required = param_name in parameters .get ("required" , [])
@@ -138,25 +140,23 @@ async def render(self, arguments: dict[str, Any] | None = None) -> list[Message]
138140 result = await result
139141
140142 # Validate messages
141- if not isinstance (result , ( list , tuple ) ):
143+ if not isinstance (result , list | tuple ):
142144 result = [result ]
143145
144146 # Convert result to messages
145- messages = []
146- for msg in result :
147+ messages : list [ Message ] = []
148+ for msg in result : # type: ignore[reportUnknownVariableType]
147149 try :
148150 if isinstance (msg , Message ):
149151 messages .append (msg )
150152 elif isinstance (msg , dict ):
151- msg = message_validator .validate_python (msg )
152- messages .append (msg )
153+ messages .append (message_validator .validate_python (msg ))
153154 elif isinstance (msg , str ):
154- messages .append (
155- UserMessage (content = TextContent (type = "text" , text = msg ))
156- )
155+ content = TextContent (type = "text" , text = msg )
156+ messages .append (UserMessage (content = content ))
157157 else :
158- msg = json .dumps (pydantic_core .to_jsonable_python (msg ))
159- messages .append (Message (role = "user" , content = msg ))
158+ content = json .dumps (pydantic_core .to_jsonable_python (msg ))
159+ messages .append (Message (role = "user" , content = content ))
160160 except Exception :
161161 raise ValueError (
162162 f"Could not convert prompt result to message: { msg } "
0 commit comments