Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 65 additions & 1 deletion packages/llm/src/model.middleware.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
import { dir } from "@synstack/fs";
import { fsCache } from "@synstack/fs-cache";
import { MockLanguageModelV1, simulateReadableStream } from "ai/test";
import assert from "node:assert/strict";
import { describe, it } from "node:test";
import { afterEach, describe, it } from "node:test";
import type { Llm } from "./llm.types.ts";
import { cacheCalls, includeAssistantMessage } from "./model.middleware.ts";

const TMP_DIR = dir(import.meta.dirname).to("tmp");

afterEach(async () => {
await TMP_DIR.rm();
});

describe("includeAssistantMessage", () => {
it("adds the last assistant message to a generated output", async () => {
const model = new MockLanguageModelV1({
Expand Down Expand Up @@ -155,4 +162,61 @@ describe("cache", () => {

assert.equal(output, " world!");
});

it("caches a streamed response on miss", async () => {
const cache = fsCache(TMP_DIR.path).key(["stream-miss"]);

let callCount = 0;
const model = new MockLanguageModelV1({
doStream: () => {
callCount++;
return Promise.resolve({
rawCall: { rawPrompt: null, rawSettings: {} },
stream: simulateReadableStream({
chunks: [
{ type: "text-delta", textDelta: "Hello" },
{ type: "text-delta", textDelta: " " },
{ type: "text-delta", textDelta: "world" },
{
type: "finish",
finishReason: "stop",
usage: { promptTokens: 10, completionTokens: 20 },
},
] satisfies Array<Llm.Model.Stream.Part>,
chunkDelayInMs: 0,
}),
});
},
});

const wrappedModel = cacheCalls(cache)(model);

const params: Llm.Model.Stream.Options = {
inputFormat: "messages",
mode: { type: "regular" },
prompt: [
{
role: "user",
content: [
{
type: "text",
text: "Hello",
},
],
},
],
};

const res1 = await wrappedModel.doStream(params);

let output1 = "";
for await (const chunk of res1.stream) {
if (chunk.type !== "text-delta") continue;
output1 += chunk.textDelta;
}

assert.equal(output1, "Hello world");
// Ensure the stream can be consumed without throwing and collects all parts
assert.equal(callCount, 1);
});
});
7 changes: 5 additions & 2 deletions packages/llm/src/model.middleware.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ export const cacheCalls =
// We need to collect the parts and then set the cache
const res = await doStream();
const collector: Array<Llm.Model.Stream.Part> = [];
res.stream.pipeThrough(
const piped = res.stream.pipeThrough(
new TransformStream<Llm.Model.Stream.Part, Llm.Model.Stream.Part>({
transform(chunk, controller) {
collector.push(chunk);
Expand All @@ -150,7 +150,10 @@ export const cacheCalls =
...res,
stream: collector,
});
return res;
return {
...res,
stream: piped,
};
},
},
});
Expand Down
Loading