Skip to content

Commit 0efdb8e

Browse files
feat(ai-proxy): support embeddings API (#12062)
1 parent 7d8eb88 commit 0efdb8e

File tree

6 files changed

+124
-86
lines changed

6 files changed

+124
-86
lines changed

apisix/plugins/ai-drivers/openai-base.lua

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ local CONTENT_TYPE_JSON = "application/json"
2525
local core = require("apisix.core")
2626
local http = require("resty.http")
2727
local url = require("socket.url")
28-
local schema = require("apisix.plugins.ai-drivers.schema")
2928
local ngx_re = require("ngx.re")
3029

3130
local ngx_print = ngx.print
@@ -59,11 +58,6 @@ function _M.validate_request(ctx)
5958
return nil, err
6059
end
6160

62-
local ok, err = core.schema.check(schema.chat_request_schema, request_table)
63-
if not ok then
64-
return nil, "request format doesn't match schema: " .. err
65-
end
66-
6761
return request_table, nil
6862
end
6963

apisix/plugins/ai-proxy/schema.lua

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,43 +42,8 @@ local model_options_schema = {
4242
type = "string",
4343
description = "Model to execute.",
4444
},
45-
max_tokens = {
46-
type = "integer",
47-
description = "Defines the max_tokens, if using chat or completion models.",
48-
default = 256
49-
50-
},
51-
input_cost = {
52-
type = "number",
53-
description = "Defines the cost per 1M tokens in your prompt.",
54-
minimum = 0
55-
56-
},
57-
output_cost = {
58-
type = "number",
59-
description = "Defines the cost per 1M tokens in the output of the AI.",
60-
minimum = 0
61-
62-
},
63-
temperature = {
64-
type = "number",
65-
description = "Defines the matching temperature, if using chat or completion models.",
66-
minimum = 0.0,
67-
maximum = 5.0,
68-
69-
},
70-
top_p = {
71-
type = "number",
72-
description = "Defines the top-p probability mass, if supported.",
73-
minimum = 0,
74-
maximum = 1,
75-
76-
},
77-
stream = {
78-
description = "Stream response by SSE",
79-
type = "boolean",
80-
}
8145
},
46+
additionalProperties = true,
8247
}
8348

8449
local ai_instance_schema = {

docs/en/latest/plugins/ai-proxy-multi.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,6 @@ Proxying requests to OpenAI is supported now. Other LLM services will be support
6363
| provider.auth | Yes | object | Authentication details, including headers and query parameters. | |
6464
| provider.auth.header | No | object | Authentication details sent via headers. Header name must match `^[a-zA-Z0-9._-]+$`. | |
6565
| provider.auth.query | No | object | Authentication details sent via query parameters. Keys must match `^[a-zA-Z0-9._-]+$`. | |
66-
| provider.options.max_tokens | No | integer | Defines the maximum tokens for chat or completion models. | 256 |
67-
| provider.options.input_cost | No | number | Cost per 1M tokens in the input prompt. Minimum is 0. | |
68-
| provider.options.output_cost | No | number | Cost per 1M tokens in the AI-generated output. Minimum is 0. | |
69-
| provider.options.temperature | No | number | Defines the model's temperature (0.0 - 5.0) for randomness in responses. | |
70-
| provider.options.top_p | No | number | Defines the top-p probability mass (0 - 1) for nucleus sampling. | |
71-
| provider.options.stream | No | boolean | Enables streaming responses via SSE. | |
7266
| provider.override.endpoint | No | string | Custom host override for the AI provider. | |
7367
| timeout | No | integer | Request timeout in milliseconds (1-60000). | 30000 |
7468
| keepalive | No | boolean | Enables keepalive connections. | true |

docs/en/latest/plugins/ai-proxy.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,6 @@ Proxying requests to OpenAI is supported now. Other LLM services will be support
5656
| model.provider | Yes | String | Name of the AI service provider (`openai`). |
5757
| model.name | Yes | String | Model name to execute. |
5858
| model.options | No | Object | Key/value settings for the model |
59-
| model.options.max_tokens | No | Integer | Defines the max tokens if using chat or completion models. Default: 256 |
60-
| model.options.input_cost | No | Number | Cost per 1M tokens in your prompt. Minimum: 0 |
61-
| model.options.output_cost | No | Number | Cost per 1M tokens in the output of the AI. Minimum: 0 |
62-
| model.options.temperature | No | Number | Matching temperature for models. Range: 0.0 - 5.0 |
63-
| model.options.top_p | No | Number | Top-p probability mass. Range: 0 - 1 |
64-
| model.options.stream | No | Boolean | Stream response by SSE. |
6559
| override.endpoint | No | String | Override the endpoint of the AI provider |
6660
| timeout | No | Integer | Timeout in milliseconds for requests to LLM. Range: 1 - 60000. Default: 30000 |
6761
| keepalive | No | Boolean | Enable keepalive for requests to LLM. Default: true |

t/plugin/ai-proxy-multi.t

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -360,19 +360,7 @@ unsupported content-type: application/x-www-form-urlencoded, only application/js
360360
361361
362362
363-
=== TEST 11: request schema validity check
364-
--- request
365-
POST /anything
366-
{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
367-
--- more_headers
368-
Authorization: Bearer token
369-
--- error_code: 400
370-
--- response_body chomp
371-
request format doesn't match schema: property "messages" is required
372-
373-
374-
375-
=== TEST 12: model options being merged to request body
363+
=== TEST 11: model options being merged to request body
376364
--- config
377365
location /t {
378366
content_by_lua_block {
@@ -441,7 +429,7 @@ options_works
441429
442430
443431
444-
=== TEST 13: override path
432+
=== TEST 12: override path
445433
--- config
446434
location /t {
447435
content_by_lua_block {
@@ -509,7 +497,7 @@ path override works
509497
510498
511499
512-
=== TEST 14: set route with stream = true (SSE)
500+
=== TEST 13: set route with stream = true (SSE)
513501
--- config
514502
location /t {
515503
content_by_lua_block {
@@ -558,7 +546,7 @@ passed
558546
559547
560548
561-
=== TEST 15: test is SSE works as expected
549+
=== TEST 14: test is SSE works as expected
562550
--- config
563551
location /t {
564552
content_by_lua_block {

t/plugin/ai-proxy.t

Lines changed: 119 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,66 @@ add_block_preprocessor(sub {
106106
}
107107
}
108108
109+
location /v1/embeddings {
110+
content_by_lua_block {
111+
if ngx.req.get_method() ~= "POST" then
112+
ngx.status = 400
113+
ngx.say("unsupported request method: ", ngx.req.get_method())
114+
end
115+
116+
local header_auth = ngx.req.get_headers()["authorization"]
117+
if header_auth ~= "Bearer token" then
118+
ngx.status = 401
119+
ngx.say("unauthorized")
120+
return
121+
end
122+
123+
ngx.req.read_body()
124+
local body, err = ngx.req.get_body_data()
125+
local json = require("cjson.safe")
126+
body, err = json.decode(body)
127+
if err then
128+
ngx.status = 400
129+
ngx.say("failed to get request body: ", err)
130+
end
131+
132+
if body.model ~= "text-embedding-ada-002" then
133+
ngx.status = 400
134+
ngx.say("unsupported model: ", body.model)
135+
return
136+
end
137+
138+
if body.encoding_format ~= "float" then
139+
ngx.status = 400
140+
ngx.say("unsupported encoding format: ", body.encoding_format)
141+
return
142+
end
143+
144+
ngx.status = 200
145+
ngx.say([[
146+
{
147+
"object": "list",
148+
"data": [
149+
{
150+
"object": "embedding",
151+
"embedding": [
152+
0.0023064255,
153+
-0.009327292,
154+
-0.0028842222
155+
],
156+
"index": 0
157+
}
158+
],
159+
"model": "text-embedding-ada-002",
160+
"usage": {
161+
"prompt_tokens": 8,
162+
"total_tokens": 8
163+
}
164+
}
165+
]])
166+
}
167+
}
168+
109169
location /random {
110170
content_by_lua_block {
111171
ngx.say("path override works")
@@ -330,19 +390,7 @@ unsupported content-type: application/x-www-form-urlencoded, only application/js
330390
331391
332392
333-
=== TEST 11: request schema validity check
334-
--- request
335-
POST /anything
336-
{ "messages-missing": [ { "role": "system", "content": "xyz" } ] }
337-
--- more_headers
338-
Authorization: Bearer token
339-
--- error_code: 400
340-
--- response_body chomp
341-
request format doesn't match schema: property "messages" is required
342-
343-
344-
345-
=== TEST 12: model options being merged to request body
393+
=== TEST 11: model options being merged to request body
346394
--- config
347395
location /t {
348396
content_by_lua_block {
@@ -405,7 +453,7 @@ options_works
405453
406454
407455
408-
=== TEST 13: override path
456+
=== TEST 12: override path
409457
--- config
410458
location /t {
411459
content_by_lua_block {
@@ -467,7 +515,7 @@ path override works
467515
468516
469517
470-
=== TEST 14: set route with stream = true (SSE)
518+
=== TEST 13: set route with stream = true (SSE)
471519
--- config
472520
location /t {
473521
content_by_lua_block {
@@ -510,7 +558,7 @@ passed
510558
511559
512560
513-
=== TEST 15: test is SSE works as expected
561+
=== TEST 14: test is SSE works as expected
514562
--- config
515563
location /t {
516564
content_by_lua_block {
@@ -568,3 +616,58 @@ passed
568616
}
569617
--- response_body_like eval
570618
qr/6data: \[DONE\]\n\n/
619+
620+
621+
622+
=== TEST 15: proxy embedding endpoint
623+
--- config
624+
location /t {
625+
content_by_lua_block {
626+
local t = require("lib.test_admin").test
627+
local code, body = t('/apisix/admin/routes/1',
628+
ngx.HTTP_PUT,
629+
[[{
630+
"uri": "/embeddings",
631+
"plugins": {
632+
"ai-proxy": {
633+
"provider": "openai",
634+
"auth": {
635+
"header": {
636+
"Authorization": "Bearer token"
637+
}
638+
},
639+
"options": {
640+
"model": "text-embedding-ada-002",
641+
"encoding_format": "float"
642+
},
643+
"override": {
644+
"endpoint": "http://localhost:6724/v1/embeddings"
645+
}
646+
}
647+
}
648+
}]]
649+
)
650+
651+
if code >= 300 then
652+
ngx.status = code
653+
ngx.say(body)
654+
return
655+
end
656+
657+
ngx.say("passed")
658+
}
659+
}
660+
--- response_body
661+
passed
662+
663+
664+
665+
=== TEST 16: send request to embedding api
666+
--- request
667+
POST /embeddings
668+
{
669+
"input": "The food was delicious and the waiter..."
670+
}
671+
--- error_code: 200
672+
--- response_body_like eval
673+
qr/.*text-embedding-ada-002*/

0 commit comments

Comments
 (0)