Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions apisix/plugins/ai-proxy-multi.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ local pcall = pcall
local ipairs = ipairs
local type = type

local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local priority_balancer = require("apisix.balancer.priority")

local pickers = {}
Expand Down Expand Up @@ -157,30 +156,48 @@ local function pick_target(ctx, conf, ups_tab)
create_server_picker, conf, ups_tab)
end
if not server_picker then
return internal_server_error, "failed to fetch server picker"
return nil, nil, "failed to fetch server picker"
end
ctx.server_picker = server_picker

local instance_name = server_picker.get(ctx)
local instance_conf = get_instance_conf(conf.instances, instance_name)

local instance_name, err = server_picker.get(ctx)
if err then
return nil, nil, err
end
ctx.balancer_server = instance_name
ctx.server_picker = server_picker
if conf.fallback_strategy == "instance_health_and_rate_limiting" then
local ai_rate_limiting = require("apisix.plugins.ai-rate-limiting")
for _ = 1, #conf.instances do
if ai_rate_limiting.check_instance_status(nil, ctx, instance_name) then
break
end
core.log.info("ai instance: ", instance_name,
" is not available, try to pick another one")
server_picker.after_balance(ctx, true)
instance_name, err = server_picker.get(ctx)
if err then
return nil, nil, err
end
ctx.balancer_server = instance_name
end
end

local instance_conf = get_instance_conf(conf.instances, instance_name)
return instance_name, instance_conf
end


local function pick_ai_instance(ctx, conf, ups_tab)
local instance_name, instance_conf
local instance_name, instance_conf, err
if #conf.instances == 1 then
instance_name = conf.instances[1].name
instance_conf = conf.instances[1]
else
instance_name, instance_conf = pick_target(ctx, conf, ups_tab)
instance_name, instance_conf, err = pick_target(ctx, conf, ups_tab)
end

core.log.info("picked instance: ", instance_name)
return instance_name, instance_conf
return instance_name, instance_conf, err
end


Expand All @@ -194,7 +211,10 @@ function _M.access(conf, ctx)
ups_tab["hash_on"] = hash_on
end

local name, ai_instance = pick_ai_instance(ctx, conf, ups_tab)
local name, ai_instance, err = pick_ai_instance(ctx, conf, ups_tab)
if err then
return 503, err
end
ctx.picked_ai_instance_name = name
ctx.picked_ai_instance = ai_instance
end
Expand Down
9 changes: 7 additions & 2 deletions apisix/plugins/ai-proxy/schema.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ _M.ai_proxy_schema = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 3000,
default = 30000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
Expand Down Expand Up @@ -188,11 +188,16 @@ _M.ai_proxy_multi_schema = {
default = { algorithm = "roundrobin" }
},
instances = ai_instance_schema,
fallback_strategy = {
type = "string",
enum = { "instance_health_and_rate_limiting" },
default = "instance_health_and_rate_limiting",
},
timeout = {
type = "integer",
minimum = 1,
maximum = 60000,
default = 3000,
default = 30000,
description = "timeout in milliseconds",
},
keepalive = {type = "boolean", default = true},
Expand Down
6 changes: 3 additions & 3 deletions apisix/plugins/ai-rate-limiting.lua
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ end
function _M.check_instance_status(conf, ctx, instance_name)
if conf == nil then
local plugins = ctx.plugins
for _, plugin in ipairs(plugins) do
if plugin.name == plugin_name then
conf = plugin
for i = 1, #plugins, 2 do
if plugins[i]["name"] == plugin_name then
conf = plugins[i + 1]
end
end
end
Expand Down
151 changes: 148 additions & 3 deletions t/plugin/ai-rate-limiting.t
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ add_block_preprocessor(sub {
end

ngx.status = 200
ngx.say([[
ngx.say(string.format([[
{
"choices": [
{
Expand All @@ -127,12 +127,12 @@ add_block_preprocessor(sub {
],
"created": 1723780938,
"id": "chatcmpl-9wiSIg5LYrrpxwsr2PubSQnbtod1P",
"model": "gpt-4o-2024-05-13",
"model": "%s",
"object": "chat.completion",
"system_fingerprint": "fp_abc28019ad",
"usage": { "completion_tokens": 5, "prompt_tokens": 8, "total_tokens": 10 }
}
]])
]], body.model))
return
end

Expand Down Expand Up @@ -537,3 +537,148 @@ Authorization: Bearer token
Authorization: Bearer token
--- error_code eval
[200, 200, 200, 200, 200, 200, 200, 403, 503]



=== TEST 13: ai-rate-limiting & ai-proxy-multi, with instance_health_and_rate_limiting strategy
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local code, body = t('/apisix/admin/routes/1',
ngx.HTTP_PUT,
[[{
"uri": "/ai",
"plugins": {
"ai-proxy-multi": {
"fallback_strategy": "instance_health_and_rate_limiting",
"instances": [
{
"name": "openai-gpt4",
"provider": "openai",
"weight": 1,
"priority": 1,
"auth": {
"header": {
"Authorization": "Bearer token"
}
},
"options": {
"model": "gpt-4"
},
"override": {
"endpoint": "http://localhost:16724"
}
},
{
"name": "openai-gpt3",
"provider": "openai",
"weight": 1,
"priority": 0,
"auth": {
"header": {
"Authorization": "Bearer token"
}
},
"options": {
"model": "gpt-3"
},
"override": {
"endpoint": "http://localhost:16724"
}
}
],
"ssl_verify": false
},
"ai-rate-limiting": {
"limit": 10,
"time_window": 60
}
},
"upstream": {
"type": "roundrobin",
"nodes": {
"canbeanything.com": 1
}
}
}]]
)

if code >= 300 then
ngx.status = code
end
ngx.say(body)
}
}
--- response_body
passed



=== TEST 14: fallback strategy should works
--- config
location /t {
content_by_lua_block {
local t = require("lib.test_admin").test
local core = require("apisix.core")
local code, _, body = t("/ai",
ngx.HTTP_POST,
[[{
"messages": [
{ "role": "system", "content": "You are a mathematician" },
{ "role": "user", "content": "What is 1+1?" }
]
}]],
nil,
{
["test-type"] = "options",
["Content-Type"] = "application/json",
}
)

assert(code == 200, "first request should be successful")
assert(core.string.find(body, "gpt-4"),
"first request should be handled by higher priority instance")

local code, _, body = t("/ai",
ngx.HTTP_POST,
[[{
"messages": [
{ "role": "system", "content": "You are a mathematician" },
{ "role": "user", "content": "What is 1+1?" }
]
}]],
nil,
{
["test-type"] = "options",
["Content-Type"] = "application/json",
}
)

assert(code == 200, "second request should be successful")
assert(core.string.find(body, "gpt-3"),
"second request should be handled by lower priority instance")

local code, body = t("/ai",
ngx.HTTP_POST,
[[{
"messages": [
{ "role": "system", "content": "You are a mathematician" },
{ "role": "user", "content": "What is 1+1?" }
]
}]],
nil,
{
["test-type"] = "options",
["Content-Type"] = "application/json",
}
)

assert(code == 503, "third request should be failed")
assert(core.string.find(body, "all servers tried"), "all servers tried")

ngx.say("passed")
}
}
--- response_body
passed
Loading