17
17
***********************************************************************/
18
18
import {
19
19
containerEngine ,
20
- provider ,
21
20
type Webview ,
22
21
type TelemetryLogger ,
23
- type ImageInfo ,
24
22
type ContainerInfo ,
25
23
type ContainerInspectInfo ,
26
- type ProviderContainerConnection ,
27
24
} from '@podman-desktop/api' ;
28
25
import type { ContainerRegistry } from '../../registries/ContainerRegistry' ;
29
26
import type { PodmanConnection } from '../podmanConnection' ;
30
27
import { beforeEach , expect , describe , test , vi } from 'vitest' ;
31
28
import { InferenceManager } from './inferenceManager' ;
32
29
import type { ModelsManager } from '../modelsManager' ;
33
- import { LABEL_INFERENCE_SERVER , INFERENCE_SERVER_IMAGE } from '../../utils/inferenceUtils' ;
30
+ import { LABEL_INFERENCE_SERVER } from '../../utils/inferenceUtils' ;
34
31
import type { InferenceServerConfig } from '@shared/src/models/InferenceServerConfig' ;
35
32
import type { TaskRegistry } from '../../registries/TaskRegistry' ;
36
33
import { Messages } from '@shared/Messages' ;
34
+ import type { InferenceProviderRegistry } from '../../registries/InferenceProviderRegistry' ;
35
+ import type { InferenceProvider } from '../../workers/provider/InferenceProvider' ;
37
36
38
37
vi . mock ( '@podman-desktop/api' , async ( ) => {
39
38
return {
40
39
containerEngine : {
41
40
startContainer : vi . fn ( ) ,
42
41
stopContainer : vi . fn ( ) ,
43
- listContainers : vi . fn ( ) ,
44
42
inspectContainer : vi . fn ( ) ,
45
- pullImage : vi . fn ( ) ,
46
- listImages : vi . fn ( ) ,
47
- createContainer : vi . fn ( ) ,
48
43
deleteContainer : vi . fn ( ) ,
44
+ listContainers : vi . fn ( ) ,
49
45
} ,
50
46
Disposable : {
51
47
from : vi . fn ( ) ,
52
48
create : vi . fn ( ) ,
53
49
} ,
54
- provider : {
55
- getContainerConnections : vi . fn ( ) ,
56
- } ,
57
50
} ;
58
51
} ) ;
59
52
@@ -87,6 +80,11 @@ const taskRegistryMock = {
87
80
getTasksByLabels : vi . fn ( ) ,
88
81
} as unknown as TaskRegistry ;
89
82
83
+ const inferenceProviderRegistryMock = {
84
+ getAll : vi . fn ( ) ,
85
+ get : vi . fn ( ) ,
86
+ } as unknown as InferenceProviderRegistry ;
87
+
90
88
const getInitializedInferenceManager = async ( ) : Promise < InferenceManager > => {
91
89
const manager = new InferenceManager (
92
90
webviewMock ,
@@ -95,6 +93,7 @@ const getInitializedInferenceManager = async (): Promise<InferenceManager> => {
95
93
modelsManager ,
96
94
telemetryMock ,
97
95
taskRegistryMock ,
96
+ inferenceProviderRegistryMock ,
98
97
) ;
99
98
manager . init ( ) ;
100
99
await vi . waitUntil ( manager . isInitialize . bind ( manager ) , {
@@ -119,26 +118,6 @@ beforeEach(() => {
119
118
Health : undefined ,
120
119
} ,
121
120
} as unknown as ContainerInspectInfo ) ;
122
- vi . mocked ( provider . getContainerConnections ) . mockReturnValue ( [
123
- {
124
- providerId : 'test@providerId' ,
125
- connection : {
126
- type : 'podman' ,
127
- name : 'test@connection' ,
128
- status : ( ) => 'started' ,
129
- } ,
130
- } as unknown as ProviderContainerConnection ,
131
- ] ) ;
132
- vi . mocked ( containerEngine . listImages ) . mockResolvedValue ( [
133
- {
134
- Id : 'dummyImageId' ,
135
- engineId : 'dummyEngineId' ,
136
- RepoTags : [ INFERENCE_SERVER_IMAGE ] ,
137
- } ,
138
- ] as unknown as ImageInfo [ ] ) ;
139
- vi . mocked ( containerEngine . createContainer ) . mockResolvedValue ( {
140
- id : 'dummyCreatedContainerId' ,
141
- } ) ;
142
121
vi . mocked ( taskRegistryMock . getTasksByLabels ) . mockReturnValue ( [ ] ) ;
143
122
vi . mocked ( modelsManager . getLocalModelPath ) . mockReturnValue ( '/local/model.guff' ) ;
144
123
vi . mocked ( modelsManager . uploadModelToPodmanMachine ) . mockResolvedValue ( '/mnt/path/model.guff' ) ;
@@ -233,119 +212,59 @@ describe('init Inference Manager', () => {
233
212
* Testing the creation logic
234
213
*/
235
214
describe ( 'Create Inference Server' , ( ) => {
236
- test ( 'unknown providerId' , async ( ) => {
237
- const inferenceManager = await getInitializedInferenceManager ( ) ;
238
- await expect (
239
- inferenceManager . createInferenceServer (
240
- {
241
- providerId : 'unknown' ,
242
- } as unknown as InferenceServerConfig ,
243
- 'dummyTrackingId' ,
244
- ) ,
245
- ) . rejects . toThrowError ( 'cannot find any started container provider.' ) ;
215
+ test ( 'no provider available should throw an error' , async ( ) => {
216
+ vi . mocked ( inferenceProviderRegistryMock . getAll ) . mockReturnValue ( [ ] ) ;
246
217
247
- expect ( provider . getContainerConnections ) . toHaveBeenCalled ( ) ;
248
- } ) ;
249
-
250
- test ( 'unknown imageId' , async ( ) => {
251
218
const inferenceManager = await getInitializedInferenceManager ( ) ;
252
219
await expect (
253
- inferenceManager . createInferenceServer (
254
- {
255
- providerId : 'test@providerId' ,
256
- image : 'unknown' ,
257
- } as unknown as InferenceServerConfig ,
258
- 'dummyTrackingId' ,
259
- ) ,
260
- ) . rejects . toThrowError ( 'image unknown not found.' ) ;
261
-
262
- expect ( containerEngine . listImages ) . toHaveBeenCalled ( ) ;
220
+ inferenceManager . createInferenceServer ( {
221
+ inferenceProvider : undefined ,
222
+ labels : { } ,
223
+ modelsInfo : [ ] ,
224
+ port : 8888 ,
225
+ } ) ,
226
+ ) . rejects . toThrowError ( 'no enabled provider could be found.' ) ;
263
227
} ) ;
264
228
265
- test ( 'empty modelsInfo' , async ( ) => {
229
+ test ( 'inference provider provided should use get from InferenceProviderRegistry' , async ( ) => {
230
+ vi . mocked ( inferenceProviderRegistryMock . get ) . mockReturnValue ( {
231
+ enabled : ( ) => false ,
232
+ } as unknown as InferenceProvider ) ;
233
+
266
234
const inferenceManager = await getInitializedInferenceManager ( ) ;
267
235
await expect (
268
- inferenceManager . createInferenceServer (
269
- {
270
- providerId : 'test@providerId' ,
271
- image : INFERENCE_SERVER_IMAGE ,
272
- modelsInfo : [ ] ,
273
- } as unknown as InferenceServerConfig ,
274
- 'dummyTrackingId' ,
275
- ) ,
276
- ) . rejects . toThrowError ( 'Need at least one model info to start an inference server.' ) ;
236
+ inferenceManager . createInferenceServer ( {
237
+ inferenceProvider : 'dummy-inference-provider' ,
238
+ labels : { } ,
239
+ modelsInfo : [ ] ,
240
+ port : 8888 ,
241
+ } ) ,
242
+ ) . rejects . toThrowError ( 'provider requested is not enabled.' ) ;
243
+ expect ( inferenceProviderRegistryMock . get ) . toHaveBeenCalledWith ( 'dummy-inference-provider' ) ;
277
244
} ) ;
278
245
279
- test ( 'valid InferenceServerConfig' , async ( ) => {
246
+ test ( 'selected inference provider should receive config' , async ( ) => {
247
+ const provider : InferenceProvider = {
248
+ enabled : ( ) => true ,
249
+ name : 'dummy-inference-provider' ,
250
+ dispose : ( ) => { } ,
251
+ perform : vi . fn ( ) . mockResolvedValue ( { id : 'dummy-container-id' , engineId : 'dummy-engine-id' } ) ,
252
+ } as unknown as InferenceProvider ;
253
+ vi . mocked ( inferenceProviderRegistryMock . get ) . mockReturnValue ( provider ) ;
254
+
280
255
const inferenceManager = await getInitializedInferenceManager ( ) ;
281
- await inferenceManager . createInferenceServer (
282
- {
283
- port : 8888 ,
284
- providerId : 'test@providerId' ,
285
- image : INFERENCE_SERVER_IMAGE ,
286
- modelsInfo : [
287
- {
288
- id : 'dummyModelId' ,
289
- file : {
290
- file : 'model.guff' ,
291
- path : '/mnt/path' ,
292
- } ,
293
- } ,
294
- ] ,
295
- } as unknown as InferenceServerConfig ,
296
- 'dummyTrackingId' ,
297
- ) ;
298
256
299
- expect ( modelsManager . uploadModelToPodmanMachine ) . toHaveBeenCalledWith (
300
- {
301
- id : 'dummyModelId' ,
302
- file : {
303
- file : 'model.guff' ,
304
- path : '/mnt/path' ,
305
- } ,
306
- } ,
307
- {
308
- trackingId : 'dummyTrackingId' ,
309
- } ,
310
- ) ;
311
- expect ( taskRegistryMock . createTask ) . toHaveBeenNthCalledWith (
312
- 1 ,
313
- expect . stringContaining (
314
- 'Pulling ghcr.io/containers/podman-desktop-extension-ai-lab-playground-images/ai-lab-playground-chat:' ,
315
- ) ,
316
- 'loading' ,
317
- {
318
- trackingId : 'dummyTrackingId' ,
319
- } ,
320
- ) ;
321
- expect ( taskRegistryMock . createTask ) . toHaveBeenNthCalledWith ( 2 , 'Creating container.' , 'loading' , {
322
- trackingId : 'dummyTrackingId' ,
323
- } ) ;
324
- expect ( taskRegistryMock . updateTask ) . toHaveBeenLastCalledWith ( {
325
- state : 'success' ,
326
- } ) ;
327
- expect ( containerEngine . createContainer ) . toHaveBeenCalled ( ) ;
328
- expect ( inferenceManager . getServers ( ) ) . toStrictEqual ( [
329
- {
330
- connection : {
331
- port : 8888 ,
332
- } ,
333
- container : {
334
- containerId : 'dummyCreatedContainerId' ,
335
- engineId : 'dummyEngineId' ,
336
- } ,
337
- models : [
338
- {
339
- file : {
340
- file : 'model.guff' ,
341
- path : '/mnt/path' ,
342
- } ,
343
- id : 'dummyModelId' ,
344
- } ,
345
- ] ,
346
- status : 'running' ,
347
- } ,
348
- ] ) ;
257
+ const config : InferenceServerConfig = {
258
+ inferenceProvider : 'dummy-inference-provider' ,
259
+ labels : { } ,
260
+ modelsInfo : [ ] ,
261
+ port : 8888 ,
262
+ } ;
263
+ const result = await inferenceManager . createInferenceServer ( config ) ;
264
+
265
+ expect ( provider . perform ) . toHaveBeenCalledWith ( config ) ;
266
+
267
+ expect ( result ) . toBe ( 'dummy-container-id' ) ;
349
268
} ) ;
350
269
} ) ;
351
270
@@ -511,33 +430,6 @@ describe('Request Create Inference Server', () => {
511
430
trackingId : identifier ,
512
431
} ) ;
513
432
} ) ;
514
-
515
- test ( 'Pull image error should be reflected in task registry' , async ( ) => {
516
- vi . mocked ( containerEngine . pullImage ) . mockRejectedValue ( new Error ( 'dummy pull image error' ) ) ;
517
-
518
- const inferenceManager = await getInitializedInferenceManager ( ) ;
519
- inferenceManager . requestCreateInferenceServer ( {
520
- port : 8888 ,
521
- providerId : 'test@providerId' ,
522
- image : 'quay.io/bootsy/playground:v0' ,
523
- modelsInfo : [
524
- {
525
- id : 'dummyModelId' ,
526
- file : {
527
- file : 'dummyFile' ,
528
- path : 'dummyPath' ,
529
- } ,
530
- } ,
531
- ] ,
532
- } as unknown as InferenceServerConfig ) ;
533
-
534
- await vi . waitFor ( ( ) => {
535
- expect ( taskRegistryMock . updateTask ) . toHaveBeenLastCalledWith ( {
536
- state : 'error' ,
537
- error : 'Something went wrong while trying to create an inference server Error: dummy pull image error.' ,
538
- } ) ;
539
- } ) ;
540
- } ) ;
541
433
} ) ;
542
434
543
435
describe ( 'containerRegistry events' , ( ) => {
0 commit comments