@@ -56,12 +56,14 @@ def __init__(
5656 def guard (self ) -> Guard :
5757 return self ._guard
5858
59- def engine_api (self , prompt : str , ** kwargs ) -> str :
59+ def engine_api (self , messages : List [Dict [str , str ]], ** kwargs ) -> str :
60+ user_messages = [m for m in messages if m ["role" ] == "user" ]
61+ query = user_messages [- 1 ]["content" ]
6062 if isinstance (self ._engine , BaseQueryEngine ):
61- response = self ._engine .query (prompt )
63+ response = self ._engine .query (query )
6264 elif isinstance (self ._engine , BaseChatEngine ):
6365 chat_history = kwargs .get ("chat_history" , [])
64- response = self ._engine .chat (prompt , chat_history )
66+ response = self ._engine .chat (query , chat_history )
6567 else :
6668 raise ValueError ("Unsupported engine type" )
6769
@@ -77,9 +79,15 @@ def _query(self, query_bundle: "QueryBundle") -> "RESPONSE_TYPE":
7779 if isinstance (query_bundle , str ):
7880 query_bundle = QueryBundle (query_bundle )
7981 try :
82+ messages = [
83+ {
84+ "role" : "user" ,
85+ "content" : query_bundle .query_str ,
86+ }
87+ ]
8088 validated_output = self .guard (
8189 llm_api = self .engine_api ,
82- prompt = query_bundle . query_str ,
90+ messages = messages ,
8391 ** self ._guard_kwargs ,
8492 )
8593
@@ -128,9 +136,15 @@ def chat(
128136 if chat_history is None :
129137 chat_history = []
130138 try :
139+ messages = [
140+ {
141+ "role" : "user" ,
142+ "content" : message ,
143+ }
144+ ]
131145 validated_output = self .guard (
132146 llm_api = self .engine_api ,
133- prompt = message ,
147+ messages = messages ,
134148 chat_history = chat_history ,
135149 ** self ._guard_kwargs ,
136150 )
0 commit comments