diff --git a/core/src/browser/extensions/engines/AIEngine.ts b/core/src/browser/extensions/engines/AIEngine.ts index 1be9770342..8c02f128b3 100644 --- a/core/src/browser/extensions/engines/AIEngine.ts +++ b/core/src/browser/extensions/engines/AIEngine.ts @@ -1,223 +1,6 @@ import { BaseExtension } from '../../extension' import { EngineManager } from './EngineManager' -/* AIEngine class types */ - -export interface chatCompletionRequestMessage { - role: 'system' | 'user' | 'assistant' | 'tool' - content: string | null | Content[] // Content can be a string OR an array of content parts - reasoning?: string | null // Some models return reasoning in completed responses - reasoning_content?: string | null // Some models return reasoning in completed responses - name?: string - tool_calls?: any[] // Simplified tool_call_id?: string -} - -export interface Content { - type: 'text' | 'image_url' | 'input_audio' - text?: string - image_url?: string - input_audio?: InputAudio -} - -export interface InputAudio { - data: string // Base64 encoded audio data - format: 'mp3' | 'wav' | 'ogg' | 'flac' // Add more formats as needed/llama-server seems to support mp3 -} - -export interface ToolFunction { - name: string // Required: a-z, A-Z, 0-9, _, -, max length 64 - description?: string - parameters?: Record // JSON Schema object - strict?: boolean | null // Defaults to false -} - -export interface Tool { - type: 'function' // Currently, only 'function' is supported - function: ToolFunction -} - -export interface ToolCallOptions { - tools?: Tool[] -} - -// A specific tool choice to force the model to call -export interface ToolCallSpec { - type: 'function' - function: { - name: string - } -} - -// tool_choice may be one of several modes or a specific call -export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec - -export interface chatCompletionRequest { - model: string // Model ID, though for local it might be implicit via sessionInfo - messages: chatCompletionRequestMessage[] - thread_id?: string // Thread/conversation ID for context tracking - return_progress?: boolean - tools?: Tool[] - tool_choice?: ToolChoice - // Core sampling parameters - temperature?: number | null - dynatemp_range?: number | null - dynatemp_exponent?: number | null - top_k?: number | null - top_p?: number | null - min_p?: number | null - typical_p?: number | null - repeat_penalty?: number | null - repeat_last_n?: number | null - presence_penalty?: number | null - frequency_penalty?: number | null - dry_multiplier?: number | null - dry_base?: number | null - dry_allowed_length?: number | null - dry_penalty_last_n?: number | null - dry_sequence_breakers?: string[] | null - xtc_probability?: number | null - xtc_threshold?: number | null - mirostat?: number | null // 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0 - mirostat_tau?: number | null - mirostat_eta?: number | null - - n_predict?: number | null - n_indent?: number | null - n_keep?: number | null - stream?: boolean | null - stop?: string | string[] | null - seed?: number | null // RNG seed - - // Advanced sampling - logit_bias?: { [key: string]: number } | null - n_probs?: number | null - min_keep?: number | null - t_max_predict_ms?: number | null - image_data?: Array<{ data: string; id: number }> | null - - // Internal/optimization parameters - id_slot?: number | null - cache_prompt?: boolean | null - return_tokens?: boolean | null - samplers?: string[] | null - timings_per_token?: boolean | null - post_sampling_probs?: boolean | null - chat_template_kwargs?: chat_template_kdict | null -} - -export interface chat_template_kdict { - enable_thinking: false -} - -export interface chatCompletionChunkChoiceDelta { - content?: string | null - role?: 'system' | 'user' | 'assistant' | 'tool' - tool_calls?: any[] // Simplified -} - -export interface chatCompletionChunkChoice { - index: number - delta: chatCompletionChunkChoiceDelta - finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null -} - -export interface chatCompletionPromptProgress { - cache: number - processed: number - time_ms: number - total: number -} - -export interface chatCompletionChunk { - id: string - object: 'chat.completion.chunk' - created: number - model: string - choices: chatCompletionChunkChoice[] - system_fingerprint?: string - prompt_progress?: chatCompletionPromptProgress -} - -export interface chatCompletionChoice { - index: number - message: chatCompletionRequestMessage // Response message - finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' - logprobs?: any // Simplified -} - -export interface chatCompletion { - id: string - object: 'chat.completion' - created: number - model: string // Model ID used - choices: chatCompletionChoice[] - usage?: { - prompt_tokens: number - completion_tokens: number - total_tokens: number - } - system_fingerprint?: string -} -// --- End OpenAI types --- - -// Shared model metadata -export interface modelInfo { - id: string // e.g. "qwen3-4B" or "org/model/quant" - name: string // human‑readable, e.g., "Qwen3 4B Q4_0" - quant_type?: string // q4_0 (optional as it might be part of ID or name) - providerId: string // e.g. "llama.cpp" - port: number - sizeBytes: number - tags?: string[] - path?: string // Absolute path to the model file, if applicable - // Additional provider-specific metadata can be added here - [key: string]: any -} - -// 1. /list -export type listResult = modelInfo[] - -export interface SessionInfo { - pid: number // opaque handle for unload/chat - port: number // llama-server output port (corrected from portid) - model_id: string //name of the model - model_path: string // path of the loaded model - is_embedding: boolean - api_key: string - mmproj_path?: string -} - -export interface UnloadResult { - success: boolean - error?: string -} - -// 5. /chat -export interface chatOptions { - providerId: string - sessionId: string - /** Full OpenAI ChatCompletionRequest payload */ - payload: chatCompletionRequest -} -// Output for /chat will be Promise for non-streaming -// or Promise> for streaming - -// 7. /import -export interface ImportOptions { - modelPath: string - mmprojPath?: string - modelSha256?: string - modelSize?: number - mmprojSha256?: string - mmprojSize?: number -} - -export interface importResult { - success: boolean - modelInfo?: modelInfo - error?: string -} - /** * Base AIEngine * Applicable to all AI Engines @@ -240,63 +23,4 @@ export abstract class AIEngine extends BaseExtension { registerEngine() { EngineManager.instance().register(this) } - - /** - * Gets model info - * @param modelId - */ - abstract get(modelId: string): Promise - - /** - * Lists available models - */ - abstract list(): Promise - - /** - * Loads a model into memory - */ - abstract load(modelId: string, settings?: any): Promise - - /** - * Unloads a model from memory - */ - abstract unload(sessionId: string): Promise - - /** - * Sends a chat request to the model - */ - abstract chat( - opts: chatCompletionRequest, - abortController?: AbortController - ): Promise> - - /** - * Deletes a model - */ - abstract delete(modelId: string): Promise - - /** - * Updates a model - */ - abstract update(modelId: string, model: Partial): Promise - /** - * Imports a model - */ - abstract import(modelId: string, opts: ImportOptions): Promise - - /** - * Aborts an ongoing model import - */ - abstract abortImport(modelId: string): Promise - - /** - * Get currently loaded models - */ - abstract getLoadedModels(): Promise - - /** - * Check if a tool is supported by the model - * @param modelId - */ - abstract isToolSupported(modelId: string): Promise } diff --git a/core/src/browser/extensions/engines/LocalAIEngine.test.ts b/core/src/browser/extensions/engines/LocalAIEngine.test.ts new file mode 100644 index 0000000000..b7d50e814f --- /dev/null +++ b/core/src/browser/extensions/engines/LocalAIEngine.test.ts @@ -0,0 +1,69 @@ +import { describe, it, expect, beforeEach, vi } from 'vitest' +import { LocalAIEngine } from './LocalAIEngine' + +class TestLocalAIEngine extends LocalAIEngine { + provider = 'test-provider' + + async onUnload(): Promise {} + + async get() { + return undefined + } + async list() { + return [] + } + async load() { + return {} as any + } + async unload() { + return {} as any + } + async chat() { + return {} as any + } + async delete() {} + async update() {} + async import() {} + async abortImport() {} + async getLoadedModels() { + return [] + } + async isToolSupported() { + return false + } +} + +describe('LocalAIEngine', () => { + let engine: TestLocalAIEngine + + beforeEach(() => { + engine = new TestLocalAIEngine('', '') + vi.clearAllMocks() + }) + + describe('onLoad', () => { + it('should call super.onLoad', async () => { + const superOnLoadSpy = vi.spyOn( + Object.getPrototypeOf(Object.getPrototypeOf(engine)), + 'onLoad' + ) + + await engine.onLoad() + + expect(superOnLoadSpy).toHaveBeenCalled() + }) + }) + + describe('abstract requirements', () => { + it('should implement provider', () => { + expect(engine.provider).toBe('test-provider') + }) + + it('should implement abstract methods', async () => { + expect(await engine.get('id')).toBeUndefined() + expect(await engine.list()).toEqual([]) + expect(await engine.getLoadedModels()).toEqual([]) + expect(await engine.isToolSupported('id')).toBe(false) + }) + }) +}) diff --git a/core/src/browser/extensions/engines/LocalAIEngine.ts b/core/src/browser/extensions/engines/LocalAIEngine.ts new file mode 100644 index 0000000000..609e164c87 --- /dev/null +++ b/core/src/browser/extensions/engines/LocalAIEngine.ts @@ -0,0 +1,87 @@ +import { AIEngine } from './AIEngine' +import { + modelInfo, + SessionInfo, + UnloadResult, + chatCompletionRequest, + chatCompletion, + chatCompletionChunk, + ImportOptions, +} from './LocalAIEngineTypes' +/** + * Base AI Local Inference Provider + */ +export abstract class LocalAIEngine extends AIEngine { + /** + * This class represents a base for local inference providers in the OpenAI architecture. + * It extends the AIEngine class and provides the implementation of loading and unloading models locally. + */ + + override async onLoad(): Promise { + super.onLoad() // ensures registration happens + } + + /* + * For any clean ups before extension shutdown + */ + abstract onUnload(): Promise + + /** + * Gets model info + * @param modelId + */ + abstract get(modelId: string): Promise + + /** + * Lists available models + */ + abstract list(): Promise + + /** + * Loads a model into memory + */ + abstract load(modelId: string, settings?: any): Promise + + /** + * Unloads a model from memory + */ + abstract unload(sessionId: string): Promise + + /** + * Sends a chat request to the model + */ + abstract chat( + opts: chatCompletionRequest, + abortController?: AbortController + ): Promise> + + /** + * Deletes a model + */ + abstract delete(modelId: string): Promise + + /** + * Updates a model + */ + abstract update(modelId: string, model: Partial): Promise + /** + * Imports a model + */ + abstract import(modelId: string, opts: ImportOptions): Promise + + /** + * Aborts an ongoing model import + */ + abstract abortImport(modelId: string): Promise + + /** + * Get currently loaded models + */ + abstract getLoadedModels(): Promise + + /** + * Check if a tool is supported by the model + * @param modelId + */ + abstract isToolSupported(modelId: string): Promise +} diff --git a/core/src/browser/extensions/engines/LocalAIEngineTypes.ts b/core/src/browser/extensions/engines/LocalAIEngineTypes.ts new file mode 100644 index 0000000000..6d6d475fec --- /dev/null +++ b/core/src/browser/extensions/engines/LocalAIEngineTypes.ts @@ -0,0 +1,216 @@ +/* AIEngine class types */ + +export interface chatCompletionRequestMessage { + role: 'system' | 'user' | 'assistant' | 'tool' + content: string | null | Content[] // Content can be a string OR an array of content parts + reasoning?: string | null // Some models return reasoning in completed responses + reasoning_content?: string | null // Some models return reasoning in completed responses + name?: string + tool_calls?: any[] // Simplified tool_call_id?: string +} + +export interface Content { + type: 'text' | 'image_url' | 'input_audio' + text?: string + image_url?: string + input_audio?: InputAudio +} + +export interface InputAudio { + data: string // Base64 encoded audio data + format: 'mp3' | 'wav' | 'ogg' | 'flac' // Add more formats as needed/llama-server seems to support mp3 +} + +export interface ToolFunction { + name: string // Required: a-z, A-Z, 0-9, _, -, max length 64 + description?: string + parameters?: Record // JSON Schema object + strict?: boolean | null // Defaults to false +} + +export interface Tool { + type: 'function' // Currently, only 'function' is supported + function: ToolFunction +} + +export interface ToolCallOptions { + tools?: Tool[] +} + +// A specific tool choice to force the model to call +export interface ToolCallSpec { + type: 'function' + function: { + name: string + } +} + +// tool_choice may be one of several modes or a specific call +export type ToolChoice = 'none' | 'auto' | 'required' | ToolCallSpec + +export interface chatCompletionRequest { + model: string // Model ID, though for local it might be implicit via sessionInfo + messages: chatCompletionRequestMessage[] + thread_id?: string // Thread/conversation ID for context tracking + return_progress?: boolean + tools?: Tool[] + tool_choice?: ToolChoice + // Core sampling parameters + temperature?: number | null + dynatemp_range?: number | null + dynatemp_exponent?: number | null + top_k?: number | null + top_p?: number | null + min_p?: number | null + typical_p?: number | null + repeat_penalty?: number | null + repeat_last_n?: number | null + presence_penalty?: number | null + frequency_penalty?: number | null + dry_multiplier?: number | null + dry_base?: number | null + dry_allowed_length?: number | null + dry_penalty_last_n?: number | null + dry_sequence_breakers?: string[] | null + xtc_probability?: number | null + xtc_threshold?: number | null + mirostat?: number | null // 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0 + mirostat_tau?: number | null + mirostat_eta?: number | null + + n_predict?: number | null + n_indent?: number | null + n_keep?: number | null + stream?: boolean | null + stop?: string | string[] | null + seed?: number | null // RNG seed + + // Advanced sampling + logit_bias?: { [key: string]: number } | null + n_probs?: number | null + min_keep?: number | null + t_max_predict_ms?: number | null + image_data?: Array<{ data: string; id: number }> | null + + // Internal/optimization parameters + id_slot?: number | null + cache_prompt?: boolean | null + return_tokens?: boolean | null + samplers?: string[] | null + timings_per_token?: boolean | null + post_sampling_probs?: boolean | null + chat_template_kwargs?: chat_template_kdict | null +} + +export interface chat_template_kdict { + enable_thinking: false +} + +export interface chatCompletionChunkChoiceDelta { + content?: string | null + role?: 'system' | 'user' | 'assistant' | 'tool' + tool_calls?: any[] // Simplified +} + +export interface chatCompletionChunkChoice { + index: number + delta: chatCompletionChunkChoiceDelta + finish_reason?: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' | null +} + +export interface chatCompletionPromptProgress { + cache: number + processed: number + time_ms: number + total: number +} + +export interface chatCompletionChunk { + id: string + object: 'chat.completion.chunk' + created: number + model: string + choices: chatCompletionChunkChoice[] + system_fingerprint?: string + prompt_progress?: chatCompletionPromptProgress +} + +export interface chatCompletionChoice { + index: number + message: chatCompletionRequestMessage // Response message + finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | 'function_call' + logprobs?: any // Simplified +} + +export interface chatCompletion { + id: string + object: 'chat.completion' + created: number + model: string // Model ID used + choices: chatCompletionChoice[] + usage?: { + prompt_tokens: number + completion_tokens: number + total_tokens: number + } + system_fingerprint?: string +} +// --- End OpenAI types --- + +// Shared model metadata +export interface modelInfo { + id: string // e.g. "qwen3-4B" or "org/model/quant" + name: string // human‑readable, e.g., "Qwen3 4B Q4_0" + quant_type?: string // q4_0 (optional as it might be part of ID or name) + providerId: string // e.g. "llama.cpp" + port: number + sizeBytes: number + tags?: string[] + path?: string // Absolute path to the model file, if applicable + // Additional provider-specific metadata can be added here + [key: string]: any +} + +// 1. /list +export type listResult = modelInfo[] + +export interface SessionInfo { + pid: number // opaque handle for unload/chat + port: number // llama-server output port (corrected from portid) + model_id: string //name of the model + model_path: string // path of the loaded model + is_embedding: boolean + api_key: string + mmproj_path?: string +} + +export interface UnloadResult { + success: boolean + error?: string +} + +// 5. /chat +export interface chatOptions { + providerId: string + sessionId: string + /** Full OpenAI ChatCompletionRequest payload */ + payload: chatCompletionRequest +} +// Output for /chat will be Promise for non-streaming +// or Promise> for streaming + +// 7. /import +export interface ImportOptions { + modelPath: string + mmprojPath?: string + modelSha256?: string + modelSize?: number + mmprojSha256?: string + mmprojSize?: number +} + +export interface importResult { + success: boolean + modelInfo?: modelInfo + error?: string +} diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts b/core/src/browser/extensions/engines/LocalOAIEngine.test.ts deleted file mode 100644 index 3523c3ce6f..0000000000 --- a/core/src/browser/extensions/engines/LocalOAIEngine.test.ts +++ /dev/null @@ -1,134 +0,0 @@ -import { describe, it, expect, beforeEach, vi, type Mock } from 'vitest' -import { LocalOAIEngine } from './LocalOAIEngine' -import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' - -vi.mock('../../events') - -class TestLocalOAIEngine extends LocalOAIEngine { - inferenceUrl = 'http://test-local-inference-url' - provider = 'test-local-provider' - nodeModule = 'test-node-module' - - async headers() { - return { Authorization: 'Bearer test-token' } - } - - async loadModel(model: Model & { file_path?: string }): Promise { - this.loadedModel = model - } - - async unloadModel(model?: Model) { - this.loadedModel = undefined - } -} - -describe('LocalOAIEngine', () => { - let engine: TestLocalOAIEngine - const mockModel: Model & { file_path?: string } = { - object: 'model', - version: '1.0.0', - format: 'gguf', - sources: [], - id: 'test-model', - name: 'Test Model', - description: 'A test model', - settings: {}, - parameters: {}, - metadata: {}, - file_path: '/path/to/model.gguf' - } - - beforeEach(() => { - engine = new TestLocalOAIEngine('', '') - vi.clearAllMocks() - }) - - describe('onLoad', () => { - it('should call super.onLoad and subscribe to model events', () => { - const superOnLoadSpy = vi.spyOn(Object.getPrototypeOf(Object.getPrototypeOf(engine)), 'onLoad') - - engine.onLoad() - - expect(superOnLoadSpy).toHaveBeenCalled() - expect(events.on).toHaveBeenCalledWith( - ModelEvent.OnModelInit, - expect.any(Function) - ) - expect(events.on).toHaveBeenCalledWith( - ModelEvent.OnModelStop, - expect.any(Function) - ) - }) - - it('should load model when OnModelInit event is triggered', () => { - const loadModelSpy = vi.spyOn(engine, 'loadModel') - engine.onLoad() - - // Get the event handler for OnModelInit - const onModelInitCall = (events.on as Mock).mock.calls.find( - call => call[0] === ModelEvent.OnModelInit - ) - const onModelInitHandler = onModelInitCall[1] - - // Trigger the event handler - onModelInitHandler(mockModel) - - expect(loadModelSpy).toHaveBeenCalledWith(mockModel) - }) - - it('should unload model when OnModelStop event is triggered', () => { - const unloadModelSpy = vi.spyOn(engine, 'unloadModel') - engine.onLoad() - - // Get the event handler for OnModelStop - const onModelStopCall = (events.on as Mock).mock.calls.find( - call => call[0] === ModelEvent.OnModelStop - ) - const onModelStopHandler = onModelStopCall[1] - - // Trigger the event handler - onModelStopHandler(mockModel) - - expect(unloadModelSpy).toHaveBeenCalledWith(mockModel) - }) - }) - - describe('properties', () => { - it('should have correct default function names', () => { - expect(engine.loadModelFunctionName).toBe('loadModel') - expect(engine.unloadModelFunctionName).toBe('unloadModel') - }) - - it('should have abstract nodeModule property implemented', () => { - expect(engine.nodeModule).toBe('test-node-module') - }) - }) - - describe('loadModel', () => { - it('should load the model and set loadedModel', async () => { - await engine.loadModel(mockModel) - expect(engine.loadedModel).toBe(mockModel) - }) - - it('should handle model with file_path', async () => { - const modelWithPath = { ...mockModel, file_path: '/custom/path/model.gguf' } - await engine.loadModel(modelWithPath) - expect(engine.loadedModel).toBe(modelWithPath) - }) - }) - - describe('unloadModel', () => { - it('should unload the model and clear loadedModel', async () => { - engine.loadedModel = mockModel - await engine.unloadModel(mockModel) - expect(engine.loadedModel).toBeUndefined() - }) - - it('should handle unload without passing a model', async () => { - engine.loadedModel = mockModel - await engine.unloadModel() - expect(engine.loadedModel).toBeUndefined() - }) - }) -}) diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts deleted file mode 100644 index d9f9220bf4..0000000000 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ /dev/null @@ -1,41 +0,0 @@ -import { events } from '../../events' -import { Model, ModelEvent } from '../../../types' -import { OAIEngine } from './OAIEngine' - -/** - * Base OAI Local Inference Provider - * Added the implementation of loading and unloading model (applicable to local inference providers) - */ -export abstract class LocalOAIEngine extends OAIEngine { - // The inference engine - abstract nodeModule: string - loadModelFunctionName: string = 'loadModel' - unloadModelFunctionName: string = 'unloadModel' - - /** - * This class represents a base for local inference providers in the OpenAI architecture. - * It extends the OAIEngine class and provides the implementation of loading and unloading models locally. - * The loadModel function subscribes to the ModelEvent.OnModelInit event, loading models when initiated. - * The unloadModel function subscribes to the ModelEvent.OnModelStop event, unloading models when stopped. - */ - override onLoad() { - super.onLoad() - // These events are applicable to local inference providers - events.on(ModelEvent.OnModelInit, (model: Model) => this.loadModel(model)) - events.on(ModelEvent.OnModelStop, (model: Model) => this.unloadModel(model)) - } - - /** - * Load the model. - */ - async loadModel(model: Model & { file_path?: string }): Promise { - // Implementation of loading the model - } - - /** - * Stops the model. - */ - async unloadModel(model?: Model) { - // Implementation of unloading the model - } -} diff --git a/core/src/browser/extensions/engines/index.ts b/core/src/browser/extensions/engines/index.ts index 34ef45afd1..7a1ac6ed2d 100644 --- a/core/src/browser/extensions/engines/index.ts +++ b/core/src/browser/extensions/engines/index.ts @@ -1,5 +1,6 @@ export * from './AIEngine' export * from './OAIEngine' -export * from './LocalOAIEngine' export * from './RemoteOAIEngine' export * from './EngineManager' +export * from './LocalAIEngine' +export * from './LocalAIEngineTypes' diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 3f2054e2be..7dbec61be5 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -7,7 +7,6 @@ */ import { - AIEngine, getJanDataFolderPath, fs, joinPath, @@ -22,6 +21,7 @@ import { AppEvent, DownloadEvent, chatCompletionRequestMessage, + LocalAIEngine, } from '@janhq/core' import { error, info, warn } from '@tauri-apps/plugin-log' @@ -183,7 +183,7 @@ const logger = { // - lib/ // - e.g. libcudart.so.12 -export default class llamacpp_extension extends AIEngine { +export default class llamacpp_extension extends LocalAIEngine { provider: string = 'llamacpp' autoUnload: boolean = true timeout: number = 600