Skip to content

Commit 1223bd9

Browse files
committed
feat(inference): vllm support
Signed-off-by: axel7083 <[email protected]>
1 parent 76d6247 commit 1223bd9

File tree

9 files changed

+198
-7
lines changed

9 files changed

+198
-7
lines changed

packages/backend/src/assets/ai.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,13 @@
490490
},
491491
"memory": 4372811936,
492492
"backend": "llama-cpp"
493+
},
494+
{
495+
"id": "Qwen/Qwen2-VL-2B-Instruct",
496+
"name": "Qwen/Qwen2-VL-2B-Instruct",
497+
"description": "Qwen/Qwen2-VL-2B-Instruct",
498+
"url": "huggingface:/Qwen/Qwen2-VL-2B-Instruct",
499+
"backend": "vllm"
493500
}
494501
],
495502
"categories": [

packages/backend/src/assets/inference-images.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,8 @@
55
"llamacpp": {
66
"default": "quay.io/ramalama/ramalama-llama-server@sha256:cbadb36fbbc2abf9867a33e6dfe3f2df4a76774259b5d4d24d50f4fc7e525406",
77
"cuda": "quay.io/ramalama/cuda-llama-server@sha256:56efc824e5b3ae6a6a11e9537ed9e2ac05f9f9fc6f2e81a55eb67b662c94fe95"
8+
},
9+
"vllm": {
10+
"default": "public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:v0.8.4"
811
}
912
}

packages/backend/src/managers/modelsManager.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,8 @@ export class ModelsManager implements Disposable {
375375
model: ModelInfo,
376376
labels?: { [key: string]: string },
377377
): Promise<string> {
378+
console.log('[ModelsManager] upload model', model);
379+
378380
// ensure the model upload is not disabled
379381
if (this.configurationRegistry.getExtensionConfiguration().modelUploadDisabled) {
380382
console.warn('The model upload is disabled, this may cause the inference server to take a few minutes to start.');
@@ -392,6 +394,7 @@ export class ModelsManager implements Disposable {
392394

393395
// perform download
394396
const path = uploader.perform(model.id);
397+
console.log('[ModelsManager] path got', path);
395398
await this.updateModelInfos();
396399

397400
return path;

packages/backend/src/managers/playgroundV2Manager.ts

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import type { TaskRegistry } from '../registries/TaskRegistry';
3838
import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry';
3939
import { getHash } from '../utils/sha';
4040
import type { RpcExtension } from '@shared/messages/MessageProxy';
41+
import { InferenceType } from '@shared/models/IInference';
4142

4243
export class PlaygroundV2Manager implements Disposable {
4344
#conversationRegistry: ConversationRegistry;
@@ -123,8 +124,10 @@ export class PlaygroundV2Manager implements Disposable {
123124

124125
// create/start inference server if necessary
125126
const servers = this.inferenceManager.getServers();
127+
console.log('servers', servers);
126128
const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id));
127129
if (!server) {
130+
console.warn(`no server running found with modelId ${model.id}, creating new one`);
128131
await this.inferenceManager.createInferenceServer(
129132
await withDefaultConfiguration({
130133
modelsInfo: [model],
@@ -239,13 +242,18 @@ export class PlaygroundV2Manager implements Disposable {
239242
abortController.abort('cancel');
240243
});
241244

245+
const messages = this.getFormattedMessages(conversation.id);
246+
console.log('[PlaygroundV2Manager] messages', messages);
247+
console.log('[PlaygroundV2Manager] messages', options);
248+
242249
client.chat.completions
243250
.create(
244251
{
245-
messages: this.getFormattedMessages(conversation.id),
252+
messages: messages,
246253
stream: true,
247254
model: modelInfo.file.file,
248-
...options,
255+
// vllm is not compatible with options provided, only llamacpp is
256+
...(server.type === InferenceType.LLAMA_CPP ? options : {}),
249257
},
250258
{
251259
signal: abortController.signal,
@@ -333,8 +341,8 @@ export class PlaygroundV2Manager implements Disposable {
333341
.map(
334342
message =>
335343
({
336-
name: undefined,
337-
...message,
344+
role: message.role,
345+
content: message.content,
338346
}) as ChatCompletionMessageParam,
339347
);
340348
}

packages/backend/src/models/HuggingFaceModelHandler.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import type { CompletionEvent } from './baseEvent';
2323
import { getDurationSecondsSince } from '../utils/utils';
2424
import type { ModelsManager } from '../managers/modelsManager';
2525
import fs from 'node:fs/promises';
26+
import { dirname, basename } from 'node:path';
2627

2728
function parseURL(url: string): { repo: string; revision?: string } | undefined {
2829
const u = URL.parse(url);
@@ -128,8 +129,8 @@ export class HuggingFaceModelHandler extends ModelHandler {
128129
const model = hfModels.find(m => m.repo?.repo === repo.id.name && m.repo?.revision === ref);
129130
if (model) {
130131
model.model.file = {
131-
path: revision.path,
132-
file: revision.path,
132+
path: dirname(revision.path),
133+
file: basename(revision.path),
133134
creation: revision.lastModifiedAt,
134135
size: revision.size,
135136
};

packages/backend/src/studio.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ import { HuggingFaceModelHandler } from './models/HuggingFaceModelHandler';
6262
import { LlamaStackApiImpl } from './llama-stack-api-impl';
6363
import { LLAMA_STACK_API_CHANNEL, type LlamaStackAPI } from '@shared/LlamaStackAPI';
6464
import { LlamaStackManager } from './managers/llama-stack/llamaStackManager';
65+
import { VLLM } from './workers/provider/VLLM';
6566

6667
export class Studio {
6768
readonly #extensionContext: ExtensionContext;
@@ -280,6 +281,9 @@ export class Studio {
280281
this.#extensionContext.subscriptions.push(
281282
this.#inferenceProviderRegistry.register(new WhisperCpp(this.#taskRegistry, this.#podmanConnection)),
282283
);
284+
this.#extensionContext.subscriptions.push(
285+
this.#inferenceProviderRegistry.register(new VLLM(this.#taskRegistry, this.#podmanConnection)),
286+
);
283287

284288
/**
285289
* The inference manager create, stop, manage Inference servers
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/**********************************************************************
2+
* Copyright (C) 2024 Red Hat, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* SPDX-License-Identifier: Apache-2.0
17+
***********************************************************************/
18+
19+
import { InferenceProvider } from './InferenceProvider';
20+
import type { TaskRegistry } from '../../registries/TaskRegistry';
21+
import type { PodmanConnection } from '../../managers/podmanConnection';
22+
import { type InferenceServer, InferenceType } from '@shared/models/IInference';
23+
import type { InferenceServerConfig } from '@shared/models/InferenceServerConfig';
24+
import type { ContainerProviderConnection, MountConfig } from '@podman-desktop/api';
25+
import * as images from '../../assets/inference-images.json';
26+
import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
27+
import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../../utils/utils';
28+
import { basename, dirname } from 'node:path';
29+
import { join as joinposix } from 'node:path/posix';
30+
import { getLocalModelFile } from '../../utils/modelsUtils';
31+
import { SECOND } from './LlamaCppPython';
32+
33+
export class VLLM extends InferenceProvider {
34+
constructor(
35+
taskRegistry: TaskRegistry,
36+
private podmanConnection: PodmanConnection,
37+
) {
38+
super(taskRegistry, InferenceType.VLLM, 'vllm');
39+
}
40+
41+
dispose(): void {}
42+
43+
public enabled = (): boolean => true;
44+
45+
/**
46+
* Here is an example
47+
*
48+
* podman run -it --rm
49+
* -v C:\Users\axels\.cache\huggingface\hub\models--mistralai--Mistral-7B-v0.1:/cache/models--mistralai--Mistral-7B-v0.1
50+
* -e HF_HUB_CACHE=/cache
51+
* localhost/vllm-cpu-env:latest
52+
* --model=/cache/models--mistralai--Mistral-7B-v0.1/snapshots/7231864981174d9bee8c7687c24c8344414eae6b
53+
*
54+
* @param config
55+
*/
56+
override async perform(config: InferenceServerConfig): Promise<InferenceServer> {
57+
if (config.modelsInfo.length !== 1)
58+
throw new Error(`only one model is supported, received ${config.modelsInfo.length}`);
59+
60+
const modelInfo = config.modelsInfo[0];
61+
if (modelInfo.backend !== InferenceType.VLLM) {
62+
throw new Error(`VLLM requires models with backend type ${InferenceType.VLLM} got ${modelInfo.backend}.`);
63+
}
64+
65+
if (modelInfo.file === undefined) {
66+
throw new Error('The model info file provided is undefined');
67+
}
68+
69+
console.log('[VLLM]', config);
70+
console.log('[VLLM] modelInfo.file', modelInfo.file.path);
71+
72+
// something ~/.cache/huggingface/hub/models--facebook--opt-125m/snapshots
73+
// modelInfo.file.path
74+
75+
const fullPath = getLocalModelFile(modelInfo);
76+
77+
// modelInfo.file.path must be under the form $(HF_HUB_CACHE)/<repo-type>--<repo-id>/snapshots/<commit-hash>
78+
const parent = dirname(fullPath);
79+
const commitHash = basename(fullPath);
80+
const name = basename(parent);
81+
if (name !== 'snapshots') throw new Error('you must provide snapshot path for vllm');
82+
const modelCache = dirname(parent);
83+
84+
let connection: ContainerProviderConnection | undefined;
85+
if (config.connection) {
86+
connection = this.podmanConnection.getContainerProviderConnection(config.connection);
87+
} else {
88+
connection = this.podmanConnection.findRunningContainerProviderConnection();
89+
}
90+
91+
if (!connection) throw new Error('no running connection could be found');
92+
93+
const labels: Record<string, string> = {
94+
...config.labels,
95+
[LABEL_INFERENCE_SERVER]: JSON.stringify(config.modelsInfo.map(model => model.id)),
96+
};
97+
98+
const imageInfo = await this.pullImage(connection, config.image ?? images.vllm.default, labels);
99+
// https://huggingface.co/docs/transformers/main/en/installation#offline-mode
100+
// HF_HUB_OFFLINE in main
101+
// TRANSFORMERS_OFFLINE for legacy
102+
const envs: string[] = [`HF_HUB_CACHE=/cache`, 'TRANSFORMERS_OFFLINE=1', 'HF_HUB_OFFLINE=1'];
103+
104+
labels['api'] = `http://localhost:${config.port}/inference`;
105+
106+
const mounts: MountConfig = [
107+
{
108+
Target: `/cache/${modelInfo.id}`,
109+
Source: modelCache,
110+
Type: 'bind',
111+
},
112+
];
113+
114+
const containerInfo = await this.createContainer(
115+
imageInfo.engineId,
116+
{
117+
Image: imageInfo.Id,
118+
Detach: true,
119+
Labels: labels,
120+
HostConfig: {
121+
AutoRemove: false,
122+
Mounts: mounts,
123+
PortBindings: {
124+
'8000/tcp': [
125+
{
126+
HostPort: `${config.port}`,
127+
},
128+
],
129+
},
130+
SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION],
131+
},
132+
HealthCheck: {
133+
// must be the port INSIDE the container not the exposed one
134+
Test: ['CMD-SHELL', `curl -sSf localhost:8000/version > /dev/null`],
135+
Interval: SECOND * 5,
136+
Retries: 4 * 5,
137+
},
138+
Env: envs,
139+
Cmd: [
140+
`--model=${joinposix('/cache', modelInfo.id, 'snapshots', commitHash)}`,
141+
`--served_model_name=${modelInfo.file.file}`,
142+
'--chat-template-content-format=openai',
143+
],
144+
},
145+
labels,
146+
);
147+
148+
return {
149+
models: [modelInfo],
150+
status: 'running',
151+
connection: {
152+
port: config.port,
153+
},
154+
container: {
155+
containerId: containerInfo.id,
156+
engineId: containerInfo.engineId,
157+
},
158+
type: InferenceType.VLLM,
159+
labels: labels,
160+
};
161+
}
162+
}

packages/frontend/src/pages/PlaygroundCreate.svelte

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ import ModelSelect from '/@/lib/select/ModelSelect.svelte';
1616
import { InferenceType } from '@shared/models/IInference';
1717
1818
let localModels: ModelInfo[];
19-
$: localModels = $modelsInfo.filter(model => model.file && model.backend === InferenceType.LLAMA_CPP);
19+
$: localModels = $modelsInfo.filter(
20+
model => model.file && (model.backend === InferenceType.LLAMA_CPP || model.backend === InferenceType.VLLM),
21+
);
2022
$: availModels = $modelsInfo.filter(model => !model.file);
2123
let model: ModelInfo | undefined = undefined;
2224
let submitted: boolean = false;

packages/shared/src/models/IInference.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import type { ModelInfo } from './IModelInfo';
2020
export enum InferenceType {
2121
LLAMA_CPP = 'llama-cpp',
2222
WHISPER_CPP = 'whisper-cpp',
23+
VLLM = 'vllm',
2324
NONE = 'none',
2425
}
2526

0 commit comments

Comments
 (0)