diff --git a/server/__tests__/utils/AiProviders/lmStudio/index.test.js b/server/__tests__/utils/AiProviders/lmStudio/index.test.js new file mode 100644 index 00000000000..ecfe8d480e7 --- /dev/null +++ b/server/__tests__/utils/AiProviders/lmStudio/index.test.js @@ -0,0 +1,152 @@ +const { LMStudioLLM } = require("../../../../utils/AiProviders/lmStudio"); + +jest.mock("../../../../utils/EmbeddingEngines/native"); +jest.mock("openai"); + +global.fetch = jest.fn(); + +describe("LMStudioLLM", () => { + beforeEach(() => { + jest.clearAllMocks(); + LMStudioLLM.modelContextWindows = {}; + LMStudioLLM._cachePromise = null; + process.env.LMSTUDIO_BASE_PATH = "http://localhost:1234"; + }); + + afterEach(() => { + delete process.env.LMSTUDIO_BASE_PATH; + }); + + describe("Constructor initialization", () => { + it("initializes limits immediately to prevent race conditions", () => { + const llm = new LMStudioLLM(); + + expect(llm.limits).toBeDefined(); + expect(llm.limits.user).toBe(4096 * 0.7); + expect(llm.limits.system).toBe(4096 * 0.15); + expect(llm.limits.history).toBe(4096 * 0.15); + }); + + it("sets model from preference or defaults to fallback", () => { + const llm1 = new LMStudioLLM(null, "custom-model"); + expect(llm1.model).toBe("custom-model"); + + const llm2 = new LMStudioLLM(); + expect(llm2.model).toBe("Loaded from Chat UI"); + }); + }); + + describe("cacheContextWindows", () => { + it("caches context windows for all chat models", async () => { + global.fetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + data: [ + { id: "model-1", type: "chat", max_context_length: 8192 }, + { id: "embedding-model", type: "embeddings", max_context_length: 2048 }, + ], + }), + }); + + await LMStudioLLM.cacheContextWindows(true); + + expect(LMStudioLLM.modelContextWindows["model-1"]).toBe(8192); + expect(LMStudioLLM.modelContextWindows["embedding-model"]).toBeUndefined(); + }); + + it("handles concurrent cache requests without duplicate fetches", async () => { + global.fetch.mockImplementation(() => + new Promise((resolve) => + setTimeout(() => resolve({ + ok: true, + json: async () => ({ data: [{ id: "model-1", type: "chat", max_context_length: 8192 }] }), + }), 100) + ) + ); + + const promise1 = LMStudioLLM.cacheContextWindows(true); + const promise2 = LMStudioLLM.cacheContextWindows(false); + + await Promise.all([promise1, promise2]); + + expect(fetch).toHaveBeenCalledTimes(1); + }); + }); + + describe("ensureModelCached", () => { + it("returns immediately if model is already cached", async () => { + LMStudioLLM.modelContextWindows["model-1"] = 8192; + + await LMStudioLLM.ensureModelCached("model-1"); + + expect(fetch).not.toHaveBeenCalled(); + }); + + it("refreshes cache when model is not found", async () => { + global.fetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + data: [{ id: "new-model", type: "chat", max_context_length: 4096 }], + }), + }); + + await LMStudioLLM.ensureModelCached("new-model"); + + expect(fetch).toHaveBeenCalledTimes(1); + expect(LMStudioLLM.modelContextWindows["new-model"]).toBe(4096); + }); + }); + + describe("getChatCompletion", () => { + it("ensures model is cached before making request", async () => { + global.fetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + data: [{ id: "test-model", type: "chat", max_context_length: 8192 }], + }), + }); + + const llm = new LMStudioLLM(null, "test-model"); + const mockCreate = jest.fn().mockResolvedValue({ + choices: [{ message: { content: "Response" } }], + usage: { prompt_tokens: 10, completion_tokens: 20, total_tokens: 30 }, + }); + + llm.lmstudio = { + chat: { + completions: { create: mockCreate }, + }, + }; + + await llm.getChatCompletion([{ role: "user", content: "Hello" }], { temperature: 0.7 }); + + expect(LMStudioLLM.modelContextWindows["test-model"]).toBe(8192); + expect(mockCreate).toHaveBeenCalled(); + }); + }); + + describe("streamGetChatCompletion", () => { + it("ensures model is cached before streaming", async () => { + global.fetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + data: [{ id: "test-model", type: "chat", max_context_length: 8192 }], + }), + }); + + const llm = new LMStudioLLM(null, "test-model"); + const mockCreate = jest.fn().mockResolvedValue({ [Symbol.asyncIterator]: jest.fn() }); + + llm.lmstudio = { + chat: { + completions: { create: mockCreate }, + }, + }; + + await llm.streamGetChatCompletion([{ role: "user", content: "Hello" }], { temperature: 0.7 }); + + expect(LMStudioLLM.modelContextWindows["test-model"]).toBe(8192); + expect(mockCreate).toHaveBeenCalled(); + }); + }); +}); diff --git a/server/__tests__/utils/agents/aibitat/providers/lmstudio.test.js b/server/__tests__/utils/agents/aibitat/providers/lmstudio.test.js new file mode 100644 index 00000000000..0a24a61b126 --- /dev/null +++ b/server/__tests__/utils/agents/aibitat/providers/lmstudio.test.js @@ -0,0 +1,85 @@ +const LMStudioProvider = require("../../../../../utils/agents/aibitat/providers/lmstudio"); +const { LMStudioLLM } = require("../../../../../utils/AiProviders/lmStudio"); + +jest.mock("openai"); +jest.mock("../../../../../utils/AiProviders/lmStudio", () => ({ + LMStudioLLM: { + ensureModelCached: jest.fn(), + cacheContextWindows: jest.fn(), + }, + parseLMStudioBasePath: jest.fn((path) => path), +})); + +describe("LMStudioProvider", () => { + beforeEach(() => { + jest.clearAllMocks(); + process.env.LMSTUDIO_BASE_PATH = "http://localhost:1234"; + }); + + afterEach(() => { + delete process.env.LMSTUDIO_BASE_PATH; + delete process.env.LMSTUDIO_MODEL_PREF; + }); + + describe("Initialization", () => { + it("initializes with provided model or defaults", () => { + const provider1 = new LMStudioProvider({ model: "custom-model" }); + expect(provider1.model).toBe("custom-model"); + + const provider2 = new LMStudioProvider(); + expect(provider2.model).toBe("Loaded from Chat UI"); + }); + + it("supports agent streaming", () => { + const provider = new LMStudioProvider(); + expect(provider.supportsAgentStreaming).toBe(true); + }); + }); + + describe("Chat completion", () => { + it("ensures model is cached before completing", async () => { + const provider = new LMStudioProvider({ model: "test-model" }); + const mockCreate = jest.fn().mockResolvedValue({ + choices: [{ message: { content: "Response" } }], + }); + + provider._client = { + chat: { + completions: { create: mockCreate }, + }, + }; + + await provider.complete([{ role: "user", content: "Hello" }]); + + expect(LMStudioLLM.ensureModelCached).toHaveBeenCalledWith("test-model"); + expect(mockCreate).toHaveBeenCalled(); + }); + + it("ensures model is cached before streaming", async () => { + const provider = new LMStudioProvider({ model: "test-model" }); + const mockCreate = jest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { choices: [{ delta: { content: "Hello" } }] }; + }, + }); + + provider._client = { + chat: { + completions: { create: mockCreate }, + }, + }; + + await provider.stream([{ role: "user", content: "Hello" }], [], null); + + expect(LMStudioLLM.ensureModelCached).toHaveBeenCalledWith("test-model"); + expect(mockCreate).toHaveBeenCalled(); + }); + }); + + describe("Cost calculation", () => { + it("returns zero cost for LMStudio", () => { + const provider = new LMStudioProvider(); + expect(provider.getCost({})).toBe(0); + }); + }); +}); diff --git a/server/utils/AiProviders/lmStudio/index.js b/server/utils/AiProviders/lmStudio/index.js index d95d6c30f84..1827ac155fa 100644 --- a/server/utils/AiProviders/lmStudio/index.js +++ b/server/utils/AiProviders/lmStudio/index.js @@ -12,6 +12,8 @@ const { OpenAI: OpenAIApi } = require("openai"); class LMStudioLLM { /** @see LMStudioLLM.cacheContextWindows */ static modelContextWindows = {}; + /** Tracks the current caching operation to prevent race conditions */ + static _cachePromise = null; constructor(embedder = null, modelPreference = null) { if (!process.env.LMSTUDIO_BASE_PATH) @@ -36,6 +38,14 @@ class LMStudioLLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; + // Initialize limits with default values that will be updated by cacheContextWindows + this.limits = { + history: this.promptWindowLimit() * 0.15, + system: this.promptWindowLimit() * 0.15, + user: this.promptWindowLimit() * 0.7, + }; + + // Start caching in background and update limits when fetched LMStudioLLM.cacheContextWindows(true).then(() => { this.limits = { history: this.promptWindowLimit() * 0.15, @@ -67,34 +77,73 @@ class LMStudioLLM { */ static async cacheContextWindows(force = false) { try { - // Skip if we already have cached context windows and we're not forcing a refresh - if (Object.keys(LMStudioLLM.modelContextWindows).length > 0 && !force) - return; + if (LMStudioLLM._cachePromise && !force) { + return await LMStudioLLM._cachePromise; + } - const endpoint = new URL(process.env.LMSTUDIO_BASE_PATH); - endpoint.pathname = "/api/v0/models"; - await fetch(endpoint.toString()) - .then((res) => { - if (!res.ok) - throw new Error(`LMStudio:cacheContextWindows - ${res.statusText}`); - return res.json(); - }) - .then(({ data: models }) => { - models.forEach((model) => { - if (model.type === "embeddings") return; - LMStudioLLM.modelContextWindows[model.id] = - model.max_context_length; + // Already have cached context windows and not forcing a refresh + if (Object.keys(LMStudioLLM.modelContextWindows).length > 0 && !force) { + return; + } + + // Store cache promise to prevent multiple requests + LMStudioLLM._cachePromise = (async () => { + const endpoint = new URL(process.env.LMSTUDIO_BASE_PATH); + endpoint.pathname = "/api/v0/models"; + await fetch(endpoint.toString()) + .then((res) => { + if (!res.ok) + throw new Error( + `LMStudio:cacheContextWindows - ${res.statusText}` + ); + return res.json(); + }) + .then(({ data: models }) => { + models.forEach((model) => { + if (model.type === "embeddings") return; + LMStudioLLM.modelContextWindows[model.id] = + model.max_context_length; + }); + }) + .catch((e) => { + LMStudioLLM.#slog(`Error caching context windows`, e); + return; }); - }) - .catch((e) => { - LMStudioLLM.#slog(`Error caching context windows`, e); - return; - }); - LMStudioLLM.#slog(`Context windows cached for all models!`); + LMStudioLLM.#slog(`Context windows cached for all models!`); + })(); + + await LMStudioLLM._cachePromise; } catch (e) { LMStudioLLM.#slog(`Error caching context windows`, e); return; + } finally { + LMStudioLLM._cachePromise = null; + } + } + + /** + * Ensure a specific model is cached. If the model is not in the cache, + * refresh the cache to get the latest models from LMStudio. + * Handles the case where users download new models after the initial cache. + * @param {string} modelName - Model name to check + * @returns {Promise} + */ + static async ensureModelCached(modelName) { + if (LMStudioLLM.modelContextWindows[modelName]) { + return; + } + + // Model may have been downloaded after the initial cache so try to refresh + LMStudioLLM.#slog( + `Model "${modelName}" not in cache, refreshing model list...` + ); + await LMStudioLLM.cacheContextWindows(true); + + if (!LMStudioLLM.modelContextWindows[modelName]) { + LMStudioLLM.#slog( + `Model "${modelName}" still not found after refresh, will use fallback context window` + ); } } @@ -198,6 +247,10 @@ class LMStudioLLM { `LMStudio chat: ${this.model} is not valid or defined model for chat completion!` ); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await LMStudioLLM.ensureModelCached(this.model); + const result = await LLMPerformanceMonitor.measureAsyncFunction( this.lmstudio.chat.completions.create({ model: this.model, @@ -230,6 +283,10 @@ class LMStudioLLM { `LMStudio chat: ${this.model} is not valid or defined model for chat completion!` ); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await LMStudioLLM.ensureModelCached(this.model); + const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( this.lmstudio.chat.completions.create({ model: this.model, diff --git a/server/utils/agents/aibitat/providers/lmstudio.js b/server/utils/agents/aibitat/providers/lmstudio.js index bf6f238fadd..81c39e411e6 100644 --- a/server/utils/agents/aibitat/providers/lmstudio.js +++ b/server/utils/agents/aibitat/providers/lmstudio.js @@ -41,7 +41,10 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { } async #handleFunctionCallChat({ messages = [] }) { - await LMStudioLLM.cacheContextWindows(); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await LMStudioLLM.ensureModelCached(this.model); + return await this.client.chat.completions .create({ model: this.model, @@ -60,7 +63,10 @@ class LMStudioProvider extends InheritMultiple([Provider, UnTooled]) { } async #handleFunctionCallStream({ messages = [] }) { - await LMStudioLLM.cacheContextWindows(); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await LMStudioLLM.ensureModelCached(this.model); + return await this.client.chat.completions.create({ model: this.model, stream: true,