Skip to content

Commit 35a59eb

Browse files
fix(ai-proxy): abstract a base for ai-proxy (#11991)
1 parent 2a5425f commit 35a59eb

File tree

4 files changed

+128
-105
lines changed

4 files changed

+128
-105
lines changed

apisix/plugins/ai-proxy-multi.lua

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
local core = require("apisix.core")
1919
local schema = require("apisix.plugins.ai-proxy.schema")
20-
local ai_proxy = require("apisix.plugins.ai-proxy")
2120
local plugin = require("apisix.plugin")
21+
local base = require("apisix.plugins.ai-proxy.base")
2222

2323
local require = require
2424
local pcall = pcall
2525
local ipairs = ipairs
26-
local unpack = unpack
2726
local type = type
2827

2928
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
@@ -190,11 +189,11 @@ local function get_load_balanced_provider(ctx, conf, ups_tab, request_table)
190189
return provider_name, provider_conf
191190
end
192191

193-
ai_proxy.get_model_name = function (...)
192+
local function get_model_name(...)
194193
end
195194

196195

197-
ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
196+
local function proxy_request_to_llm(conf, request_table, ctx)
198197
local ups_tab = {}
199198
local algo = core.table.try_read_attr(conf, "balancer", "algorithm")
200199
if algo == "chash" then
@@ -228,9 +227,7 @@ ai_proxy.proxy_request_to_llm = function (conf, request_table, ctx)
228227
end
229228

230229

231-
function _M.access(conf, ctx)
232-
local rets = {ai_proxy.access(conf, ctx)}
233-
return unpack(rets)
234-
end
230+
_M.access = base.new(proxy_request_to_llm, get_model_name)
231+
235232

236233
return _M

apisix/plugins/ai-proxy.lua

Lines changed: 5 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616
--
1717
local core = require("apisix.core")
1818
local schema = require("apisix.plugins.ai-proxy.schema")
19+
local base = require("apisix.plugins.ai-proxy.base")
20+
1921
local require = require
2022
local pcall = pcall
21-
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
22-
local bad_request = ngx.HTTP_BAD_REQUEST
23-
local ngx_req = ngx.req
24-
local ngx_print = ngx.print
25-
local ngx_flush = ngx.flush
2623

2724
local plugin_name = "ai-proxy"
2825
local _M = {
@@ -42,24 +39,12 @@ function _M.check_schema(conf)
4239
end
4340

4441

45-
local CONTENT_TYPE_JSON = "application/json"
46-
47-
48-
local function keepalive_or_close(conf, httpc)
49-
if conf.set_keepalive then
50-
httpc:set_keepalive(10000, 100)
51-
return
52-
end
53-
httpc:close()
54-
end
55-
56-
57-
function _M.get_model_name(conf)
42+
local function get_model_name(conf)
5843
return conf.model.name
5944
end
6045

6146

62-
function _M.proxy_request_to_llm(conf, request_table, ctx)
47+
local function proxy_request_to_llm(conf, request_table, ctx)
6348
local ai_driver = require("apisix.plugins.ai-drivers." .. conf.model.provider)
6449
local extra_opts = {
6550
endpoint = core.table.try_read_attr(conf, "override", "endpoint"),
@@ -74,82 +59,6 @@ function _M.proxy_request_to_llm(conf, request_table, ctx)
7459
return res, nil, httpc
7560
end
7661

77-
function _M.access(conf, ctx)
78-
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
79-
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
80-
return bad_request, "unsupported content-type: " .. ct
81-
end
82-
83-
local request_table, err = core.request.get_json_request_body_table()
84-
if not request_table then
85-
return bad_request, err
86-
end
87-
88-
local ok, err = core.schema.check(schema.chat_request_schema, request_table)
89-
if not ok then
90-
return bad_request, "request format doesn't match schema: " .. err
91-
end
92-
93-
request_table.model = _M.get_model_name(conf)
94-
95-
if core.table.try_read_attr(conf, "model", "options", "stream") then
96-
request_table.stream = true
97-
end
98-
99-
local res, err, httpc = _M.proxy_request_to_llm(conf, request_table, ctx)
100-
if not res then
101-
core.log.error("failed to send request to LLM service: ", err)
102-
return internal_server_error
103-
end
104-
105-
local body_reader = res.body_reader
106-
if not body_reader then
107-
core.log.error("LLM sent no response body")
108-
return internal_server_error
109-
end
110-
111-
if conf.passthrough then
112-
ngx_req.init_body()
113-
while true do
114-
local chunk, err = body_reader() -- will read chunk by chunk
115-
if err then
116-
core.log.error("failed to read response chunk: ", err)
117-
break
118-
end
119-
if not chunk then
120-
break
121-
end
122-
ngx_req.append_body(chunk)
123-
end
124-
ngx_req.finish_body()
125-
keepalive_or_close(conf, httpc)
126-
return
127-
end
128-
129-
if request_table.stream then
130-
while true do
131-
local chunk, err = body_reader() -- will read chunk by chunk
132-
if err then
133-
core.log.error("failed to read response chunk: ", err)
134-
break
135-
end
136-
if not chunk then
137-
break
138-
end
139-
ngx_print(chunk)
140-
ngx_flush(true)
141-
end
142-
keepalive_or_close(conf, httpc)
143-
return
144-
else
145-
local res_body, err = res:read_body()
146-
if not res_body then
147-
core.log.error("failed to read response body: ", err)
148-
return internal_server_error
149-
end
150-
keepalive_or_close(conf, httpc)
151-
return res.status, res_body
152-
end
153-
end
62+
_M.access = base.new(proxy_request_to_llm, get_model_name)
15463

15564
return _M

apisix/plugins/ai-proxy/base.lua

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
--
2+
-- Licensed to the Apache Software Foundation (ASF) under one or more
3+
-- contributor license agreements. See the NOTICE file distributed with
4+
-- this work for additional information regarding copyright ownership.
5+
-- The ASF licenses this file to You under the Apache License, Version 2.0
6+
-- (the "License"); you may not use this file except in compliance with
7+
-- the License. You may obtain a copy of the License at
8+
--
9+
-- http://www.apache.org/licenses/LICENSE-2.0
10+
--
11+
-- Unless required by applicable law or agreed to in writing, software
12+
-- distributed under the License is distributed on an "AS IS" BASIS,
13+
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
-- See the License for the specific language governing permissions and
15+
-- limitations under the License.
16+
--
17+
18+
local CONTENT_TYPE_JSON = "application/json"
19+
local core = require("apisix.core")
20+
local bad_request = ngx.HTTP_BAD_REQUEST
21+
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
22+
local schema = require("apisix.plugins.ai-proxy.schema")
23+
local ngx_req = ngx.req
24+
local ngx_print = ngx.print
25+
local ngx_flush = ngx.flush
26+
27+
local function keepalive_or_close(conf, httpc)
28+
if conf.set_keepalive then
29+
httpc:set_keepalive(10000, 100)
30+
return
31+
end
32+
httpc:close()
33+
end
34+
35+
local _M = {}
36+
37+
function _M.new(proxy_request_to_llm_func, get_model_name_func)
38+
return function(conf, ctx)
39+
local ct = core.request.header(ctx, "Content-Type") or CONTENT_TYPE_JSON
40+
if not core.string.has_prefix(ct, CONTENT_TYPE_JSON) then
41+
return bad_request, "unsupported content-type: " .. ct
42+
end
43+
44+
local request_table, err = core.request.get_json_request_body_table()
45+
if not request_table then
46+
return bad_request, err
47+
end
48+
49+
local ok, err = core.schema.check(schema.chat_request_schema, request_table)
50+
if not ok then
51+
return bad_request, "request format doesn't match schema: " .. err
52+
end
53+
54+
request_table.model = get_model_name_func(conf)
55+
56+
if core.table.try_read_attr(conf, "model", "options", "stream") then
57+
request_table.stream = true
58+
end
59+
60+
local res, err, httpc = proxy_request_to_llm_func(conf, request_table, ctx)
61+
if not res then
62+
core.log.error("failed to send request to LLM service: ", err)
63+
return internal_server_error
64+
end
65+
66+
local body_reader = res.body_reader
67+
if not body_reader then
68+
core.log.error("LLM sent no response body")
69+
return internal_server_error
70+
end
71+
72+
if conf.passthrough then
73+
ngx_req.init_body()
74+
while true do
75+
local chunk, err = body_reader() -- will read chunk by chunk
76+
if err then
77+
core.log.error("failed to read response chunk: ", err)
78+
break
79+
end
80+
if not chunk then
81+
break
82+
end
83+
ngx_req.append_body(chunk)
84+
end
85+
ngx_req.finish_body()
86+
keepalive_or_close(conf, httpc)
87+
return
88+
end
89+
90+
if request_table.stream then
91+
while true do
92+
local chunk, err = body_reader() -- will read chunk by chunk
93+
if err then
94+
core.log.error("failed to read response chunk: ", err)
95+
break
96+
end
97+
if not chunk then
98+
break
99+
end
100+
ngx_print(chunk)
101+
ngx_flush(true)
102+
end
103+
keepalive_or_close(conf, httpc)
104+
return
105+
else
106+
local res_body, err = res:read_body()
107+
if not res_body then
108+
core.log.error("failed to read response body: ", err)
109+
return internal_server_error
110+
end
111+
keepalive_or_close(conf, httpc)
112+
return res.status, res_body
113+
end
114+
end
115+
end
116+
117+
return _M

t/plugin/ai-proxy-multi2.t

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ passed
289289
290290
291291
=== TEST 6: send request
292-
--- custom_trusted_cert: /etc/ssl/cert.pem
292+
--- custom_trusted_cert: /etc/ssl/certs/ca-certificates.crt
293293
--- request
294294
POST /anything
295295
{ "messages": [ { "role": "system", "content": "You are a mathematician" }, { "role": "user", "content": "What is 1+1?"} ] }

0 commit comments

Comments
 (0)