Skip to content

Commit c667dad

Browse files
authored
EFF-730 Fix ai LanguageModel incremental prompt fallback (#1780)
1 parent 3015c2d commit c667dad

File tree

3 files changed

+163
-39
lines changed

3 files changed

+163
-39
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"effect": patch
3+
---
4+
5+
Fix `LanguageModel` incremental prompt fallback to reliably retry with the full prompt when an incremental request fails with `InvalidRequestError`.

packages/effect/src/unstable/ai/LanguageModel.ts

Lines changed: 42 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -937,19 +937,23 @@ export const make: (params: {
937937
const tracker = Option.getOrUndefined(yield* Effect.serviceOption(ResponseIdTracker.ResponseIdTracker))
938938
const toolChoice = options.toolChoice ?? "auto"
939939

940-
const withNonIncrementalFallback = <R>(
941-
effect: Effect.Effect<Array<Response.PartEncoded>, AiError.AiError, R>
942-
): Effect.Effect<Array<Response.PartEncoded>, AiError.AiError, R | IdGenerator> =>
943-
providerOptions.incrementalPrompt ?
944-
effect.pipe(
945-
Effect.catchReason("AiError", "InvalidRequestError", (_) =>
946-
params.generateText({
947-
...providerOptions,
948-
incrementalPrompt: undefined,
949-
previousResponseId: undefined
950-
}))
951-
) :
952-
effect
940+
const generateWithNonIncrementalFallback = () => {
941+
const requestOptions: ProviderOptions = {
942+
...providerOptions
943+
}
944+
const fallbackPrompt = requestOptions.prompt
945+
const fallbackOptions: ProviderOptions = {
946+
...requestOptions,
947+
prompt: fallbackPrompt,
948+
incrementalPrompt: undefined,
949+
previousResponseId: undefined
950+
}
951+
return requestOptions.incrementalPrompt
952+
? params.generateText(requestOptions).pipe(
953+
Effect.catchReason("AiError", "InvalidRequestError", (_) => params.generateText(fallbackOptions))
954+
)
955+
: params.generateText(requestOptions)
956+
}
953957

954958
// Check for pending approvals that need resolution
955959
const { approved, denied } = collectToolApprovals(
@@ -982,7 +986,7 @@ export const make: (params: {
982986
const ResponseSchema = Schema.mutable(
983987
Schema.Array(Response.Part(Toolkit.empty))
984988
)
985-
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
989+
const rawContent = yield* generateWithNonIncrementalFallback()
986990
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
987991
if (tracker) {
988992
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1020,7 +1024,7 @@ export const make: (params: {
10201024
const ResponseSchema = Schema.mutable(
10211025
Schema.Array(Response.Part(Toolkit.empty))
10221026
)
1023-
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
1027+
const rawContent = yield* generateWithNonIncrementalFallback()
10241028
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
10251029
if (tracker) {
10261030
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1099,7 +1103,7 @@ export const make: (params: {
10991103
// If tool call resolution is disabled, return the response without
11001104
// resolving the tool calls that were generated
11011105
if (options.disableToolCallResolution === true) {
1102-
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
1106+
const rawContent = yield* generateWithNonIncrementalFallback()
11031107
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
11041108
if (tracker) {
11051109
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1110,7 +1114,7 @@ export const make: (params: {
11101114
return content as Array<Response.Part<Tools>>
11111115
}
11121116

1113-
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
1117+
const rawContent = yield* generateWithNonIncrementalFallback()
11141118

11151119
// Resolve the generated tool calls
11161120
const toolResults = yield* resolveToolCalls(
@@ -1168,19 +1172,23 @@ export const make: (params: {
11681172
const tracker = Option.getOrUndefined(yield* Effect.serviceOption(ResponseIdTracker.ResponseIdTracker))
11691173
const toolChoice = options.toolChoice ?? "auto"
11701174

1171-
const withNonIncrementalFallback = <R>(
1172-
stream: Stream.Stream<Response.StreamPartEncoded, AiError.AiError, R>
1173-
): Stream.Stream<Response.StreamPartEncoded, AiError.AiError, R | IdGenerator> =>
1174-
providerOptions.incrementalPrompt ?
1175-
stream.pipe(
1176-
Stream.catchReason("AiError", "InvalidRequestError", (_) =>
1177-
params.streamText({
1178-
...providerOptions,
1179-
incrementalPrompt: undefined,
1180-
previousResponseId: undefined
1181-
}))
1182-
) :
1183-
stream
1175+
const streamWithNonIncrementalFallback = () => {
1176+
const requestOptions: ProviderOptions = {
1177+
...providerOptions
1178+
}
1179+
const fallbackPrompt = requestOptions.prompt
1180+
const fallbackOptions: ProviderOptions = {
1181+
...requestOptions,
1182+
prompt: fallbackPrompt,
1183+
incrementalPrompt: undefined,
1184+
previousResponseId: undefined
1185+
}
1186+
return requestOptions.incrementalPrompt
1187+
? params.streamText(requestOptions).pipe(
1188+
Stream.catchReason("AiError", "InvalidRequestError", (_) => params.streamText(fallbackOptions))
1189+
)
1190+
: params.streamText(requestOptions)
1191+
}
11841192

11851193
// Check for pending approvals that need resolution
11861194
const { approved: pendingApproved, denied: pendingDenied } = collectToolApprovals(providerOptions.prompt.content, {
@@ -1212,8 +1220,7 @@ export const make: (params: {
12121220
const schema = Schema.NonEmptyArray(Response.StreamPart(Toolkit.empty))
12131221
const decodeParts = Schema.decodeEffect(schema)
12141222
return pipe(
1215-
params.streamText(providerOptions),
1216-
withNonIncrementalFallback,
1223+
streamWithNonIncrementalFallback(),
12171224
Stream.mapArrayEffect((parts) =>
12181225
decodeParts(parts).pipe(
12191226
tracker ?
@@ -1262,8 +1269,7 @@ export const make: (params: {
12621269
const schema = Schema.NonEmptyArray(Response.StreamPart(Toolkit.empty))
12631270
const decodeParts = Schema.decodeEffect(schema)
12641271
return pipe(
1265-
params.streamText(providerOptions),
1266-
withNonIncrementalFallback,
1272+
streamWithNonIncrementalFallback(),
12671273
Stream.mapArrayEffect((parts) =>
12681274
decodeParts(parts).pipe(
12691275
tracker ?
@@ -1369,8 +1375,7 @@ export const make: (params: {
13691375
if (options.disableToolCallResolution === true) {
13701376
const schema = Schema.NonEmptyArray(Response.StreamPart(toolkit))
13711377
const decodeParts = Schema.decodeEffect(schema)
1372-
return params.streamText(providerOptions).pipe(
1373-
withNonIncrementalFallback,
1378+
return streamWithNonIncrementalFallback().pipe(
13741379
Stream.mapArrayEffect((parts) =>
13751380
decodeParts(parts).pipe(
13761381
tracker ?
@@ -1449,8 +1454,7 @@ export const make: (params: {
14491454
)
14501455
})
14511456

1452-
yield* params.streamText(providerOptions).pipe(
1453-
withNonIncrementalFallback,
1457+
yield* streamWithNonIncrementalFallback().pipe(
14541458
Stream.runForEachArray(
14551459
Effect.fnUntraced(function*(chunk) {
14561460
const parts = yield* decodeParts(chunk)

packages/effect/test/unstable/ai/LanguageModel.test.ts

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { describe, it } from "@effect/vitest"
22
import { assertDefined, assertTrue, deepStrictEqual, strictEqual } from "@effect/vitest/utils"
33
import { Effect, Latch, Option, Schema, Stream } from "effect"
44
import { TestClock } from "effect/testing"
5-
import { LanguageModel, Prompt, Response, ResponseIdTracker, Tool, Toolkit } from "effect/unstable/ai"
5+
import { AiError, LanguageModel, Prompt, Response, ResponseIdTracker, Tool, Toolkit } from "effect/unstable/ai"
66
import * as TestUtils from "./utils.ts"
77

88
const MyTool = Tool.make("MyTool", {
@@ -204,6 +204,121 @@ describe("LanguageModel", () => {
204204
strictEqual(capturedOptions.incrementalPrompt, undefined)
205205
}))
206206

207+
it("falls back to full prompt in generateText when incremental request fails", () =>
208+
Effect.gen(function*() {
209+
const fullPrompt = Prompt.make([
210+
Prompt.systemMessage({ content: "system" }),
211+
Prompt.userMessage({ content: [Prompt.textPart({ text: "user" })] }),
212+
Prompt.assistantMessage({ content: [Prompt.textPart({ text: "assistant" })] }),
213+
Prompt.userMessage({ content: [Prompt.textPart({ text: "next" })] })
214+
])
215+
216+
const incrementalPrompt = Prompt.make([
217+
Prompt.userMessage({ content: [Prompt.textPart({ text: "next" })] })
218+
])
219+
220+
const calls: Array<LanguageModel.ProviderOptions> = []
221+
222+
yield* LanguageModel.generateText({
223+
prompt: fullPrompt
224+
}).pipe(
225+
Effect.provideServiceEffect(
226+
LanguageModel.LanguageModel,
227+
LanguageModel.make({
228+
generateText: (options) => {
229+
calls.push(options)
230+
if (calls.length === 1) {
231+
;(options as any).prompt = options.incrementalPrompt ?? options.prompt
232+
return Effect.fail(AiError.make({
233+
module: "LanguageModelTest",
234+
method: "generateText",
235+
reason: new AiError.InvalidRequestError({
236+
description: "invalid previous response id"
237+
})
238+
}))
239+
}
240+
return Effect.succeed([finishPart])
241+
},
242+
streamText: () => Stream.empty
243+
})
244+
),
245+
Effect.provideService(ResponseIdTracker.ResponseIdTracker, {
246+
clearUnsafe() {},
247+
markParts() {},
248+
prepareUnsafe: () =>
249+
Option.some({
250+
previousResponseId: "resp_prev",
251+
prompt: incrementalPrompt
252+
})
253+
})
254+
)
255+
256+
strictEqual(calls.length, 2)
257+
strictEqual(calls[0]!.previousResponseId, "resp_prev")
258+
strictEqual(calls[0]!.incrementalPrompt, incrementalPrompt)
259+
strictEqual(calls[1]!.previousResponseId, undefined)
260+
strictEqual(calls[1]!.incrementalPrompt, undefined)
261+
deepStrictEqual(calls[1]!.prompt, fullPrompt)
262+
}))
263+
264+
it("falls back to full prompt in streamText when incremental request fails", () =>
265+
Effect.gen(function*() {
266+
const fullPrompt = Prompt.make([
267+
Prompt.systemMessage({ content: "system" }),
268+
Prompt.userMessage({ content: [Prompt.textPart({ text: "user" })] }),
269+
Prompt.assistantMessage({ content: [Prompt.textPart({ text: "assistant" })] }),
270+
Prompt.userMessage({ content: [Prompt.textPart({ text: "next" })] })
271+
])
272+
273+
const incrementalPrompt = Prompt.make([
274+
Prompt.userMessage({ content: [Prompt.textPart({ text: "next" })] })
275+
])
276+
277+
const calls: Array<LanguageModel.ProviderOptions> = []
278+
279+
yield* LanguageModel.streamText({
280+
prompt: fullPrompt
281+
}).pipe(
282+
Stream.runDrain,
283+
Effect.provideServiceEffect(
284+
LanguageModel.LanguageModel,
285+
LanguageModel.make({
286+
generateText: () => Effect.succeed([finishPart]),
287+
streamText: (options) => {
288+
calls.push(options)
289+
if (calls.length === 1) {
290+
;(options as any).prompt = options.incrementalPrompt ?? options.prompt
291+
return Stream.fail(AiError.make({
292+
module: "LanguageModelTest",
293+
method: "streamText",
294+
reason: new AiError.InvalidRequestError({
295+
description: "invalid previous response id"
296+
})
297+
}))
298+
}
299+
return Stream.fromIterable([finishPart])
300+
}
301+
})
302+
),
303+
Effect.provideService(ResponseIdTracker.ResponseIdTracker, {
304+
clearUnsafe() {},
305+
markParts() {},
306+
prepareUnsafe: () =>
307+
Option.some({
308+
previousResponseId: "resp_prev",
309+
prompt: incrementalPrompt
310+
})
311+
})
312+
)
313+
314+
strictEqual(calls.length, 2)
315+
strictEqual(calls[0]!.previousResponseId, "resp_prev")
316+
strictEqual(calls[0]!.incrementalPrompt, incrementalPrompt)
317+
strictEqual(calls[1]!.previousResponseId, undefined)
318+
strictEqual(calls[1]!.incrementalPrompt, undefined)
319+
deepStrictEqual(calls[1]!.prompt, fullPrompt)
320+
}))
321+
207322
it("uses tracker prepareUnsafe and markParts in generateText without toolkit", () =>
208323
Effect.gen(function*() {
209324
let capturedOptions: LanguageModel.ProviderOptions | undefined

0 commit comments

Comments
 (0)