diff --git a/server/__tests__/utils/AiProviders/ollama/index.test.js b/server/__tests__/utils/AiProviders/ollama/index.test.js new file mode 100644 index 00000000000..5ad7da24e45 --- /dev/null +++ b/server/__tests__/utils/AiProviders/ollama/index.test.js @@ -0,0 +1,190 @@ +const { OllamaAILLM } = require("../../../../utils/AiProviders/ollama"); + +jest.mock("../../../../utils/EmbeddingEngines/native"); +jest.mock("ollama"); + +describe("OllamaAILLM", () => { + beforeEach(() => { + jest.clearAllMocks(); + OllamaAILLM.modelContextWindows = {}; + OllamaAILLM._cachePromise = null; + process.env.OLLAMA_BASE_PATH = "http://localhost:11434"; + }); + + afterEach(() => { + delete process.env.OLLAMA_BASE_PATH; + }); + + describe("Constructor initialization", () => { + it("initializes limits immediately to prevent race conditions", () => { + const llm = new OllamaAILLM(); + + expect(llm.limits).toBeDefined(); + expect(llm.limits.user).toBe(4096 * 0.7); + expect(llm.limits.system).toBe(4096 * 0.15); + expect(llm.limits.history).toBe(4096 * 0.15); + }); + + it("sets model from preference or defaults to undefined", () => { + const llm1 = new OllamaAILLM(null, "custom-model"); + expect(llm1.model).toBe("custom-model"); + + const llm2 = new OllamaAILLM(); + expect(llm2.model).toBeUndefined(); + }); + }); + + describe("cacheContextWindows", () => { + it("caches context windows for all chat models", async () => { + const mockClient = { + list: jest.fn().mockResolvedValue({ + models: [ + { name: "model-1" }, + { name: "embedding-model" }, + ], + }), + show: jest.fn((params) => { + if (params.model === "model-1") { + return Promise.resolve({ + capabilities: [], + model_info: { "general.context_length": 8192 }, + }); + } + return Promise.resolve({ + capabilities: ["embedding"], + model_info: { "general.context_length": 2048 }, + }); + }), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + await OllamaAILLM.cacheContextWindows(true); + + expect(OllamaAILLM.modelContextWindows["model-1"]).toBe(8192); + expect(OllamaAILLM.modelContextWindows["embedding-model"]).toBeUndefined(); + }); + + it("handles concurrent cache requests without duplicate fetches", async () => { + const mockClient = { + list: jest.fn().mockImplementation(() => + new Promise((resolve) => + setTimeout(() => resolve({ + models: [{ name: "model-1" }], + }), 100) + ) + ), + show: jest.fn().mockResolvedValue({ + capabilities: [], + model_info: { "general.context_length": 8192 }, + }), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + const promise1 = OllamaAILLM.cacheContextWindows(true); + const promise2 = OllamaAILLM.cacheContextWindows(false); + + await Promise.all([promise1, promise2]); + + expect(mockClient.list).toHaveBeenCalledTimes(1); + }); + }); + + describe("ensureModelCached", () => { + it("returns immediately if model is already cached", async () => { + OllamaAILLM.modelContextWindows["model-1"] = 8192; + + const mockClient = { + list: jest.fn(), + show: jest.fn(), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + await OllamaAILLM.ensureModelCached("model-1"); + + expect(mockClient.list).not.toHaveBeenCalled(); + }); + + it("refreshes cache when model is not found", async () => { + const mockClient = { + list: jest.fn().mockResolvedValue({ + models: [{ name: "new-model" }], + }), + show: jest.fn().mockResolvedValue({ + capabilities: [], + model_info: { "general.context_length": 4096 }, + }), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + await OllamaAILLM.ensureModelCached("new-model"); + + expect(mockClient.list).toHaveBeenCalledTimes(1); + expect(OllamaAILLM.modelContextWindows["new-model"]).toBe(4096); + }); + }); + + describe("getChatCompletion", () => { + it("ensures model is cached before making request", async () => { + const mockClient = { + list: jest.fn().mockResolvedValue({ + models: [{ name: "test-model" }], + }), + show: jest.fn().mockResolvedValue({ + capabilities: [], + model_info: { "general.context_length": 8192 }, + }), + chat: jest.fn().mockResolvedValue({ + message: { content: "Response" }, + prompt_eval_count: 10, + eval_count: 20, + }), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + const llm = new OllamaAILLM(null, "test-model"); + await llm.getChatCompletion([{ role: "user", content: "Hello" }], { temperature: 0.7 }); + + expect(OllamaAILLM.modelContextWindows["test-model"]).toBe(8192); + expect(mockClient.chat).toHaveBeenCalled(); + }); + }); + + describe("streamGetChatCompletion", () => { + it("ensures model is cached before streaming", async () => { + const mockClient = { + list: jest.fn().mockResolvedValue({ + models: [{ name: "test-model" }], + }), + show: jest.fn().mockResolvedValue({ + capabilities: [], + model_info: { "general.context_length": 8192 }, + }), + chat: jest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { message: { content: "Hello" } }; + }, + }), + }; + + const { Ollama } = require("ollama"); + Ollama.mockImplementation(() => mockClient); + + const llm = new OllamaAILLM(null, "test-model"); + await llm.streamGetChatCompletion([{ role: "user", content: "Hello" }], { temperature: 0.7 }); + + expect(OllamaAILLM.modelContextWindows["test-model"]).toBe(8192); + expect(mockClient.chat).toHaveBeenCalled(); + }); + }); +}); + diff --git a/server/__tests__/utils/agents/aibitat/providers/ollama.test.js b/server/__tests__/utils/agents/aibitat/providers/ollama.test.js new file mode 100644 index 00000000000..ae8073a95e3 --- /dev/null +++ b/server/__tests__/utils/agents/aibitat/providers/ollama.test.js @@ -0,0 +1,82 @@ +const OllamaProvider = require("../../../../../utils/agents/aibitat/providers/ollama"); +const { OllamaAILLM } = require("../../../../../utils/AiProviders/ollama"); + +jest.mock("ollama"); +jest.mock("../../../../../utils/AiProviders/ollama", () => ({ + OllamaAILLM: { + ensureModelCached: jest.fn(), + cacheContextWindows: jest.fn(), + promptWindowLimit: jest.fn(() => 4096), + }, +})); + +describe("OllamaProvider", () => { + beforeEach(() => { + jest.clearAllMocks(); + process.env.OLLAMA_BASE_PATH = "http://localhost:11434"; + }); + + afterEach(() => { + delete process.env.OLLAMA_BASE_PATH; + }); + + describe("Initialization", () => { + it("initializes with provided model or undefined", () => { + const provider1 = new OllamaProvider({ model: "custom-model" }); + expect(provider1.model).toBe("custom-model"); + + const provider2 = new OllamaProvider(); + expect(provider2.model).toBeNull(); + }); + + it("supports agent streaming", () => { + const provider = new OllamaProvider(); + expect(provider.supportsAgentStreaming).toBe(true); + }); + }); + + describe("Chat completion", () => { + it("ensures model is cached before completing", async () => { + const provider = new OllamaProvider({ model: "test-model" }); + const mockChat = jest.fn().mockResolvedValue({ + message: { content: "Response" }, + }); + + provider._client = { + chat: mockChat, + }; + + await provider.complete([{ role: "user", content: "Hello" }]); + + expect(OllamaAILLM.ensureModelCached).toHaveBeenCalledWith("test-model"); + expect(mockChat).toHaveBeenCalled(); + }); + + it("ensures model is cached before streaming", async () => { + const provider = new OllamaProvider({ model: "test-model" }); + const mockChat = jest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { message: { content: "Hello" }, done: false }; + yield { done: true }; + }, + }); + + provider._client = { + chat: mockChat, + }; + + await provider.stream([{ role: "user", content: "Hello" }], [], null); + + expect(OllamaAILLM.ensureModelCached).toHaveBeenCalledWith("test-model"); + expect(mockChat).toHaveBeenCalled(); + }); + }); + + describe("Cost calculation", () => { + it("returns zero cost for Ollama", () => { + const provider = new OllamaProvider(); + expect(provider.getCost({})).toBe(0); + }); + }); +}); + diff --git a/server/utils/AiProviders/ollama/index.js b/server/utils/AiProviders/ollama/index.js index b88a8121870..dd750ca8374 100644 --- a/server/utils/AiProviders/ollama/index.js +++ b/server/utils/AiProviders/ollama/index.js @@ -13,6 +13,8 @@ const { Ollama } = require("ollama"); class OllamaAILLM { /** @see OllamaAILLM.cacheContextWindows */ static modelContextWindows = {}; + /** Tracks the current caching operation to prevent race conditions */ + static _cachePromise = null; constructor(embedder = null, modelPreference = null) { if (!process.env.OLLAMA_BASE_PATH) @@ -37,6 +39,14 @@ class OllamaAILLM { this.embedder = embedder ?? new NativeEmbedder(); this.defaultTemp = 0.7; + // Initialize limits with default values that will be updated by cacheContextWindows + this.limits = { + history: this.promptWindowLimit() * 0.15, + system: this.promptWindowLimit() * 0.15, + user: this.promptWindowLimit() * 0.7, + }; + + // Start caching in background and update limits when fetched OllamaAILLM.cacheContextWindows(true).then(() => { this.limits = { history: this.promptWindowLimit() * 0.15, @@ -68,40 +78,77 @@ class OllamaAILLM { */ static async cacheContextWindows(force = false) { try { - // Skip if we already have cached context windows and we're not forcing a refresh - if (Object.keys(OllamaAILLM.modelContextWindows).length > 0 && !force) + if (OllamaAILLM._cachePromise && !force) { + return await OllamaAILLM._cachePromise; + } + + // Already have cached context windows and not forcing a refresh + if (Object.keys(OllamaAILLM.modelContextWindows).length > 0 && !force) { return; + } - const authToken = process.env.OLLAMA_AUTH_TOKEN; - const basePath = process.env.OLLAMA_BASE_PATH; - const client = new Ollama({ - host: basePath, - headers: authToken ? { Authorization: `Bearer ${authToken}` } : {}, - }); + // Store cache promise to prevent multiple requests + OllamaAILLM._cachePromise = (async () => { + const authToken = process.env.OLLAMA_AUTH_TOKEN; + const basePath = process.env.OLLAMA_BASE_PATH; + const client = new Ollama({ + host: basePath, + headers: authToken ? { Authorization: `Bearer ${authToken}` } : {}, + }); - const { models } = await client.list().catch(() => ({ models: [] })); - if (!models.length) return; + const { models } = await client.list().catch(() => ({ models: [] })); + if (!models.length) return; - const infoPromises = models.map((model) => - client - .show({ model: model.name }) - .then((info) => ({ name: model.name, ...info })) - ); - const infos = await Promise.all(infoPromises); - infos.forEach((showInfo) => { - if (showInfo.capabilities.includes("embedding")) return; - const contextWindowKey = Object.keys(showInfo.model_info).find((key) => - key.endsWith(".context_length") + const infoPromises = models.map((model) => + client + .show({ model: model.name }) + .then((info) => ({ name: model.name, ...info })) ); - if (!contextWindowKey) - return (OllamaAILLM.modelContextWindows[showInfo.name] = 4096); - OllamaAILLM.modelContextWindows[showInfo.name] = - showInfo.model_info[contextWindowKey]; - }); - OllamaAILLM.#slog(`Context windows cached for all models!`); + const infos = await Promise.all(infoPromises); + infos.forEach((showInfo) => { + if (showInfo.capabilities.includes("embedding")) return; + const contextWindowKey = Object.keys(showInfo.model_info).find( + (key) => key.endsWith(".context_length") + ); + if (!contextWindowKey) + return (OllamaAILLM.modelContextWindows[showInfo.name] = 4096); + OllamaAILLM.modelContextWindows[showInfo.name] = + showInfo.model_info[contextWindowKey]; + }); + OllamaAILLM.#slog(`Context windows cached for all models!`); + })(); + + await OllamaAILLM._cachePromise; } catch (e) { OllamaAILLM.#slog(`Error caching context windows`, e); return; + } finally { + OllamaAILLM._cachePromise = null; + } + } + + /** + * Ensure a specific model is cached. If the model is not in the cache, + * refresh the cache to get the latest models from Ollama. + * This handles the case where users download new models after the initial cache. + * @param {string} modelName - The model name to check + * @returns {Promise} + */ + static async ensureModelCached(modelName) { + if (OllamaAILLM.modelContextWindows[modelName]) { + return; + } + + // Model may have been downloaded after the initial cache so try to refresh + OllamaAILLM.#slog( + `Model "${modelName}" not in cache, refreshing model list...` + ); + await OllamaAILLM.cacheContextWindows(true); + + if (!OllamaAILLM.modelContextWindows[modelName]) { + OllamaAILLM.#slog( + `Model "${modelName}" still not found after refresh, will use fallback context window` + ); } } @@ -240,6 +287,10 @@ class OllamaAILLM { } async getChatCompletion(messages = null, { temperature = 0.7 }) { + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await OllamaAILLM.ensureModelCached(this.model); + const result = await LLMPerformanceMonitor.measureAsyncFunction( this.client .chat({ @@ -289,6 +340,10 @@ class OllamaAILLM { } async streamGetChatCompletion(messages = null, { temperature = 0.7 }) { + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await OllamaAILLM.ensureModelCached(this.model); + const measuredStreamRequest = await LLMPerformanceMonitor.measureStream( this.client.chat({ model: this.model, diff --git a/server/utils/agents/aibitat/providers/ollama.js b/server/utils/agents/aibitat/providers/ollama.js index 532bf61af98..480d521ddc7 100644 --- a/server/utils/agents/aibitat/providers/ollama.js +++ b/server/utils/agents/aibitat/providers/ollama.js @@ -57,7 +57,10 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { * @returns {Promise} The completion. */ async #handleFunctionCallChat({ messages = [] }) { - await OllamaAILLM.cacheContextWindows(); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await OllamaAILLM.ensureModelCached(this.model); + const response = await this.client.chat({ model: this.model, messages, @@ -67,7 +70,10 @@ class OllamaProvider extends InheritMultiple([Provider, UnTooled]) { } async #handleFunctionCallStream({ messages = [] }) { - await OllamaAILLM.cacheContextWindows(); + // Ensure context window is cached before proceeding + // Prevents race conditions and handles newly downloaded models + await OllamaAILLM.ensureModelCached(this.model); + return await this.client.chat({ model: this.model, messages,