Skip to content

Commit 3396c06

Browse files
committed
Merge branch 'master' of github.com:apache/apisix into revolyssup/ai-prompt-guard
2 parents 39827ac + 7dba835 commit 3396c06

33 files changed

+1916
-863
lines changed

Makefile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,6 @@ install: runtime
382382
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
383383
$(ENV_INSTALL) apisix/plugins/ai-rag/vector-search/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
384384

385-
# ai-content-moderation plugin
386-
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai
387-
$(ENV_INSTALL) apisix/plugins/ai/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai
388385

389386
$(ENV_INSTALL) bin/apisix $(ENV_INST_BINDIR)/apisix
390387

apisix/cli/config.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ local _M = {
218218
"ai-prompt-decorator",
219219
"ai-prompt-guard",
220220
"ai-rag",
221-
"ai-content-moderation",
221+
"ai-aws-content-moderation",
222222
"proxy-mirror",
223223
"proxy-rewrite",
224224
"workflow",

apisix/core/config_etcd.lua

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,14 @@ local function do_run_watch(premature)
262262
cancel_watch(http_cli)
263263
break
264264
end
265+
266+
if rev < watch_ctx.rev then
267+
log.error("received smaller revision, rev=", rev, ", watch_ctx.rev=",
268+
watch_ctx.rev,". etcd may be restarted. resyncing....")
269+
produce_res(nil, "restarted")
270+
cancel_watch(http_cli)
271+
break
272+
end
265273
if rev > watch_ctx.rev then
266274
watch_ctx.rev = rev + 1
267275
end
@@ -569,6 +577,7 @@ local function load_full_data(self, dir_res, headers)
569577
end
570578

571579
if headers then
580+
self.prev_index = tonumber(headers["X-Etcd-Index"]) or 0
572581
self:upgrade_version(headers["X-Etcd-Index"])
573582
end
574583

@@ -633,7 +642,7 @@ local function sync_data(self)
633642
log.info("res: ", json.delay_encode(dir_res, true), ", err: ", err)
634643

635644
if not dir_res then
636-
if err == "compacted" then
645+
if err == "compacted" or err == "restarted" then
637646
self.need_reload = true
638647
log.error("waitdir [", self.key, "] err: ", err,
639648
", will read the configuration again via readdir")

apisix/plugins/ai-content-moderation.lua renamed to apisix/plugins/ai-aws-content-moderation.lua

Lines changed: 28 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,48 +19,34 @@ local aws_instance = require("resty.aws")()
1919
local http = require("resty.http")
2020
local fetch_secrets = require("apisix.secret").fetch_secrets
2121

22-
local next = next
2322
local pairs = pairs
2423
local unpack = unpack
2524
local type = type
2625
local ipairs = ipairs
27-
local require = require
2826
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
2927
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST
3028

31-
32-
local aws_comprehend_schema = {
33-
type = "object",
34-
properties = {
35-
access_key_id = { type = "string" },
36-
secret_access_key = { type = "string" },
37-
region = { type = "string" },
38-
endpoint = {
39-
type = "string",
40-
pattern = [[^https?://]]
41-
},
42-
ssl_verify = {
43-
type = "boolean",
44-
default = true
45-
}
46-
},
47-
required = { "access_key_id", "secret_access_key", "region", }
48-
}
49-
5029
local moderation_categories_pattern = "^(PROFANITY|HATE_SPEECH|INSULT|"..
5130
"HARASSMENT_OR_ABUSE|SEXUAL|VIOLENCE_OR_THREAT)$"
5231
local schema = {
5332
type = "object",
5433
properties = {
55-
provider = {
34+
comprehend = {
5635
type = "object",
5736
properties = {
58-
aws_comprehend = aws_comprehend_schema
37+
access_key_id = { type = "string" },
38+
secret_access_key = { type = "string" },
39+
region = { type = "string" },
40+
endpoint = {
41+
type = "string",
42+
pattern = [[^https?://]]
43+
},
44+
ssl_verify = {
45+
type = "boolean",
46+
default = true
47+
}
5948
},
60-
maxProperties = 1,
61-
-- ensure only one provider can be configured while implementing support for
62-
-- other providers
63-
required = { "aws_comprehend" }
49+
required = { "access_key_id", "secret_access_key", "region", }
6450
},
6551
moderation_categories = {
6652
type = "object",
@@ -78,20 +64,16 @@ local schema = {
7864
minimum = 0,
7965
maximum = 1,
8066
default = 0.5
81-
},
82-
llm_provider = {
83-
type = "string",
84-
enum = { "openai" },
8567
}
8668
},
87-
required = { "provider", "llm_provider" },
69+
required = { "comprehend" },
8870
}
8971

9072

9173
local _M = {
9274
version = 0.1,
9375
priority = 1040, -- TODO: might change
94-
name = "ai-content-moderation",
76+
name = "ai-aws-content-moderation",
9577
schema = schema,
9678
}
9779

@@ -107,51 +89,44 @@ function _M.rewrite(conf, ctx)
10789
return HTTP_INTERNAL_SERVER_ERROR, "failed to retrieve secrets from conf"
10890
end
10991

110-
local body, err = core.request.get_json_request_body_table()
92+
local body, err = core.request.get_body()
11193
if not body then
11294
return HTTP_BAD_REQUEST, err
11395
end
11496

115-
local msgs = body.messages
116-
if type(msgs) ~= "table" or #msgs < 1 then
117-
return HTTP_BAD_REQUEST, "messages not found in request body"
118-
end
119-
120-
local provider = conf.provider[next(conf.provider)]
97+
local comprehend = conf.comprehend
12198

12299
local credentials = aws_instance:Credentials({
123-
accessKeyId = provider.access_key_id,
124-
secretAccessKey = provider.secret_access_key,
125-
sessionToken = provider.session_token,
100+
accessKeyId = comprehend.access_key_id,
101+
secretAccessKey = comprehend.secret_access_key,
102+
sessionToken = comprehend.session_token,
126103
})
127104

128-
local default_endpoint = "https://comprehend." .. provider.region .. ".amazonaws.com"
129-
local scheme, host, port = unpack(http:parse_uri(provider.endpoint or default_endpoint))
105+
local default_endpoint = "https://comprehend." .. comprehend.region .. ".amazonaws.com"
106+
local scheme, host, port = unpack(http:parse_uri(comprehend.endpoint or default_endpoint))
130107
local endpoint = scheme .. "://" .. host
131108
aws_instance.config.endpoint = endpoint
132-
aws_instance.config.ssl_verify = provider.ssl_verify
109+
aws_instance.config.ssl_verify = comprehend.ssl_verify
133110

134111
local comprehend = aws_instance:Comprehend({
135112
credentials = credentials,
136113
endpoint = endpoint,
137-
region = provider.region,
114+
region = comprehend.region,
138115
port = port,
139116
})
140117

141-
local ai_module = require("apisix.plugins.ai." .. conf.llm_provider)
142-
local create_request_text_segments = ai_module.create_request_text_segments
143-
144-
local text_segments = create_request_text_segments(msgs)
145118
local res, err = comprehend:detectToxicContent({
146119
LanguageCode = "en",
147-
TextSegments = text_segments,
120+
TextSegments = {{
121+
Text = body
122+
}},
148123
})
149124

150125
if not res then
151126
core.log.error("failed to send request to ", endpoint, ": ", err)
152127
return HTTP_INTERNAL_SERVER_ERROR, err
153128
end
154-
129+
core.log.warn("dibag: ", core.json.encode(res))
155130
local results = res.body and res.body.ResultList
156131
if type(results) ~= "table" or core.table.isempty(results) then
157132
return HTTP_INTERNAL_SERVER_ERROR, "failed to get moderation results from response"

apisix/plugins/ai-drivers/deepseek.lua

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- limitations under the License.
1616
--
1717

18-
return require("apisix.plugins.ai-drivers.openai-compatible").new(
18+
return require("apisix.plugins.ai-drivers.openai-base").new(
1919
{
2020
host = "api.deepseek.com",
2121
path = "/chat/completions",
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
local _M = {}
18+
19+
local mt = {
20+
__index = _M
21+
}
22+
23+
local core = require("apisix.core")
24+
local http = require("resty.http")
25+
local url = require("socket.url")
26+
27+
local pairs = pairs
28+
local type = type
29+
local setmetatable = setmetatable
30+
31+
32+
function _M.new(opts)
33+
34+
local self = {
35+
host = opts.host,
36+
port = opts.port,
37+
path = opts.path,
38+
}
39+
return setmetatable(self, mt)
40+
end
41+
42+
43+
function _M.request(self, conf, request_table, extra_opts)
44+
local httpc, err = http.new()
45+
if not httpc then
46+
return nil, "failed to create http client to send request to LLM server: " .. err
47+
end
48+
httpc:set_timeout(conf.timeout)
49+
50+
local endpoint = extra_opts.endpoint
51+
local parsed_url
52+
if endpoint then
53+
parsed_url = url.parse(endpoint)
54+
end
55+
56+
local ok, err = httpc:connect({
57+
scheme = endpoint and parsed_url.scheme or "https",
58+
host = endpoint and parsed_url.host or self.host,
59+
port = endpoint and parsed_url.port or self.port,
60+
ssl_verify = conf.ssl_verify,
61+
ssl_server_name = endpoint and parsed_url.host or self.host,
62+
pool_size = conf.keepalive and conf.keepalive_pool,
63+
})
64+
65+
if not ok then
66+
return nil, "failed to connect to LLM server: " .. err
67+
end
68+
69+
local query_params = extra_opts.query_params
70+
71+
if type(parsed_url) == "table" and parsed_url.query and #parsed_url.query > 0 then
72+
local args_tab = core.string.decode_args(parsed_url.query)
73+
if type(args_tab) == "table" then
74+
core.table.merge(query_params, args_tab)
75+
end
76+
end
77+
78+
local path = (endpoint and parsed_url.path or self.path)
79+
80+
local headers = extra_opts.headers
81+
headers["Content-Type"] = "application/json"
82+
local params = {
83+
method = "POST",
84+
headers = headers,
85+
keepalive = conf.keepalive,
86+
ssl_verify = conf.ssl_verify,
87+
path = path,
88+
query = query_params
89+
}
90+
91+
if extra_opts.model_options then
92+
for opt, val in pairs(extra_opts.model_options) do
93+
request_table[opt] = val
94+
end
95+
end
96+
97+
local req_json, err = core.json.encode(request_table)
98+
if not req_json then
99+
return nil, err
100+
end
101+
102+
params.body = req_json
103+
104+
local res, err = httpc:request(params)
105+
if not res then
106+
return nil, err
107+
end
108+
109+
return res, nil, httpc
110+
end
111+
112+
return _M

0 commit comments

Comments
 (0)