Skip to content

Commit 34310ce

Browse files
committed
prompt -> messages
1 parent 540d031 commit 34310ce

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

guardrails/integrations/llama_index/guardrails_engine.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,14 @@ def __init__(
5656
def guard(self) -> Guard:
5757
return self._guard
5858

59-
def engine_api(self, prompt: str, **kwargs) -> str:
59+
def engine_api(self, messages: List[Dict[str, str]], **kwargs) -> str:
60+
user_messages = [m for m in messages if m["role"] == "user"]
61+
query = user_messages[-1]["content"]
6062
if isinstance(self._engine, BaseQueryEngine):
61-
response = self._engine.query(prompt)
63+
response = self._engine.query(query)
6264
elif isinstance(self._engine, BaseChatEngine):
6365
chat_history = kwargs.get("chat_history", [])
64-
response = self._engine.chat(prompt, chat_history)
66+
response = self._engine.chat(query, chat_history)
6567
else:
6668
raise ValueError("Unsupported engine type")
6769

@@ -77,9 +79,15 @@ def _query(self, query_bundle: "QueryBundle") -> "RESPONSE_TYPE":
7779
if isinstance(query_bundle, str):
7880
query_bundle = QueryBundle(query_bundle)
7981
try:
82+
messages = [
83+
{
84+
"role": "user",
85+
"content": query_bundle.query_str,
86+
}
87+
]
8088
validated_output = self.guard(
8189
llm_api=self.engine_api,
82-
prompt=query_bundle.query_str,
90+
messages=messages,
8391
**self._guard_kwargs,
8492
)
8593

@@ -128,9 +136,15 @@ def chat(
128136
if chat_history is None:
129137
chat_history = []
130138
try:
139+
messages = [
140+
{
141+
"role": "user",
142+
"content": message,
143+
}
144+
]
131145
validated_output = self.guard(
132146
llm_api=self.engine_api,
133-
prompt=message,
147+
messages=messages,
134148
chat_history=chat_history,
135149
**self._guard_kwargs,
136150
)

0 commit comments

Comments
 (0)