Skip to content

Commit 9df24e5

Browse files
authored
Merge pull request #5996 from ChatGPTNextWeb/feature/cogview
Feature/cogview
2 parents e467ce0 + bc322be commit 9df24e5

File tree

7 files changed

+169
-28
lines changed

7 files changed

+169
-28
lines changed

app/client/platforms/glm.ts

Lines changed: 112 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,103 @@ import { getMessageTextContent } from "@/app/utils";
2525
import { RequestPayload } from "./openai";
2626
import { fetch } from "@/app/utils/stream";
2727

28+
interface BasePayload {
29+
model: string;
30+
}
31+
32+
interface ChatPayload extends BasePayload {
33+
messages: ChatOptions["messages"];
34+
stream?: boolean;
35+
temperature?: number;
36+
presence_penalty?: number;
37+
frequency_penalty?: number;
38+
top_p?: number;
39+
}
40+
41+
interface ImageGenerationPayload extends BasePayload {
42+
prompt: string;
43+
size?: string;
44+
user_id?: string;
45+
}
46+
47+
interface VideoGenerationPayload extends BasePayload {
48+
prompt: string;
49+
duration?: number;
50+
resolution?: string;
51+
user_id?: string;
52+
}
53+
54+
type ModelType = "chat" | "image" | "video";
55+
2856
export class ChatGLMApi implements LLMApi {
2957
private disableListModels = true;
3058

59+
private getModelType(model: string): ModelType {
60+
if (model.startsWith("cogview-")) return "image";
61+
if (model.startsWith("cogvideo-")) return "video";
62+
return "chat";
63+
}
64+
65+
private getModelPath(type: ModelType): string {
66+
switch (type) {
67+
case "image":
68+
return ChatGLM.ImagePath;
69+
case "video":
70+
return ChatGLM.VideoPath;
71+
default:
72+
return ChatGLM.ChatPath;
73+
}
74+
}
75+
76+
private createPayload(
77+
messages: ChatOptions["messages"],
78+
modelConfig: any,
79+
options: ChatOptions,
80+
): BasePayload {
81+
const modelType = this.getModelType(modelConfig.model);
82+
const lastMessage = messages[messages.length - 1];
83+
const prompt =
84+
typeof lastMessage.content === "string"
85+
? lastMessage.content
86+
: lastMessage.content.map((c) => c.text).join("\n");
87+
88+
switch (modelType) {
89+
case "image":
90+
return {
91+
model: modelConfig.model,
92+
prompt,
93+
size: options.config.size,
94+
} as ImageGenerationPayload;
95+
default:
96+
return {
97+
messages,
98+
stream: options.config.stream,
99+
model: modelConfig.model,
100+
temperature: modelConfig.temperature,
101+
presence_penalty: modelConfig.presence_penalty,
102+
frequency_penalty: modelConfig.frequency_penalty,
103+
top_p: modelConfig.top_p,
104+
} as ChatPayload;
105+
}
106+
}
107+
108+
private parseResponse(modelType: ModelType, json: any): string {
109+
switch (modelType) {
110+
case "image": {
111+
const imageUrl = json.data?.[0]?.url;
112+
return imageUrl ? `![Generated Image](${imageUrl})` : "";
113+
}
114+
case "video": {
115+
const videoUrl = json.data?.[0]?.url;
116+
return videoUrl ? `<video controls src="${videoUrl}"></video>` : "";
117+
}
118+
default:
119+
return this.extractMessage(json);
120+
}
121+
}
122+
31123
path(path: string): string {
32124
const accessStore = useAccessStore.getState();
33-
34125
let baseUrl = "";
35126

36127
if (accessStore.useCustomConfig) {
@@ -51,7 +142,6 @@ export class ChatGLMApi implements LLMApi {
51142
}
52143

53144
console.log("[Proxy Endpoint] ", baseUrl, path);
54-
55145
return [baseUrl, path].join("/");
56146
}
57147

@@ -79,53 +169,55 @@ export class ChatGLMApi implements LLMApi {
79169
},
80170
};
81171

82-
const requestPayload: RequestPayload = {
83-
messages,
84-
stream: options.config.stream,
85-
model: modelConfig.model,
86-
temperature: modelConfig.temperature,
87-
presence_penalty: modelConfig.presence_penalty,
88-
frequency_penalty: modelConfig.frequency_penalty,
89-
top_p: modelConfig.top_p,
90-
};
172+
const modelType = this.getModelType(modelConfig.model);
173+
const requestPayload = this.createPayload(messages, modelConfig, options);
174+
const path = this.path(this.getModelPath(modelType));
91175

92-
console.log("[Request] glm payload: ", requestPayload);
176+
console.log(`[Request] glm ${modelType} payload: `, requestPayload);
93177

94-
const shouldStream = !!options.config.stream;
95178
const controller = new AbortController();
96179
options.onController?.(controller);
97180

98181
try {
99-
const chatPath = this.path(ChatGLM.ChatPath);
100182
const chatPayload = {
101183
method: "POST",
102184
body: JSON.stringify(requestPayload),
103185
signal: controller.signal,
104186
headers: getHeaders(),
105187
};
106188

107-
// make a fetch request
108189
const requestTimeoutId = setTimeout(
109190
() => controller.abort(),
110191
REQUEST_TIMEOUT_MS,
111192
);
112193

194+
if (modelType === "image" || modelType === "video") {
195+
const res = await fetch(path, chatPayload);
196+
clearTimeout(requestTimeoutId);
197+
198+
const resJson = await res.json();
199+
console.log(`[Response] glm ${modelType}:`, resJson);
200+
const message = this.parseResponse(modelType, resJson);
201+
options.onFinish(message, res);
202+
return;
203+
}
204+
205+
const shouldStream = !!options.config.stream;
113206
if (shouldStream) {
114207
const [tools, funcs] = usePluginStore
115208
.getState()
116209
.getAsTools(
117210
useChatStore.getState().currentSession().mask?.plugin || [],
118211
);
119212
return stream(
120-
chatPath,
213+
path,
121214
requestPayload,
122215
getHeaders(),
123216
tools as any,
124217
funcs,
125218
controller,
126219
// parseSSE
127220
(text: string, runTools: ChatMessageTool[]) => {
128-
// console.log("parseSSE", text, runTools);
129221
const json = JSON.parse(text);
130222
const choices = json.choices as Array<{
131223
delta: {
@@ -154,7 +246,7 @@ export class ChatGLMApi implements LLMApi {
154246
}
155247
return choices[0]?.delta?.content;
156248
},
157-
// processToolMessage, include tool_calls message and tool call results
249+
// processToolMessage
158250
(
159251
requestPayload: RequestPayload,
160252
toolCallMessage: any,
@@ -172,7 +264,7 @@ export class ChatGLMApi implements LLMApi {
172264
options,
173265
);
174266
} else {
175-
const res = await fetch(chatPath, chatPayload);
267+
const res = await fetch(path, chatPayload);
176268
clearTimeout(requestTimeoutId);
177269

178270
const resJson = await res.json();
@@ -184,6 +276,7 @@ export class ChatGLMApi implements LLMApi {
184276
options.onError?.(e as Error);
185277
}
186278
}
279+
187280
async usage() {
188281
return {
189282
used: 0,

app/client/platforms/openai.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {
2424
stream,
2525
} from "@/app/utils/chat";
2626
import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare";
27-
import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing";
27+
import { ModelSize, DalleQuality, DalleStyle } from "@/app/typing";
2828

2929
import {
3030
ChatOptions,
@@ -73,7 +73,7 @@ export interface DalleRequestPayload {
7373
prompt: string;
7474
response_format: "url" | "b64_json";
7575
n: number;
76-
size: DalleSize;
76+
size: ModelSize;
7777
quality: DalleQuality;
7878
style: DalleStyle;
7979
}

app/components/chat.tsx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,14 +72,16 @@ import {
7272
isDalle3,
7373
showPlugins,
7474
safeLocalStorage,
75+
getModelSizes,
76+
supportsCustomSize,
7577
} from "../utils";
7678

7779
import { uploadImage as uploadImageRemote } from "@/app/utils/chat";
7880

7981
import dynamic from "next/dynamic";
8082

8183
import { ChatControllerPool } from "../client/controller";
82-
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
84+
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
8385
import { Prompt, usePromptStore } from "../store/prompt";
8486
import Locale from "../locales";
8587

@@ -519,10 +521,11 @@ export function ChatActions(props: {
519521
const [showSizeSelector, setShowSizeSelector] = useState(false);
520522
const [showQualitySelector, setShowQualitySelector] = useState(false);
521523
const [showStyleSelector, setShowStyleSelector] = useState(false);
522-
const dalle3Sizes: DalleSize[] = ["1024x1024", "1792x1024", "1024x1792"];
524+
const modelSizes = getModelSizes(currentModel);
523525
const dalle3Qualitys: DalleQuality[] = ["standard", "hd"];
524526
const dalle3Styles: DalleStyle[] = ["vivid", "natural"];
525-
const currentSize = session.mask.modelConfig?.size ?? "1024x1024";
527+
const currentSize =
528+
session.mask.modelConfig?.size ?? ("1024x1024" as ModelSize);
526529
const currentQuality = session.mask.modelConfig?.quality ?? "standard";
527530
const currentStyle = session.mask.modelConfig?.style ?? "vivid";
528531

@@ -673,7 +676,7 @@ export function ChatActions(props: {
673676
/>
674677
)}
675678

676-
{isDalle3(currentModel) && (
679+
{supportsCustomSize(currentModel) && (
677680
<ChatAction
678681
onClick={() => setShowSizeSelector(true)}
679682
text={currentSize}
@@ -684,7 +687,7 @@ export function ChatActions(props: {
684687
{showSizeSelector && (
685688
<Selector
686689
defaultSelectedValue={currentSize}
687-
items={dalle3Sizes.map((m) => ({
690+
items={modelSizes.map((m) => ({
688691
title: m,
689692
value: m,
690693
}))}

app/constant.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,8 @@ export const XAI = {
233233
export const ChatGLM = {
234234
ExampleEndpoint: CHATGLM_BASE_URL,
235235
ChatPath: "api/paas/v4/chat/completions",
236+
ImagePath: "api/paas/v4/images/generations",
237+
VideoPath: "api/paas/v4/videos/generations",
236238
};
237239

238240
export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang
@@ -431,6 +433,15 @@ const chatglmModels = [
431433
"glm-4-long",
432434
"glm-4-flashx",
433435
"glm-4-flash",
436+
"glm-4v-plus",
437+
"glm-4v",
438+
"glm-4v-flash", // free
439+
"cogview-3-plus",
440+
"cogview-3",
441+
"cogview-3-flash", // free
442+
// 目前无法适配轮询任务
443+
// "cogvideox",
444+
// "cogvideox-flash", // free
434445
];
435446

436447
let seq = 1000; // 内置的模型序号生成器从1000开始

app/store/config.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { LLMModel } from "../client/api";
2-
import { DalleSize, DalleQuality, DalleStyle } from "../typing";
2+
import { DalleQuality, DalleStyle, ModelSize } from "../typing";
33
import { getClientConfig } from "../config/client";
44
import {
55
DEFAULT_INPUT_TEMPLATE,
@@ -78,7 +78,7 @@ export const DEFAULT_CONFIG = {
7878
compressProviderName: "",
7979
enableInjectSystemPrompts: true,
8080
template: config?.template ?? DEFAULT_INPUT_TEMPLATE,
81-
size: "1024x1024" as DalleSize,
81+
size: "1024x1024" as ModelSize,
8282
quality: "standard" as DalleQuality,
8383
style: "vivid" as DalleStyle,
8484
},

app/typing.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,14 @@ export interface RequestMessage {
1111
export type DalleSize = "1024x1024" | "1792x1024" | "1024x1792";
1212
export type DalleQuality = "standard" | "hd";
1313
export type DalleStyle = "vivid" | "natural";
14+
15+
export type ModelSize =
16+
| "1024x1024"
17+
| "1792x1024"
18+
| "1024x1792"
19+
| "768x1344"
20+
| "864x1152"
21+
| "1344x768"
22+
| "1152x864"
23+
| "1440x720"
24+
| "720x1440";

app/utils.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { ServiceProvider } from "./constant";
77
import { fetch as tauriStreamFetch } from "./utils/stream";
88
import { VISION_MODEL_REGEXES, EXCLUDE_VISION_MODEL_REGEXES } from "./constant";
99
import { getClientConfig } from "./config/client";
10+
import { ModelSize } from "./typing";
1011

1112
export function trimTopic(topic: string) {
1213
// Fix an issue where double quotes still show in the Indonesian language
@@ -271,6 +272,28 @@ export function isDalle3(model: string) {
271272
return "dall-e-3" === model;
272273
}
273274

275+
export function getModelSizes(model: string): ModelSize[] {
276+
if (isDalle3(model)) {
277+
return ["1024x1024", "1792x1024", "1024x1792"];
278+
}
279+
if (model.toLowerCase().includes("cogview")) {
280+
return [
281+
"1024x1024",
282+
"768x1344",
283+
"864x1152",
284+
"1344x768",
285+
"1152x864",
286+
"1440x720",
287+
"720x1440",
288+
];
289+
}
290+
return [];
291+
}
292+
293+
export function supportsCustomSize(model: string): boolean {
294+
return getModelSizes(model).length > 0;
295+
}
296+
274297
export function showPlugins(provider: ServiceProvider, model: string) {
275298
if (
276299
provider == ServiceProvider.OpenAI ||

0 commit comments

Comments
 (0)