diff --git a/packages/llm/src/model.middleware.test.ts b/packages/llm/src/model.middleware.test.ts index 7be7c8e..0578e2d 100644 --- a/packages/llm/src/model.middleware.test.ts +++ b/packages/llm/src/model.middleware.test.ts @@ -1,10 +1,17 @@ +import { dir } from "@synstack/fs"; import { fsCache } from "@synstack/fs-cache"; import { MockLanguageModelV1, simulateReadableStream } from "ai/test"; import assert from "node:assert/strict"; -import { describe, it } from "node:test"; +import { afterEach, describe, it } from "node:test"; import type { Llm } from "./llm.types.ts"; import { cacheCalls, includeAssistantMessage } from "./model.middleware.ts"; +const TMP_DIR = dir(import.meta.dirname).to("tmp"); + +afterEach(async () => { + await TMP_DIR.rm(); +}); + describe("includeAssistantMessage", () => { it("adds the last assistant message to a generated output", async () => { const model = new MockLanguageModelV1({ @@ -155,4 +162,61 @@ describe("cache", () => { assert.equal(output, " world!"); }); + + it("caches a streamed response on miss", async () => { + const cache = fsCache(TMP_DIR.path).key(["stream-miss"]); + + let callCount = 0; + const model = new MockLanguageModelV1({ + doStream: () => { + callCount++; + return Promise.resolve({ + rawCall: { rawPrompt: null, rawSettings: {} }, + stream: simulateReadableStream({ + chunks: [ + { type: "text-delta", textDelta: "Hello" }, + { type: "text-delta", textDelta: " " }, + { type: "text-delta", textDelta: "world" }, + { + type: "finish", + finishReason: "stop", + usage: { promptTokens: 10, completionTokens: 20 }, + }, + ] satisfies Array, + chunkDelayInMs: 0, + }), + }); + }, + }); + + const wrappedModel = cacheCalls(cache)(model); + + const params: Llm.Model.Stream.Options = { + inputFormat: "messages", + mode: { type: "regular" }, + prompt: [ + { + role: "user", + content: [ + { + type: "text", + text: "Hello", + }, + ], + }, + ], + }; + + const res1 = await wrappedModel.doStream(params); + + let output1 = ""; + for await (const chunk of res1.stream) { + if (chunk.type !== "text-delta") continue; + output1 += chunk.textDelta; + } + + assert.equal(output1, "Hello world"); + // Ensure the stream can be consumed without throwing and collects all parts + assert.equal(callCount, 1); + }); }); diff --git a/packages/llm/src/model.middleware.ts b/packages/llm/src/model.middleware.ts index a71a795..91e8b3f 100644 --- a/packages/llm/src/model.middleware.ts +++ b/packages/llm/src/model.middleware.ts @@ -138,7 +138,7 @@ export const cacheCalls = // We need to collect the parts and then set the cache const res = await doStream(); const collector: Array = []; - res.stream.pipeThrough( + const piped = res.stream.pipeThrough( new TransformStream({ transform(chunk, controller) { collector.push(chunk); @@ -150,7 +150,10 @@ export const cacheCalls = ...res, stream: collector, }); - return res; + return { + ...res, + stream: piped, + }; }, }, });