Skip to content

Commit ac03b08

Browse files
committed
apply suggestions
1 parent 88f9996 commit ac03b08

File tree

1 file changed

+31
-13
lines changed

1 file changed

+31
-13
lines changed

apisix/plugins/ai-prompt-guard.lua

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -82,20 +82,16 @@ end
8282
local function get_content_to_check(conf, messages)
8383
local contents = {}
8484
if conf.match_all_conversation_history then
85-
for _, msg in ipairs(messages) do
86-
if msg.content then
87-
core.table.insert(contents, msg.content)
88-
end
89-
end
85+
return messages
9086
else
9187
if #messages > 0 then
9288
local last_msg = messages[#messages]
93-
if last_msg.content then
94-
core.table.insert(contents, last_msg.content)
89+
if last_msg then
90+
core.table.insert(contents, last_msg)
9591
end
9692
end
9793
end
98-
return table.concat(contents, " ")
94+
return contents
9995
end
10096

10197
function _M.access(conf, ctx)
@@ -111,12 +107,34 @@ function _M.access(conf, ctx)
111107
end
112108

113109
local messages = json_body.messages or {}
114-
if not conf.match_all_roles and #messages > 0 and messages[#messages].role ~= "user" then
115-
return
110+
messages = get_content_to_check(conf, messages)
111+
if not conf.match_all_roles then
112+
-- filter to only user messages
113+
local new_messages = {}
114+
for _, msg in ipairs(messages) do
115+
if not msg then
116+
return 400, {message = "request doesn't contain messages"}
117+
end
118+
if msg.role == "user" then
119+
core.table.insert(new_messages, msg)
120+
end
121+
end
122+
messages = new_messages
116123
end
117-
118-
local content_to_check = get_content_to_check(conf, messages)
119-
124+
if #messages == 0 then --nothing to check
125+
return 200
126+
end
127+
-- extract only messages
128+
local content = {}
129+
for _, msg in ipairs(messages) do
130+
if not msg then
131+
return 400, {message = "request doesn't contain messages"}
132+
end
133+
if msg.content then
134+
core.table.insert(content, msg.content)
135+
end
136+
end
137+
local content_to_check = table.concat(content, " ")
120138
-- Allow patterns check
121139
if #conf.allow_patterns > 0 then
122140
local any_allowed = false

0 commit comments

Comments
 (0)