diff --git a/packages/backend/src/assets/webui.db b/packages/backend/src/assets/webui.db new file mode 100644 index 000000000..0f335a153 Binary files /dev/null and b/packages/backend/src/assets/webui.db differ diff --git a/packages/backend/src/managers/playgroundV2Manager.spec.ts b/packages/backend/src/managers/playgroundV2Manager.spec.ts deleted file mode 100644 index c354f2395..000000000 --- a/packages/backend/src/managers/playgroundV2Manager.spec.ts +++ /dev/null @@ -1,778 +0,0 @@ -/********************************************************************** - * Copyright (C) 2024 Red Hat, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ***********************************************************************/ - -import { expect, test, vi, beforeEach, afterEach, describe } from 'vitest'; -import OpenAI from 'openai'; -import { PlaygroundV2Manager } from './playgroundV2Manager'; -import type { TelemetryLogger, Webview } from '@podman-desktop/api'; -import type { InferenceServer } from '@shared/src/models/IInference'; -import type { InferenceManager } from './inference/inferenceManager'; -import { Messages } from '@shared/Messages'; -import type { ModelInfo } from '@shared/src/models/IModelInfo'; -import type { TaskRegistry } from '../registries/TaskRegistry'; -import type { Task, TaskState } from '@shared/src/models/ITask'; -import type { ChatMessage, ErrorMessage } from '@shared/src/models/IPlaygroundMessage'; -import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry'; - -vi.mock('openai', () => ({ - default: vi.fn(), -})); - -const webviewMock = { - postMessage: vi.fn(), -} as unknown as Webview; - -const inferenceManagerMock = { - get: vi.fn(), - getServers: vi.fn(), - createInferenceServer: vi.fn(), - startInferenceServer: vi.fn(), -} as unknown as InferenceManager; - -const taskRegistryMock = { - createTask: vi.fn(), - getTasksByLabels: vi.fn(), - updateTask: vi.fn(), -} as unknown as TaskRegistry; - -const telemetryMock = { - logUsage: vi.fn(), - logError: vi.fn(), -} as unknown as TelemetryLogger; - -const cancellationTokenRegistryMock = { - createCancellationTokenSource: vi.fn(), - delete: vi.fn(), -} as unknown as CancellationTokenRegistry; - -beforeEach(() => { - vi.resetAllMocks(); - vi.mocked(webviewMock.postMessage).mockResolvedValue(true); - vi.useFakeTimers(); -}); - -afterEach(() => { - vi.useRealTimers(); -}); - -test('manager should be properly initialized', () => { - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); -}); - -test('submit should throw an error if the server is stopped', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'model1' } as ModelInfo, 'tracking-1'); - - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'stopped', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - - await expect(manager.submit(manager.getConversations()[0].id, 'dummyUserInput')).rejects.toThrowError( - 'Inference server is not running.', - ); -}); - -test('submit should throw an error if the server is unhealthy', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'unhealthy', - }, - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('p1', { id: 'model1' } as ModelInfo, 'tracking-1'); - const playgroundId = manager.getConversations()[0].id; - await expect(manager.submit(playgroundId, 'dummyUserInput')).rejects.toThrowError( - 'Inference server is not healthy, currently status: unhealthy.', - ); -}); - -test('create playground should create conversation.', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); - await manager.createPlayground('playground 1', { id: 'model-1' } as ModelInfo, 'tracking-1'); - - const conversations = manager.getConversations(); - expect(conversations.length).toBe(1); -}); - -test('valid submit should create IPlaygroundMessage and notify the webview', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const playgrounds = manager.getConversations(); - await manager.submit(playgrounds[0].id, 'dummyUserInput'); - - // Wait for assistant message to be completed - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ChatMessage).content).toBeDefined(); - }); - - const conversations = manager.getConversations(); - - expect(conversations.length).toBe(1); - expect(conversations[0].messages.length).toBe(2); - expect(conversations[0].messages[0]).toStrictEqual({ - content: 'dummyUserInput', - id: expect.anything(), - options: undefined, - role: 'user', - timestamp: expect.any(Number), - }); - expect(conversations[0].messages[1]).toStrictEqual({ - choices: undefined, - completed: expect.any(Number), - content: '', - id: expect.anything(), - role: 'assistant', - timestamp: expect.any(Number), - }); - - expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: conversations, - }); -}); - -test('submit should send options', async () => { - vi.mocked(cancellationTokenRegistryMock.createCancellationTokenSource).mockReturnValue(55); - - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const playgrounds = manager.getConversations(); - const cancellationId = await manager.submit(playgrounds[0].id, 'dummyUserInput', { - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }); - expect(cancellationId).toBe(55); - - const messages: unknown[] = [ - { - content: 'dummyUserInput', - id: expect.any(String), - role: 'user', - timestamp: expect.any(Number), - options: { - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }, - }, - ]; - expect(createMock).toHaveBeenCalledWith( - { - messages, - model: 'dummyModelFile', - stream: true, - temperature: 0.123, - max_tokens: 45, - top_p: 0.345, - }, - { - signal: expect.anything(), - }, - ); - // at the end the token must be deleted once the request is complete - await vi.waitFor(() => { - expect(cancellationTokenRegistryMock.delete).toHaveBeenCalledWith(55); - }); -}); - -test('error', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockRejectedValue('Please reduce the length of the messages or completion.'); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const playgrounds = manager.getConversations(); - await manager.submit(playgrounds[0].id, 'dummyUserInput'); - - // Wait for error message - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ErrorMessage).error).toBeDefined(); - }); - - const conversations = manager.getConversations(); - - expect(conversations.length).toBe(1); - expect(conversations[0].messages.length).toBe(2); - expect(conversations[0].messages[0]).toStrictEqual({ - content: 'dummyUserInput', - id: expect.anything(), - options: undefined, - role: 'user', - timestamp: expect.any(Number), - }); - expect(conversations[0].messages[1]).toStrictEqual({ - error: 'Please reduce the length of the messages or completion. Note: You should start a new playground.', - id: expect.anything(), - timestamp: expect.any(Number), - }); - - expect(webviewMock.postMessage).toHaveBeenLastCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: conversations, - }); -}); - -test('creating a new playground should send new playground to frontend', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(webviewMock.postMessage).toHaveBeenCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: [ - { - id: expect.anything(), - modelId: 'model-1', - name: 'a name', - messages: [], - }, - ], - }); -}); - -test('creating a new playground with no name should send new playground to frontend with generated name', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - '', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(webviewMock.postMessage).toHaveBeenCalledWith({ - id: Messages.MSG_CONVERSATIONS_UPDATE, - body: [ - { - id: expect.anything(), - modelId: 'model-1', - name: 'playground 1', - messages: [], - }, - ], - }); -}); - -test('creating a new playground with no model served should start an inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).toHaveBeenCalledWith({ - gpuLayers: expect.any(Number), - image: undefined, - providerId: undefined, - inferenceProvider: undefined, - labels: { - trackingId: 'tracking-1', - }, - modelsInfo: [ - { - id: 'model-1', - name: 'Model 1', - }, - ], - port: expect.anything(), - }); -}); - -test('creating a new playground with the model already served should not start an inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - models: [ - { - id: 'model-1', - }, - ], - }, - ] as InferenceServer[]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).not.toHaveBeenCalled(); -}); - -test('creating a new playground with the model server stopped should start the inference server', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - models: [ - { - id: 'model-1', - }, - ], - status: 'stopped', - container: { - containerId: 'container-1', - }, - }, - ] as InferenceServer[]); - const createInferenceServerMock = vi.mocked(inferenceManagerMock.createInferenceServer); - const startInferenceServerMock = vi.mocked(inferenceManagerMock.startInferenceServer); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - expect(createInferenceServerMock).not.toHaveBeenCalled(); - expect(startInferenceServerMock).toHaveBeenCalledWith('container-1'); -}); - -test('delete conversation should delete the conversation', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - expect(manager.getConversations().length).toBe(0); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - - const conversations = manager.getConversations(); - expect(conversations.length).toBe(1); - manager.deleteConversation(conversations[0].id); - expect(manager.getConversations().length).toBe(0); - expect(webviewMock.postMessage).toHaveBeenCalled(); -}); - -test('creating a new playground with an existing name shoud fail', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground( - 'a name', - { - id: 'model-1', - name: 'Model 1', - } as unknown as ModelInfo, - 'tracking-1', - ); - await expect( - manager.createPlayground( - 'a name', - { - id: 'model-2', - name: 'Model 2', - } as unknown as ModelInfo, - 'tracking-2', - ), - ).rejects.toThrowError('a playground with the name a name already exists'); -}); - -test('requestCreatePlayground should call createPlayground and createTask, then updateTask', async () => { - vi.useRealTimers(); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - const createTaskMock = vi.mocked(taskRegistryMock).createTask; - const updateTaskMock = vi.mocked(taskRegistryMock).updateTask; - createTaskMock.mockImplementation((_name: string, _state: TaskState, labels?: { [id: string]: string }) => { - return { - labels, - } as Task; - }); - const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockResolvedValue('playground-1'); - - const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo); - - expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String)); - expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', { - trackingId: id, - }); - await new Promise(resolve => setTimeout(resolve, 0)); - expect(updateTaskMock).toHaveBeenCalledWith({ - labels: { - trackingId: id, - playgroundId: 'playground-1', - }, - state: 'success', - }); -}); - -test('requestCreatePlayground should call createPlayground and createTask, then updateTask when createPlayground fails', async () => { - vi.useRealTimers(); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - const createTaskMock = vi.mocked(taskRegistryMock).createTask; - const updateTaskMock = vi.mocked(taskRegistryMock).updateTask; - const getTasksByLabelsMock = vi.mocked(taskRegistryMock).getTasksByLabels; - createTaskMock.mockImplementation((_name: string, _state: TaskState, labels?: { [id: string]: string }) => { - return { - labels, - } as Task; - }); - const createPlaygroundSpy = vi.spyOn(manager, 'createPlayground').mockRejectedValue(new Error('an error')); - - const id = await manager.requestCreatePlayground('a name', { id: 'model-1' } as ModelInfo); - - expect(createPlaygroundSpy).toHaveBeenCalledWith('a name', { id: 'model-1' } as ModelInfo, expect.any(String)); - expect(createTaskMock).toHaveBeenCalledWith('Creating Playground environment', 'loading', { - trackingId: id, - }); - - getTasksByLabelsMock.mockReturnValue([ - { - labels: { - trackingId: id, - }, - } as unknown as Task, - ]); - - await new Promise(resolve => setTimeout(resolve, 0)); - expect(updateTaskMock).toHaveBeenCalledWith({ - error: 'Something went wrong while trying to create a playground environment Error: an error.', - labels: { - trackingId: id, - }, - state: 'error', - }); -}); - -describe('system prompt', () => { - test('set system prompt on non existing conversation should throw an error', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - models: [ - { - id: 'model1', - }, - ], - } as unknown as InferenceServer, - ]); - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - - expect(() => { - manager.setSystemPrompt('invalid', 'content'); - }).toThrowError('conversation with id invalid does not exist.'); - }); - - test('set system prompt should throw an error if user already submit message', async () => { - vi.mocked(inferenceManagerMock.getServers).mockReturnValue([ - { - status: 'running', - health: { - Status: 'healthy', - }, - models: [ - { - id: 'dummyModelId', - file: { - file: 'dummyModelFile', - }, - }, - ], - connection: { - port: 8888, - }, - } as unknown as InferenceServer, - ]); - const createMock = vi.fn().mockResolvedValue([]); - vi.mocked(OpenAI).mockReturnValue({ - chat: { - completions: { - create: createMock, - }, - }, - } as unknown as OpenAI); - - const manager = new PlaygroundV2Manager( - webviewMock, - inferenceManagerMock, - taskRegistryMock, - telemetryMock, - cancellationTokenRegistryMock, - ); - await manager.createPlayground('playground 1', { id: 'dummyModelId' } as ModelInfo, 'tracking-1'); - - const date = new Date(2000, 1, 1, 13); - vi.setSystemTime(date); - - const conversations = manager.getConversations(); - await manager.submit(conversations[0].id, 'dummyUserInput'); - - // Wait for assistant message to be completed - await vi.waitFor(() => { - expect((manager.getConversations()[0].messages[1] as ChatMessage).content).toBeDefined(); - }); - - expect(() => { - manager.setSystemPrompt(manager.getConversations()[0].id, 'newSystemPrompt'); - }).toThrowError('Cannot change system prompt on started conversation.'); - }); -}); diff --git a/packages/backend/src/managers/playgroundV2Manager.ts b/packages/backend/src/managers/playgroundV2Manager.ts index 4a030e6fa..99dc13a4d 100644 --- a/packages/backend/src/managers/playgroundV2Manager.ts +++ b/packages/backend/src/managers/playgroundV2Manager.ts @@ -36,6 +36,8 @@ import { getRandomString } from '../utils/randomUtils'; import type { TaskRegistry } from '../registries/TaskRegistry'; import type { CancellationTokenRegistry } from '../registries/CancellationTokenRegistry'; import { getHash } from '../utils/sha'; +import type { ConfigurationRegistry } from '../registries/ConfigurationRegistry'; +import type { PodmanConnection } from './podmanConnection'; export class PlaygroundV2Manager implements Disposable { #conversationRegistry: ConversationRegistry; @@ -46,17 +48,24 @@ export class PlaygroundV2Manager implements Disposable { private taskRegistry: TaskRegistry, private telemetry: TelemetryLogger, private cancellationTokenRegistry: CancellationTokenRegistry, + configurationRegistry: ConfigurationRegistry, + podmanConnection: PodmanConnection, ) { - this.#conversationRegistry = new ConversationRegistry(webview); + this.#conversationRegistry = new ConversationRegistry( + webview, + configurationRegistry, + taskRegistry, + podmanConnection, + ); } - deleteConversation(conversationId: string): void { + async deleteConversation(conversationId: string): Promise { const conversation = this.#conversationRegistry.get(conversationId); this.telemetry.logUsage('playground.delete', { totalMessages: conversation.messages.length, modelId: getHash(conversation.modelId), }); - this.#conversationRegistry.deleteConversation(conversationId); + await this.#conversationRegistry.deleteConversation(conversationId); } async requestCreatePlayground(name: string, model: ModelInfo): Promise { @@ -117,11 +126,11 @@ export class PlaygroundV2Manager implements Disposable { } // Create conversation - const conversationId = this.#conversationRegistry.createConversation(name, model.id); + const conversationId = await this.#conversationRegistry.createConversation(name, model.id); // create/start inference server if necessary const servers = this.inferenceManager.getServers(); - const server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); + let server = servers.find(s => s.models.map(mi => mi.id).includes(model.id)); if (!server) { await this.inferenceManager.createInferenceServer( await withDefaultConfiguration({ @@ -131,10 +140,15 @@ export class PlaygroundV2Manager implements Disposable { }, }), ); + server = this.inferenceManager.findServerByModel(model); } else if (server.status === 'stopped') { await this.inferenceManager.startInferenceServer(server.container.containerId); } + if (server && server.status === 'running') { + await this.#conversationRegistry.startConversationContainer(server, trackingId, conversationId); + } + return conversationId; } diff --git a/packages/backend/src/registries/ConfigurationRegistry.ts b/packages/backend/src/registries/ConfigurationRegistry.ts index 19ed02f63..5fc0c0ad2 100644 --- a/packages/backend/src/registries/ConfigurationRegistry.ts +++ b/packages/backend/src/registries/ConfigurationRegistry.ts @@ -79,6 +79,10 @@ export class ConfigurationRegistry extends Publisher imp return path.join(this.appUserDirectory, 'models'); } + public getConversationsPath(): string { + return path.join(this.appUserDirectory, 'conversations'); + } + dispose(): void { this.#configurationDisposable?.dispose(); } diff --git a/packages/backend/src/registries/ConversationRegistry.ts b/packages/backend/src/registries/ConversationRegistry.ts index eab300242..1a99fb101 100644 --- a/packages/backend/src/registries/ConversationRegistry.ts +++ b/packages/backend/src/registries/ConversationRegistry.ts @@ -25,14 +25,38 @@ import type { Message, PendingChat, } from '@shared/src/models/IPlaygroundMessage'; -import type { Disposable, Webview } from '@podman-desktop/api'; +import { + type Disposable, + type Webview, + type ContainerCreateOptions, + containerEngine, + type ContainerProviderConnection, + type ImageInfo, + type PullEvent, +} from '@podman-desktop/api'; import { Messages } from '@shared/Messages'; +import type { ConfigurationRegistry } from './ConfigurationRegistry'; +import path from 'node:path'; +import fs from 'node:fs'; +import type { InferenceServer } from '@shared/src/models/IInference'; +import { getFreeRandomPort } from '../utils/ports'; +import { DISABLE_SELINUX_LABEL_SECURITY_OPTION } from '../utils/utils'; +import { getImageInfo } from '../utils/inferenceUtils'; +import type { TaskRegistry } from './TaskRegistry'; +import type { PodmanConnection } from '../managers/podmanConnection'; + +const OPEN_WEBUI_IMAGE = 'ghcr.io/open-webui/open-webui:dev'; export class ConversationRegistry extends Publisher implements Disposable { #conversations: Map; #counter: number; - constructor(webview: Webview) { + constructor( + webview: Webview, + private configurationRegistry: ConfigurationRegistry, + private taskRegistry: TaskRegistry, + private podmanConnection: PodmanConnection, + ) { super(webview, Messages.MSG_CONVERSATIONS_UPDATE, () => this.getAll()); this.#conversations = new Map(); this.#counter = 0; @@ -76,13 +100,32 @@ export class ConversationRegistry extends Publisher implements D this.notify(); } - deleteConversation(id: string): void { + async deleteConversation(id: string): Promise { + const conversation = this.get(id); + if (conversation.container) { + await containerEngine.stopContainer(conversation.container?.engineId, conversation.container?.containerId); + } + await fs.promises.rm(path.join(this.configurationRegistry.getConversationsPath(), id), { + recursive: true, + force: true, + }); this.#conversations.delete(id); this.notify(); } - createConversation(name: string, modelId: string): string { + async createConversation(name: string, modelId: string): Promise { const conversationId = this.getUniqueId(); + const conversationFolder = path.join(this.configurationRegistry.getConversationsPath(), conversationId); + await fs.promises.mkdir(conversationFolder, { + recursive: true, + }); + //WARNING: this will not work in production mode but didn't find how to embed binary assets + //this code get an initialized database so that default user is not admin thus did not get the initial + //welcome modal dialog + await fs.promises.copyFile( + path.join(__dirname, '..', 'src', 'assets', 'webui.db'), + path.join(conversationFolder, 'webui.db'), + ); this.#conversations.set(conversationId, { name: name, modelId: modelId, @@ -93,6 +136,77 @@ export class ConversationRegistry extends Publisher implements D return conversationId; } + async startConversationContainer(server: InferenceServer, trackingId: string, conversationId: string): Promise { + const conversation = this.get(conversationId); + const port = await getFreeRandomPort('127.0.0.1'); + const connection = await this.podmanConnection.getConnectionByEngineId(server.container.engineId); + await this.pullImage(connection, OPEN_WEBUI_IMAGE, { + trackingId: trackingId, + }); + const inferenceServerContainer = await containerEngine.inspectContainer( + server.container.engineId, + server.container.containerId, + ); + const options: ContainerCreateOptions = { + Env: [ + 'DEFAULT_LOCALE=en-US', + 'WEBUI_AUTH=false', + 'ENABLE_OLLAMA_API=false', + `OPENAI_API_BASE_URL=http://${inferenceServerContainer.NetworkSettings.IPAddress}:8000/v1`, + 'OPENAI_API_KEY=sk_dummy', + `WEBUI_URL=http://localhost:${port}`, + `DEFAULT_MODELS=/models/${server.models[0].file?.file}`, + ], + Image: OPEN_WEBUI_IMAGE, + HostConfig: { + AutoRemove: true, + Mounts: [ + { + Source: path.join(this.configurationRegistry.getConversationsPath(), conversationId), + Target: '/app/backend/data', + Type: 'bind', + }, + ], + PortBindings: { + '8080/tcp': [ + { + HostPort: `${port}`, + }, + ], + }, + SecurityOpt: [DISABLE_SELINUX_LABEL_SECURITY_OPTION], + }, + }; + const c = await containerEngine.createContainer(server.container.engineId, options); + conversation.container = { engineId: c.engineId, containerId: c.id, port }; + } + + protected pullImage( + connection: ContainerProviderConnection, + image: string, + labels: { [id: string]: string }, + ): Promise { + // Creating a task to follow pulling progress + const pullingTask = this.taskRegistry.createTask(`Pulling ${image}.`, 'loading', labels); + + // get the default image info for this provider + return getImageInfo(connection, image, (_event: PullEvent) => {}) + .catch((err: unknown) => { + pullingTask.state = 'error'; + pullingTask.progress = undefined; + pullingTask.error = `Something went wrong while pulling ${image}: ${String(err)}`; + throw err; + }) + .then(imageInfo => { + pullingTask.state = 'success'; + pullingTask.progress = undefined; + return imageInfo; + }) + .finally(() => { + this.taskRegistry.updateTask(pullingTask); + }); + } + /** * This method will be responsible for finalizing the message by concatenating all the choices * @param conversationId diff --git a/packages/backend/src/studio-api-impl.ts b/packages/backend/src/studio-api-impl.ts index ca10f8f48..e0d7c4827 100644 --- a/packages/backend/src/studio-api-impl.ts +++ b/packages/backend/src/studio-api-impl.ts @@ -87,9 +87,9 @@ export class StudioApiImpl implements StudioAPI { // Do not wait on the promise as the api would probably timeout before the user answer. podmanDesktopApi.window .showWarningMessage(`Are you sure you want to delete this playground ?`, 'Confirm', 'Cancel') - .then((result: string | undefined) => { + .then(async (result: string | undefined) => { if (result === 'Confirm') { - this.playgroundV2.deleteConversation(conversationId); + await this.playgroundV2.deleteConversation(conversationId); } }) .catch((err: unknown) => { diff --git a/packages/backend/src/studio.ts b/packages/backend/src/studio.ts index d1e50a4e6..6e8a792bd 100644 --- a/packages/backend/src/studio.ts +++ b/packages/backend/src/studio.ts @@ -316,6 +316,8 @@ export class Studio { this.#taskRegistry, this.#telemetry, this.#cancellationTokenRegistry, + this.#configurationRegistry, + this.#podmanConnection, ); this.#extensionContext.subscriptions.push(this.#playgroundManager); diff --git a/packages/frontend/src/pages/Playground.svelte b/packages/frontend/src/pages/Playground.svelte index 62d1a6200..dcc586f90 100644 --- a/packages/frontend/src/pages/Playground.svelte +++ b/packages/frontend/src/pages/Playground.svelte @@ -1,113 +1,26 @@ {#if conversation} @@ -188,110 +68,12 @@ function handleOnClick(): void { -
-
- - -
-
- {#if conversation} - - {#key conversation.messages.length} - - {/key} - -
    - {#each messages as message} -
  • - -
  • - {/each} -
- {/if} -
-
-
- -
Next prompt will use these settings
-
-
Model Parameters
-
-
-
- -
- - - -
- What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - more random, while lower values like 0.2 will make it more focused and deterministic. -
-
-
-
-
-
- -
- - - -
- The maximum number of tokens that can be generated in the chat completion. -
-
-
-
-
-
- -
- - - -
- An alternative to sampling with temperature, where the model considers the results of the - tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - probability mass are considered. -
-
-
-
-
-
-
-
-
- {#if errorMsg} -
{errorMsg}
- {/if} -
- - -
- {#if !sendEnabled && cancellationTokenId !== undefined} - - {/if} -
+
+
+
diff --git a/packages/shared/src/models/IPlaygroundMessage.ts b/packages/shared/src/models/IPlaygroundMessage.ts index cdebc2046..4333305ac 100644 --- a/packages/shared/src/models/IPlaygroundMessage.ts +++ b/packages/shared/src/models/IPlaygroundMessage.ts @@ -57,6 +57,11 @@ export interface Conversation { messages: Message[]; modelId: string; name: string; + container?: { + engineId: string; + containerId: string; + port: number; + }; } export interface Choice {