1313# limitations under the License.
1414from __future__ import annotations
1515
16- from typing import Any , Iterable , Optional , TYPE_CHECKING
16+ from typing import Any , Iterable , Optional , TYPE_CHECKING , cast
1717
1818from pydantic import ValidationError
1919
2020from neo4j_graphrag .exceptions import LLMGenerationError
2121from neo4j_graphrag .llm .base import LLMInterface
22- from neo4j_graphrag .llm .types import LLMResponse , MessageList , UserMessage , BaseMessage
22+ from neo4j_graphrag .llm .types import LLMResponse , MessageList , UserMessage
2323
2424if TYPE_CHECKING :
2525 from anthropic .types .message_param import MessageParam
@@ -71,22 +71,22 @@ def __init__(
7171 self .async_client = anthropic .AsyncAnthropic (** kwargs )
7272
7373 def get_messages (
74- self , input : str , message_history : Optional [list [BaseMessage ]] = None
74+ self , input : str , message_history : Optional [list [dict [ str , str ] ]] = None
7575 ) -> Iterable [MessageParam ]:
7676 messages = []
7777 if message_history :
7878 try :
79- MessageList (messages = message_history )
79+ MessageList (messages = message_history ) # type: ignore
8080 except ValidationError as e :
8181 raise LLMGenerationError (e .errors ()) from e
8282 messages .extend (message_history )
8383 messages .append (UserMessage (content = input ).model_dump ())
84- return messages
84+ return cast ( Iterable [ MessageParam ], messages )
8585
8686 def invoke (
8787 self ,
8888 input : str ,
89- message_history : Optional [list [BaseMessage ]] = None ,
89+ message_history : Optional [list [dict [ str , str ] ]] = None ,
9090 system_instruction : Optional [str ] = None ,
9191 ) -> LLMResponse :
9292 """Sends text to the LLM and returns a response.
@@ -108,18 +108,18 @@ def invoke(
108108 )
109109 response = self .client .messages .create (
110110 model = self .model_name ,
111- system = system_message ,
111+ system = system_message , # type: ignore
112112 messages = messages ,
113113 ** self .model_params ,
114114 )
115- return LLMResponse (content = response .content )
115+ return LLMResponse (content = response .content ) # type: ignore
116116 except self .anthropic .APIError as e :
117117 raise LLMGenerationError (e )
118118
119119 async def ainvoke (
120120 self ,
121121 input : str ,
122- message_history : Optional [list [BaseMessage ]] = None ,
122+ message_history : Optional [list [dict [ str , str ] ]] = None ,
123123 system_instruction : Optional [str ] = None ,
124124 ) -> LLMResponse :
125125 """Asynchronously sends text to the LLM and returns a response.
@@ -141,10 +141,10 @@ async def ainvoke(
141141 )
142142 response = await self .async_client .messages .create (
143143 model = self .model_name ,
144- system = system_message ,
144+ system = system_message , # type: ignore
145145 messages = messages ,
146146 ** self .model_params ,
147147 )
148- return LLMResponse (content = response .content )
148+ return LLMResponse (content = response .content ) # type: ignore
149149 except self .anthropic .APIError as e :
150150 raise LLMGenerationError (e )
0 commit comments