Skip to content

Commit fe93067

Browse files
committed
feat: add OpenAI compatible chat endpoint
1 parent 99ab750 commit fe93067

File tree

10 files changed

+412
-10
lines changed

10 files changed

+412
-10
lines changed

app/ui/src/routes/bot/ds.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ export default function BotDSRoot() {
4646
<div className="mx-auto my-3 w-full max-w-7xl">
4747
{status === "loading" && <SkeletonLoading />}
4848
{status === "success" && (
49-
<div className="px-4 sm:px-6 lg:px-8">
49+
<div className="px-4 sm:px-6 lg:px-8">
5050
<DsTable data={botData.data} />
5151
{botData.total >= 10 && (
5252
<div className="my-3 flex items-center justify-end">

server/src/handlers/api/v1/bot/bot/delete.handler.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export const deleteSourceByIdHandler = async (
1414
const bot = await prisma.bot.findFirst({
1515
where: {
1616
id: bot_id,
17-
user_id: request.user.user_id,
17+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
1818
},
1919
});
2020

@@ -79,7 +79,7 @@ export const deleteBotByIdHandler = async (
7979
const bot = await prisma.bot.findFirst({
8080
where: {
8181
id,
82-
user_id: request.user.user_id,
82+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
8383
},
8484
});
8585

server/src/handlers/api/v1/bot/bot/get.handler.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ export const getBotByIdEmbeddingsHandler = async (
1414
const bot = await prisma.bot.findFirst({
1515
where: {
1616
id,
17-
user_id: request.user.user_id,
17+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
1818
},
1919
});
2020

@@ -100,7 +100,7 @@ export const getBotByIdHandler = async (
100100
const bot = await prisma.bot.findFirst({
101101
where: {
102102
id,
103-
user_id: request.user.user_id,
103+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
104104
},
105105
});
106106

@@ -122,7 +122,7 @@ export const getAllBotsHandler = async (
122122

123123
const bots = await prisma.bot.findMany({
124124
where: {
125-
user_id: request.user.user_id,
125+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
126126
},
127127
orderBy: {
128128
createdAt: "desc",
@@ -216,7 +216,7 @@ export const getBotByIdSettingsHandler = async (
216216
const bot = await prisma.bot.findFirst({
217217
where: {
218218
id,
219-
user_id: request.user.user_id,
219+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
220220
},
221221
});
222222
if (!bot) {
@@ -292,7 +292,7 @@ export const isBotReadyHandler = async (
292292
const bot = await prisma.bot.findFirst({
293293
where: {
294294
id,
295-
user_id: request.user.user_id,
295+
user_id: request.user?.is_admin ? undefined : request.user?.user_id
296296
},
297297
});
298298

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import type { FastifyRequest, FastifyReply } from "fastify";
2+
import type { OpenaiRequestType } from "./type"
3+
import { getModelInfo } from "../../../../utils/get-model-info";
4+
import { embeddings } from "../../../../utils/embeddings";
5+
import { Document } from "langchain/document";
6+
import { BaseRetriever } from "@langchain/core/retrievers";
7+
import { DialoqbaseHybridRetrival } from "../../../../utils/hybrid";
8+
import { DialoqbaseVectorStore } from "../../../../utils/store";
9+
import { createChatModel } from "../bot/playground/chat.service";
10+
import { createChain } from "../../../../chain";
11+
import { openaiNonStreamResponse, openaiStreamResponse } from "./openai-response";
12+
import { groupOpenAiMessages } from "./other";
13+
import { nextTick } from "../../../../utils/nextTick";
14+
15+
16+
export const createChatCompletionHandler = async (
17+
request: FastifyRequest<OpenaiRequestType>,
18+
reply: FastifyReply
19+
) => {
20+
try {
21+
const {
22+
model,
23+
messages
24+
} = request.body;
25+
26+
const prisma = request.server.prisma;
27+
28+
const bot = await prisma.bot.findFirst({
29+
where: {
30+
OR: [
31+
{
32+
id: model
33+
},
34+
{
35+
publicId: model
36+
}
37+
],
38+
user_id: request.user.is_admin ? undefined : request.user.user_id,
39+
},
40+
})
41+
42+
if (!bot) {
43+
return reply.status(404).send({
44+
error: {
45+
message: "Bot not found",
46+
type: "not_found",
47+
param: "model",
48+
code: "bot_not_found"
49+
}
50+
});
51+
}
52+
53+
54+
const embeddingInfo = await getModelInfo({
55+
prisma,
56+
model: bot.embedding,
57+
type: "embedding",
58+
});
59+
60+
if (!embeddingInfo) {
61+
return reply.status(400).send({
62+
error: {
63+
message: "Embedding not found",
64+
type: "not_found",
65+
param: "embedding",
66+
code: "embedding_not_found"
67+
}
68+
});
69+
}
70+
71+
72+
const embeddingModel = embeddings(
73+
embeddingInfo.model_provider!.toLowerCase(),
74+
embeddingInfo.model_id,
75+
embeddingInfo?.config
76+
);
77+
78+
const modelinfo = await getModelInfo({
79+
prisma,
80+
model: bot.model,
81+
type: "chat",
82+
});
83+
84+
if (!modelinfo) {
85+
return reply.status(400).send({
86+
error: {
87+
message: "Model not found",
88+
type: "not_found",
89+
param: "model",
90+
code: "model_not_found"
91+
}
92+
});
93+
}
94+
95+
const botConfig = (modelinfo.config as {}) || {};
96+
let retriever: BaseRetriever;
97+
let resolveWithDocuments: (value: Document[]) => void;
98+
99+
if (bot.use_hybrid_search) {
100+
retriever = new DialoqbaseHybridRetrival(embeddingModel, {
101+
botId: bot.id,
102+
sourceId: null,
103+
callbacks: [
104+
{
105+
handleRetrieverEnd(documents) {
106+
resolveWithDocuments(documents);
107+
},
108+
},
109+
],
110+
});
111+
} else {
112+
const vectorstore = await DialoqbaseVectorStore.fromExistingIndex(
113+
embeddingModel,
114+
{
115+
botId: bot.id,
116+
sourceId: null,
117+
}
118+
);
119+
120+
retriever = vectorstore.asRetriever({
121+
});
122+
}
123+
124+
const streamedModel = createChatModel(
125+
bot,
126+
bot.temperature,
127+
botConfig,
128+
true
129+
);
130+
const nonStreamingModel = createChatModel(bot, bot.temperature, botConfig);
131+
132+
const chain = createChain({
133+
llm: streamedModel,
134+
question_llm: nonStreamingModel,
135+
question_template: bot.questionGeneratorPrompt,
136+
response_template: bot.qaPrompt,
137+
retriever,
138+
});
139+
140+
if (!request.body.stream) {
141+
const res = await chain.invoke({
142+
question: messages[messages.length - 1].content,
143+
chat_history: groupOpenAiMessages(
144+
messages
145+
),
146+
})
147+
148+
149+
return reply.status(200).send(openaiNonStreamResponse(
150+
res,
151+
bot.name
152+
))
153+
}
154+
155+
const stream = await chain.stream({
156+
question: messages[messages.length - 1].content,
157+
chat_history: groupOpenAiMessages(
158+
messages
159+
),
160+
})
161+
reply.raw.setHeader("Content-Type", "text/event-stream");
162+
163+
for await (const token of stream) {
164+
reply.sse({
165+
data: openaiStreamResponse(
166+
token || "",
167+
bot.name
168+
)
169+
});
170+
}
171+
reply.sse({
172+
data: "[DONE]\n\n"
173+
})
174+
await nextTick();
175+
return reply.raw.end();
176+
} catch (error) {
177+
console.log(error)
178+
return reply.status(500).send({
179+
error: {
180+
message: error.message,
181+
type: "internal_server_error",
182+
param: null,
183+
code: "internal_server_error"
184+
}
185+
});
186+
}
187+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import { randomUUID } from "node:crypto";
2+
3+
export const openaiNonStreamResponse = (message: string, model: string) => {
4+
return {
5+
id: randomUUID(),
6+
created: new Date().toISOString(),
7+
model,
8+
choices: [
9+
{
10+
index: 0,
11+
message: {
12+
role: "assistant",
13+
content: message,
14+
},
15+
},
16+
],
17+
object: "chat.completion",
18+
};
19+
};
20+
21+
export const openaiStreamResponse = (message: string, model: string) => {
22+
return JSON.stringify({
23+
id: randomUUID(),
24+
created: new Date().toISOString(),
25+
model,
26+
object: "chat.completion.chunk",
27+
choices: [
28+
{
29+
delta: {
30+
content: message,
31+
},
32+
},
33+
],
34+
});
35+
};
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
export function groupOpenAiMessages(
2+
messages: {
3+
role: "user" | "assistant";
4+
content: string;
5+
}[]
6+
) {
7+
if (messages.length % 2 !== 0) {
8+
messages.pop();
9+
}
10+
11+
const groupedMessages = [];
12+
for (let i = 0; i < messages.length; i += 2) {
13+
groupedMessages.push({
14+
[messages[i].role === "user" ? "human" : "ai"]: messages[i].content,
15+
[messages[i + 1].role === "user" ? "human" : "ai"]:
16+
messages[i + 1].content,
17+
});
18+
}
19+
20+
return groupedMessages;
21+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
export interface OpenaiRequestType {
2+
Body: {
3+
messages: {
4+
role: "user" | "assistant";
5+
content: string;
6+
}[]
7+
model: string;
8+
stream: boolean;
9+
temperature: number;
10+
}
11+
}

0 commit comments

Comments
 (0)