Skip to content

Commit d2ea36b

Browse files
authored
feat(inference): introducing InferenceProviders (#1161)
* feat(inference): introducing InferenceProviders Signed-off-by: axel7083 <[email protected]> * feat: improve inference provider integration Signed-off-by: axel7083 <[email protected]> * fix: compilation Signed-off-by: axel7083 <[email protected]> * fix: labels propagation Signed-off-by: axel7083 <[email protected]> * fix: error message Signed-off-by: axel7083 <[email protected]> * fix: revert to podman desktop api 1.10.3 Signed-off-by: axel7083 <[email protected]> * fix: typecheck Signed-off-by: axel7083 <[email protected]> --------- Signed-off-by: axel7083 <[email protected]>
1 parent 987938c commit d2ea36b

18 files changed

+844
-453
lines changed

packages/backend/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"xml-js": "^1.6.11"
6666
},
6767
"devDependencies": {
68-
"@podman-desktop/api": "0.0.202404101645-5d46ba5",
68+
"@podman-desktop/api": "1.10.3",
6969
"@types/js-yaml": "^4.0.9",
7070
"@types/node": "^20",
7171
"@types/postman-collection": "^3.5.10",

packages/backend/src/managers/applicationManager.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,11 @@ import { ApplicationRegistry } from '../registries/ApplicationRegistry';
4747
import type { TaskRegistry } from '../registries/TaskRegistry';
4848
import { Publisher } from '../utils/Publisher';
4949
import { isQEMUMachine } from '../utils/podman';
50-
import { SECOND } from '../utils/inferenceUtils';
5150
import { getModelPropertiesForEnvironment } from '../utils/modelsUtils';
5251
import { getRandomName } from '../utils/randomUtils';
5352
import type { BuilderManager } from './recipes/BuilderManager';
5453
import type { PodManager } from './recipes/PodManager';
54+
import { SECOND } from '../workers/provider/LlamaCppPython';
5555

5656
export const LABEL_MODEL_ID = 'ai-lab-model-id';
5757
export const LABEL_MODEL_PORTS = 'ai-lab-model-ports';

packages/backend/src/managers/inference/inferenceManager.spec.ts

Lines changed: 52 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -17,43 +17,36 @@
1717
***********************************************************************/
1818
import {
1919
containerEngine,
20-
provider,
2120
type Webview,
2221
type TelemetryLogger,
23-
type ImageInfo,
2422
type ContainerInfo,
2523
type ContainerInspectInfo,
26-
type ProviderContainerConnection,
2724
} from '@podman-desktop/api';
2825
import type { ContainerRegistry } from '../../registries/ContainerRegistry';
2926
import type { PodmanConnection } from '../podmanConnection';
3027
import { beforeEach, expect, describe, test, vi } from 'vitest';
3128
import { InferenceManager } from './inferenceManager';
3229
import type { ModelsManager } from '../modelsManager';
33-
import { LABEL_INFERENCE_SERVER, INFERENCE_SERVER_IMAGE } from '../../utils/inferenceUtils';
30+
import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils';
3431
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig';
3532
import type { TaskRegistry } from '../../registries/TaskRegistry';
3633
import { Messages } from '@shared/Messages';
34+
import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry';
35+
import type { InferenceProvider } from '../../workers/provider/InferenceProvider';
3736

3837
vi.mock('@podman-desktop/api', async () => {
3938
return {
4039
containerEngine: {
4140
startContainer: vi.fn(),
4241
stopContainer: vi.fn(),
43-
listContainers: vi.fn(),
4442
inspectContainer: vi.fn(),
45-
pullImage: vi.fn(),
46-
listImages: vi.fn(),
47-
createContainer: vi.fn(),
4843
deleteContainer: vi.fn(),
44+
listContainers: vi.fn(),
4945
},
5046
Disposable: {
5147
from: vi.fn(),
5248
create: vi.fn(),
5349
},
54-
provider: {
55-
getContainerConnections: vi.fn(),
56-
},
5750
};
5851
});
5952

@@ -87,6 +80,11 @@ const taskRegistryMock = {
8780
getTasksByLabels: vi.fn(),
8881
} as unknown as TaskRegistry;
8982

83+
const inferenceProviderRegistryMock = {
84+
getAll: vi.fn(),
85+
get: vi.fn(),
86+
} as unknown as InferenceProviderRegistry;
87+
9088
const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
9189
const manager = new InferenceManager(
9290
webviewMock,
@@ -95,6 +93,7 @@ const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
9593
modelsManager,
9694
telemetryMock,
9795
taskRegistryMock,
96+
inferenceProviderRegistryMock,
9897
);
9998
manager.init();
10099
await vi.waitUntil(manager.isInitialize.bind(manager), {
@@ -119,26 +118,6 @@ beforeEach(() => {
119118
Health: undefined,
120119
},
121120
} as unknown as ContainerInspectInfo);
122-
vi.mocked(provider.getContainerConnections).mockReturnValue([
123-
{
124-
providerId: 'test@providerId',
125-
connection: {
126-
type: 'podman',
127-
name: 'test@connection',
128-
status: () => 'started',
129-
},
130-
} as unknown as ProviderContainerConnection,
131-
]);
132-
vi.mocked(containerEngine.listImages).mockResolvedValue([
133-
{
134-
Id: 'dummyImageId',
135-
engineId: 'dummyEngineId',
136-
RepoTags: [INFERENCE_SERVER_IMAGE],
137-
},
138-
] as unknown as ImageInfo[]);
139-
vi.mocked(containerEngine.createContainer).mockResolvedValue({
140-
id: 'dummyCreatedContainerId',
141-
});
142121
vi.mocked(taskRegistryMock.getTasksByLabels).mockReturnValue([]);
143122
vi.mocked(modelsManager.getLocalModelPath).mockReturnValue('/local/model.guff');
144123
vi.mocked(modelsManager.uploadModelToPodmanMachine).mockResolvedValue('/mnt/path/model.guff');
@@ -233,119 +212,59 @@ describe('init Inference Manager', () => {
233212
* Testing the creation logic
234213
*/
235214
describe('Create Inference Server', () => {
236-
test('unknown providerId', async () => {
237-
const inferenceManager = await getInitializedInferenceManager();
238-
await expect(
239-
inferenceManager.createInferenceServer(
240-
{
241-
providerId: 'unknown',
242-
} as unknown as InferenceServerConfig,
243-
'dummyTrackingId',
244-
),
245-
).rejects.toThrowError('cannot find any started container provider.');
215+
test('no provider available should throw an error', async () => {
216+
vi.mocked(inferenceProviderRegistryMock.getAll).mockReturnValue([]);
246217

247-
expect(provider.getContainerConnections).toHaveBeenCalled();
248-
});
249-
250-
test('unknown imageId', async () => {
251218
const inferenceManager = await getInitializedInferenceManager();
252219
await expect(
253-
inferenceManager.createInferenceServer(
254-
{
255-
providerId: 'test@providerId',
256-
image: 'unknown',
257-
} as unknown as InferenceServerConfig,
258-
'dummyTrackingId',
259-
),
260-
).rejects.toThrowError('image unknown not found.');
261-
262-
expect(containerEngine.listImages).toHaveBeenCalled();
220+
inferenceManager.createInferenceServer({
221+
inferenceProvider: undefined,
222+
labels: {},
223+
modelsInfo: [],
224+
port: 8888,
225+
}),
226+
).rejects.toThrowError('no enabled provider could be found.');
263227
});
264228

265-
test('empty modelsInfo', async () => {
229+
test('inference provider provided should use get from InferenceProviderRegistry', async () => {
230+
vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue({
231+
enabled: () => false,
232+
} as unknown as InferenceProvider);
233+
266234
const inferenceManager = await getInitializedInferenceManager();
267235
await expect(
268-
inferenceManager.createInferenceServer(
269-
{
270-
providerId: 'test@providerId',
271-
image: INFERENCE_SERVER_IMAGE,
272-
modelsInfo: [],
273-
} as unknown as InferenceServerConfig,
274-
'dummyTrackingId',
275-
),
276-
).rejects.toThrowError('Need at least one model info to start an inference server.');
236+
inferenceManager.createInferenceServer({
237+
inferenceProvider: 'dummy-inference-provider',
238+
labels: {},
239+
modelsInfo: [],
240+
port: 8888,
241+
}),
242+
).rejects.toThrowError('provider requested is not enabled.');
243+
expect(inferenceProviderRegistryMock.get).toHaveBeenCalledWith('dummy-inference-provider');
277244
});
278245

279-
test('valid InferenceServerConfig', async () => {
246+
test('selected inference provider should receive config', async () => {
247+
const provider: InferenceProvider = {
248+
enabled: () => true,
249+
name: 'dummy-inference-provider',
250+
dispose: () => {},
251+
perform: vi.fn().mockResolvedValue({ id: 'dummy-container-id', engineId: 'dummy-engine-id' }),
252+
} as unknown as InferenceProvider;
253+
vi.mocked(inferenceProviderRegistryMock.get).mockReturnValue(provider);
254+
280255
const inferenceManager = await getInitializedInferenceManager();
281-
await inferenceManager.createInferenceServer(
282-
{
283-
port: 8888,
284-
providerId: 'test@providerId',
285-
image: INFERENCE_SERVER_IMAGE,
286-
modelsInfo: [
287-
{
288-
id: 'dummyModelId',
289-
file: {
290-
file: 'model.guff',
291-
path: '/mnt/path',
292-
},
293-
},
294-
],
295-
} as unknown as InferenceServerConfig,
296-
'dummyTrackingId',
297-
);
298256

299-
expect(modelsManager.uploadModelToPodmanMachine).toHaveBeenCalledWith(
300-
{
301-
id: 'dummyModelId',
302-
file: {
303-
file: 'model.guff',
304-
path: '/mnt/path',
305-
},
306-
},
307-
{
308-
trackingId: 'dummyTrackingId',
309-
},
310-
);
311-
expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith(
312-
1,
313-
expect.stringContaining(
314-
'Pulling ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:',
315-
),
316-
'loading',
317-
{
318-
trackingId: 'dummyTrackingId',
319-
},
320-
);
321-
expect(taskRegistryMock.createTask).toHaveBeenNthCalledWith(2, 'Creating container.', 'loading', {
322-
trackingId: 'dummyTrackingId',
323-
});
324-
expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({
325-
state: 'success',
326-
});
327-
expect(containerEngine.createContainer).toHaveBeenCalled();
328-
expect(inferenceManager.getServers()).toStrictEqual([
329-
{
330-
connection: {
331-
port: 8888,
332-
},
333-
container: {
334-
containerId: 'dummyCreatedContainerId',
335-
engineId: 'dummyEngineId',
336-
},
337-
models: [
338-
{
339-
file: {
340-
file: 'model.guff',
341-
path: '/mnt/path',
342-
},
343-
id: 'dummyModelId',
344-
},
345-
],
346-
status: 'running',
347-
},
348-
]);
257+
const config: InferenceServerConfig = {
258+
inferenceProvider: 'dummy-inference-provider',
259+
labels: {},
260+
modelsInfo: [],
261+
port: 8888,
262+
};
263+
const result = await inferenceManager.createInferenceServer(config);
264+
265+
expect(provider.perform).toHaveBeenCalledWith(config);
266+
267+
expect(result).toBe('dummy-container-id');
349268
});
350269
});
351270

@@ -511,33 +430,6 @@ describe('Request Create Inference Server', () => {
511430
trackingId: identifier,
512431
});
513432
});
514-
515-
test('Pull image error should be reflected in task registry', async () => {
516-
vi.mocked(containerEngine.pullImage).mockRejectedValue(new Error('dummy pull image error'));
517-
518-
const inferenceManager = await getInitializedInferenceManager();
519-
inferenceManager.requestCreateInferenceServer({
520-
port: 8888,
521-
providerId: 'test@providerId',
522-
image: 'quay.io/bootsy/playground:v0',
523-
modelsInfo: [
524-
{
525-
id: 'dummyModelId',
526-
file: {
527-
file: 'dummyFile',
528-
path: 'dummyPath',
529-
},
530-
},
531-
],
532-
} as unknown as InferenceServerConfig);
533-
534-
await vi.waitFor(() => {
535-
expect(taskRegistryMock.updateTask).toHaveBeenLastCalledWith({
536-
state: 'error',
537-
error: 'Something went wrong while trying to create an inference server Error: dummy pull image error.',
538-
});
539-
});
540-
});
541433
});
542434

543435
describe('containerRegistry events', () => {

0 commit comments

Comments
 (0)