Skip to content

Commit f4e2aba

Browse files
authored
retry incremental prompt on invalid request (#1765)
1 parent 342fc4b commit f4e2aba

File tree

4 files changed

+50
-8
lines changed

4 files changed

+50
-8
lines changed

.changeset/puny-pens-clap.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
---
2+
"@effect/ai-openai": patch
3+
"effect": patch
4+
---
5+
6+
retry incremental prompt on invalid request

packages/ai/openai/src/OpenAiClient.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,9 @@ const makeSocket = Effect.gen(function*() {
434434
const text = typeof msg === "string" ? msg : decoder.decode(msg)
435435
try {
436436
const event = decodeEvent(text)
437+
if (event.type === "error") {
438+
tracker.clearUnsafe()
439+
}
437440
if (event.type === "error" && "status" in event) {
438441
return Queue.fail(
439442
currentQueue,
@@ -442,7 +445,10 @@ const makeSocket = Effect.gen(function*() {
442445
method: "createResponseStream",
443446
reason: AiError.reasonFromHttpStatus({
444447
status: event.status,
445-
metadata: event.error
448+
metadata: {
449+
...event.error,
450+
description: event.error.message
451+
}
446452
})
447453
})
448454
)

packages/effect/src/unstable/ai/AiError.ts

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1396,10 +1396,8 @@ export class AiError extends Schema.ErrorClass<AiError>(
13961396
method: Schema.String,
13971397
reason: AiErrorReason
13981398
}) {
1399-
/**
1400-
* @since 1.0.0
1401-
*/
14021399
readonly [TypeId] = TypeId
1400+
override readonly cause = this.reason
14031401

14041402
/**
14051403
* Delegates to the underlying reason's `isRetryable` getter.

packages/effect/src/unstable/ai/LanguageModel.ts

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,20 @@ export const make: (params: {
937937
const tracker = Option.getOrUndefined(yield* Effect.serviceOption(ResponseIdTracker.ResponseIdTracker))
938938
const toolChoice = options.toolChoice ?? "auto"
939939

940+
const withNonIncrementalFallback = <R>(
941+
effect: Effect.Effect<Array<Response.PartEncoded>, AiError.AiError, R>
942+
): Effect.Effect<Array<Response.PartEncoded>, AiError.AiError, R | IdGenerator> =>
943+
providerOptions.incrementalPrompt ?
944+
effect.pipe(
945+
Effect.catchReason("AiError", "InvalidRequestError", (_) =>
946+
params.generateText({
947+
...providerOptions,
948+
incrementalPrompt: undefined,
949+
previousResponseId: undefined
950+
}))
951+
) :
952+
effect
953+
940954
// Check for pending approvals that need resolution
941955
const { approved, denied } = collectToolApprovals(
942956
providerOptions.prompt.content,
@@ -968,7 +982,7 @@ export const make: (params: {
968982
const ResponseSchema = Schema.mutable(
969983
Schema.Array(Response.Part(Toolkit.empty))
970984
)
971-
const rawContent = yield* params.generateText(providerOptions)
985+
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
972986
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
973987
if (tracker) {
974988
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1006,7 +1020,7 @@ export const make: (params: {
10061020
const ResponseSchema = Schema.mutable(
10071021
Schema.Array(Response.Part(Toolkit.empty))
10081022
)
1009-
const rawContent = yield* params.generateText(providerOptions)
1023+
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
10101024
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
10111025
if (tracker) {
10121026
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1085,7 +1099,7 @@ export const make: (params: {
10851099
// If tool call resolution is disabled, return the response without
10861100
// resolving the tool calls that were generated
10871101
if (options.disableToolCallResolution === true) {
1088-
const rawContent = yield* params.generateText(providerOptions)
1102+
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
10891103
const content = yield* Schema.decodeEffect(ResponseSchema)(rawContent)
10901104
if (tracker) {
10911105
const responseMetadata = content.find((part) => part.type === "response-metadata")
@@ -1096,7 +1110,7 @@ export const make: (params: {
10961110
return content as Array<Response.Part<Tools>>
10971111
}
10981112

1099-
const rawContent = yield* params.generateText(providerOptions)
1113+
const rawContent = yield* withNonIncrementalFallback(params.generateText(providerOptions))
11001114

11011115
// Resolve the generated tool calls
11021116
const toolResults = yield* resolveToolCalls(
@@ -1154,6 +1168,20 @@ export const make: (params: {
11541168
const tracker = Option.getOrUndefined(yield* Effect.serviceOption(ResponseIdTracker.ResponseIdTracker))
11551169
const toolChoice = options.toolChoice ?? "auto"
11561170

1171+
const withNonIncrementalFallback = <R>(
1172+
stream: Stream.Stream<Response.StreamPartEncoded, AiError.AiError, R>
1173+
): Stream.Stream<Response.StreamPartEncoded, AiError.AiError, R | IdGenerator> =>
1174+
providerOptions.incrementalPrompt ?
1175+
stream.pipe(
1176+
Stream.catchReason("AiError", "InvalidRequestError", (_) =>
1177+
params.streamText({
1178+
...providerOptions,
1179+
incrementalPrompt: undefined,
1180+
previousResponseId: undefined
1181+
}))
1182+
) :
1183+
stream
1184+
11571185
// Check for pending approvals that need resolution
11581186
const { approved: pendingApproved, denied: pendingDenied } = collectToolApprovals(providerOptions.prompt.content, {
11591187
excludeResolved: true
@@ -1185,6 +1213,7 @@ export const make: (params: {
11851213
const decodeParts = Schema.decodeEffect(schema)
11861214
return pipe(
11871215
params.streamText(providerOptions),
1216+
withNonIncrementalFallback,
11881217
Stream.mapArrayEffect((parts) =>
11891218
decodeParts(parts).pipe(
11901219
tracker ?
@@ -1234,6 +1263,7 @@ export const make: (params: {
12341263
const decodeParts = Schema.decodeEffect(schema)
12351264
return pipe(
12361265
params.streamText(providerOptions),
1266+
withNonIncrementalFallback,
12371267
Stream.mapArrayEffect((parts) =>
12381268
decodeParts(parts).pipe(
12391269
tracker ?
@@ -1340,6 +1370,7 @@ export const make: (params: {
13401370
const schema = Schema.NonEmptyArray(Response.StreamPart(toolkit))
13411371
const decodeParts = Schema.decodeEffect(schema)
13421372
return params.streamText(providerOptions).pipe(
1373+
withNonIncrementalFallback,
13431374
Stream.mapArrayEffect((parts) =>
13441375
decodeParts(parts).pipe(
13451376
tracker ?
@@ -1419,6 +1450,7 @@ export const make: (params: {
14191450
})
14201451

14211452
yield* params.streamText(providerOptions).pipe(
1453+
withNonIncrementalFallback,
14221454
Stream.runForEachArray(
14231455
Effect.fnUntraced(function*(chunk) {
14241456
const parts = yield* decodeParts(chunk)

0 commit comments

Comments
 (0)