Skip to content

Commit 6150f06

Browse files
authored
Merge pull request #1214 from narengogi/chore/finish-reason-mapping-part-2
finish reason mapping part 2
2 parents 1464cb4 + a9a982b commit 6150f06

File tree

15 files changed

+399
-61
lines changed

15 files changed

+399
-61
lines changed

src/providers/anthropic/complete.ts

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
11
import { ANTHROPIC } from '../../globals';
22
import { Params } from '../../types/requestBody';
33
import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types';
4-
import { generateInvalidProviderResponseError } from '../utils';
4+
import {
5+
generateInvalidProviderResponseError,
6+
transformFinishReason,
7+
} from '../utils';
8+
import {
9+
ANTHROPIC_STOP_REASON,
10+
AnthropicStreamState,
11+
AnthropicErrorResponse,
12+
} from './types';
513
import { AnthropicErrorResponseTransform } from './utils';
6-
import { AnthropicErrorResponse } from './types';
714

815
// TODO: this configuration does not enforce the maximum token limit for the input parameter. If you want to enforce this, you might need to add a custom validation function or a max property to the ParameterConfig interface, and then use it in the input configuration. However, this might be complex because the token count is not a simple length check, but depends on the specific tokenization method used by the model.
916

@@ -57,7 +64,7 @@ export const AnthropicCompleteConfig: ProviderConfig = {
5764

5865
interface AnthropicCompleteResponse {
5966
completion: string;
60-
stop_reason: string;
67+
stop_reason: ANTHROPIC_STOP_REASON;
6168
model: string;
6269
truncated: boolean;
6370
stop: null | string;
@@ -68,10 +75,20 @@ interface AnthropicCompleteResponse {
6875
// TODO: The token calculation is wrong atm
6976
export const AnthropicCompleteResponseTransform: (
7077
response: AnthropicCompleteResponse | AnthropicErrorResponse,
71-
responseStatus: number
72-
) => CompletionResponse | ErrorResponse = (response, responseStatus) => {
73-
if (responseStatus !== 200 && 'error' in response) {
74-
return AnthropicErrorResponseTransform(response);
78+
responseStatus: number,
79+
responseHeaders: Headers,
80+
strictOpenAiCompliance: boolean
81+
) => CompletionResponse | ErrorResponse = (
82+
response,
83+
responseStatus,
84+
_responseHeaders,
85+
strictOpenAiCompliance
86+
) => {
87+
if (responseStatus !== 200) {
88+
const errorResposne = AnthropicErrorResponseTransform(
89+
response as AnthropicErrorResponse
90+
);
91+
if (errorResposne) return errorResposne;
7592
}
7693

7794
if ('completion' in response) {
@@ -86,7 +103,10 @@ export const AnthropicCompleteResponseTransform: (
86103
text: response.completion,
87104
index: 0,
88105
logprobs: null,
89-
finish_reason: response.stop_reason,
106+
finish_reason: transformFinishReason(
107+
response.stop_reason,
108+
strictOpenAiCompliance
109+
),
90110
},
91111
],
92112
};
@@ -96,8 +116,16 @@ export const AnthropicCompleteResponseTransform: (
96116
};
97117

98118
export const AnthropicCompleteStreamChunkTransform: (
99-
response: string
100-
) => string | undefined = (responseChunk) => {
119+
response: string,
120+
fallbackId: string,
121+
streamState: AnthropicStreamState,
122+
strictOpenAiCompliance: boolean
123+
) => string | undefined = (
124+
responseChunk,
125+
fallbackId,
126+
streamState,
127+
strictOpenAiCompliance
128+
) => {
101129
let chunk = responseChunk.trim();
102130
if (chunk.startsWith('event: ping')) {
103131
return;
@@ -110,6 +138,9 @@ export const AnthropicCompleteStreamChunkTransform: (
110138
return chunk;
111139
}
112140
const parsedChunk: AnthropicCompleteResponse = JSON.parse(chunk);
141+
const finishReason = parsedChunk.stop_reason
142+
? transformFinishReason(parsedChunk.stop_reason, strictOpenAiCompliance)
143+
: null;
113144
return (
114145
`data: ${JSON.stringify({
115146
id: parsedChunk.log_id,
@@ -122,7 +153,7 @@ export const AnthropicCompleteStreamChunkTransform: (
122153
text: parsedChunk.completion,
123154
index: 0,
124155
logprobs: null,
125-
finish_reason: parsedChunk.stop_reason,
156+
finish_reason: finishReason,
126157
},
127158
],
128159
})}` + '\n\n'

src/providers/bedrock/complete.ts

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import { BEDROCK } from '../../globals';
22
import { Params } from '../../types/requestBody';
33
import { CompletionResponse, ErrorResponse, ProviderConfig } from '../types';
4-
import { generateInvalidProviderResponseError } from '../utils';
4+
import {
5+
generateInvalidProviderResponseError,
6+
transformFinishReason,
7+
} from '../utils';
58
import { BedrockErrorResponseTransform } from './chatComplete';
69
import { BedrockErrorResponse } from './embed';
10+
import { TITAN_STOP_REASON as TITAN_COMPLETION_REASON } from './types';
711

812
export const BedrockAnthropicCompleteConfig: ProviderConfig = {
913
prompt: {
@@ -380,7 +384,7 @@ export interface BedrockTitanCompleteResponse {
380384
results: {
381385
tokenCount: number;
382386
outputText: string;
383-
completionReason: string;
387+
completionReason: TITAN_COMPLETION_REASON;
384388
}[];
385389
}
386390

@@ -420,7 +424,10 @@ export const BedrockTitanCompleteResponseTransform: (
420424
text: generation.outputText,
421425
index: index,
422426
logprobs: null,
423-
finish_reason: generation.completionReason,
427+
finish_reason: transformFinishReason(
428+
generation.completionReason,
429+
strictOpenAiCompliance
430+
),
424431
})),
425432
usage: {
426433
prompt_tokens: response.inputTextTokenCount,
@@ -437,7 +444,7 @@ export interface BedrockTitanStreamChunk {
437444
outputText: string;
438445
index: number;
439446
totalOutputTextTokenCount: number;
440-
completionReason: string | null;
447+
completionReason: TITAN_COMPLETION_REASON | null;
441448
'amazon-bedrock-invocationMetrics': {
442449
inputTokenCount: number;
443450
outputTokenCount: number;
@@ -462,6 +469,12 @@ export const BedrockTitanCompleteStreamChunkTransform: (
462469
let chunk = responseChunk.trim();
463470
chunk = chunk.trim();
464471
const parsedChunk: BedrockTitanStreamChunk = JSON.parse(chunk);
472+
const finishReason = parsedChunk.completionReason
473+
? transformFinishReason(
474+
parsedChunk.completionReason,
475+
_strictOpenAiCompliance
476+
)
477+
: null;
465478

466479
return [
467480
`data: ${JSON.stringify({
@@ -490,7 +503,7 @@ export const BedrockTitanCompleteStreamChunkTransform: (
490503
text: '',
491504
index: 0,
492505
logprobs: null,
493-
finish_reason: parsedChunk.completionReason,
506+
finish_reason: finishReason,
494507
},
495508
],
496509
usage: {

src/providers/bedrock/types.ts

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ export interface BedrockChatCompletionResponse {
108108
content: BedrockContentItem[];
109109
};
110110
};
111-
stopReason: BEDROCK_STOP_REASON;
111+
stopReason: BEDROCK_CONVERSE_STOP_REASON;
112112
usage: {
113113
inputTokens: number;
114114
outputTokens: number;
@@ -156,7 +156,7 @@ export type BedrockContentItem = {
156156
};
157157

158158
export interface BedrockStreamState {
159-
stopReason?: BEDROCK_STOP_REASON;
159+
stopReason?: BEDROCK_CONVERSE_STOP_REASON;
160160
currentToolCallIndex?: number;
161161
currentContentBlockIndex?: number;
162162
}
@@ -186,7 +186,7 @@ export interface BedrockChatCompleteStreamChunk {
186186
input?: object;
187187
};
188188
};
189-
stopReason?: BEDROCK_STOP_REASON;
189+
stopReason?: BEDROCK_CONVERSE_STOP_REASON;
190190
metrics?: {
191191
latencyMs: number;
192192
};
@@ -199,13 +199,22 @@ export interface BedrockChatCompleteStreamChunk {
199199
cacheWriteInputTokenCount?: number;
200200
cacheWriteInputTokens?: number;
201201
};
202+
message?: string;
202203
}
203204

204-
export enum BEDROCK_STOP_REASON {
205+
export enum BEDROCK_CONVERSE_STOP_REASON {
205206
end_turn = 'end_turn',
206207
tool_use = 'tool_use',
207208
max_tokens = 'max_tokens',
208209
stop_sequence = 'stop_sequence',
209210
guardrail_intervened = 'guardrail_intervened',
210211
content_filtered = 'content_filtered',
211212
}
213+
214+
export enum TITAN_STOP_REASON {
215+
FINISHED = 'FINISHED',
216+
LENGTH = 'LENGTH',
217+
STOP_CRITERIA_MET = 'STOP_CRITERIA_MET',
218+
RAG_QUERY_WHEN_RAG_DISABLED = 'RAG_QUERY_WHEN_RAG_DISABLED',
219+
CONTENT_FILTERED = 'CONTENT_FILTERED',
220+
}

src/providers/deepseek/chatComplete.ts

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import {
99
import {
1010
generateErrorResponse,
1111
generateInvalidProviderResponseError,
12+
transformFinishReason,
1213
} from '../utils';
14+
import { DEEPSEEK_STOP_REASON } from './types';
1315

1416
export const DeepSeekChatCompleteConfig: ProviderConfig = {
1517
model: {
@@ -127,8 +129,15 @@ interface DeepSeekStreamChunk {
127129

128130
export const DeepSeekChatCompleteResponseTransform: (
129131
response: DeepSeekChatCompleteResponse | DeepSeekErrorResponse,
130-
responseStatus: number
131-
) => ChatCompletionResponse | ErrorResponse = (response, responseStatus) => {
132+
responseStatus: number,
133+
responseHeaders: Headers,
134+
strictOpenAiCompliance: boolean
135+
) => ChatCompletionResponse | ErrorResponse = (
136+
response,
137+
responseStatus,
138+
_responseHeaders,
139+
strictOpenAiCompliance
140+
) => {
132141
if ('message' in response && responseStatus !== 200) {
133142
return generateErrorResponse(
134143
{
@@ -154,7 +163,10 @@ export const DeepSeekChatCompleteResponseTransform: (
154163
role: c.message.role,
155164
content: c.message.content,
156165
},
157-
finish_reason: c.finish_reason,
166+
finish_reason: transformFinishReason(
167+
c.finish_reason as DEEPSEEK_STOP_REASON,
168+
strictOpenAiCompliance
169+
),
158170
})),
159171
usage: {
160172
prompt_tokens: response.usage?.prompt_tokens,
@@ -168,15 +180,31 @@ export const DeepSeekChatCompleteResponseTransform: (
168180
};
169181

170182
export const DeepSeekChatCompleteStreamChunkTransform: (
171-
response: string
172-
) => string = (responseChunk) => {
183+
response: string,
184+
fallbackId: string,
185+
streamState: any,
186+
strictOpenAiCompliance: boolean,
187+
gatewayRequest: Params
188+
) => string | string[] = (
189+
responseChunk,
190+
fallbackId,
191+
_streamState,
192+
strictOpenAiCompliance,
193+
_gatewayRequest
194+
) => {
173195
let chunk = responseChunk.trim();
174196
chunk = chunk.replace(/^data: /, '');
175197
chunk = chunk.trim();
176198
if (chunk === '[DONE]') {
177199
return `data: ${chunk}\n\n`;
178200
}
179201
const parsedChunk: DeepSeekStreamChunk = JSON.parse(chunk);
202+
const finishReason = parsedChunk.choices[0].finish_reason
203+
? transformFinishReason(
204+
parsedChunk.choices[0].finish_reason as DEEPSEEK_STOP_REASON,
205+
strictOpenAiCompliance
206+
)
207+
: null;
180208
return (
181209
`data: ${JSON.stringify({
182210
id: parsedChunk.id,
@@ -188,7 +216,7 @@ export const DeepSeekChatCompleteStreamChunkTransform: (
188216
{
189217
index: parsedChunk.choices[0].index,
190218
delta: parsedChunk.choices[0].delta,
191-
finish_reason: parsedChunk.choices[0].finish_reason,
219+
finish_reason: finishReason,
192220
},
193221
],
194222
usage: parsedChunk.usage,

src/providers/deepseek/types.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
export enum DEEPSEEK_STOP_REASON {
2+
stop = 'stop',
3+
length = 'length',
4+
tool_calls = 'tool_calls',
5+
content_filter = 'content_filter',
6+
insufficient_system_resource = 'insufficient_system_resource',
7+
}

src/providers/google-vertex-ai/chatComplete.ts

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import {
3737
import {
3838
generateErrorResponse,
3939
generateInvalidProviderResponseError,
40+
transformFinishReason,
4041
} from '../utils';
4142
import { transformGenerationConfig } from './transformGenerationConfig';
4243
import type {
@@ -500,7 +501,10 @@ export const GoogleChatCompleteResponseTransform: (
500501
return {
501502
message: message,
502503
index: index,
503-
finish_reason: generation.finishReason,
504+
finish_reason: transformFinishReason(
505+
generation.finishReason,
506+
strictOpenAiCompliance
507+
),
504508
logprobs,
505509
...(!strictOpenAiCompliance && {
506510
safetyRatings: generation.safetyRatings,
@@ -621,6 +625,13 @@ export const GoogleChatCompleteStreamChunkTransform: (
621625
provider: GOOGLE_VERTEX_AI,
622626
choices:
623627
parsedChunk.candidates?.map((generation, index) => {
628+
const finishReason = generation.finishReason
629+
? transformFinishReason(
630+
parsedChunk.candidates[0].finishReason,
631+
strictOpenAiCompliance
632+
)
633+
: null;
634+
624635
let message: any = { role: 'assistant', content: '' };
625636
if (generation.content?.parts[0]?.text) {
626637
const contentBlocks = [];
@@ -667,7 +678,7 @@ export const GoogleChatCompleteStreamChunkTransform: (
667678
return {
668679
delta: message,
669680
index: index,
670-
finish_reason: generation.finishReason,
681+
finish_reason: finishReason,
671682
...(!strictOpenAiCompliance && {
672683
safetyRatings: generation.safetyRatings,
673684
}),
@@ -767,7 +778,10 @@ export const VertexAnthropicChatCompleteResponseTransform: (
767778
},
768779
index: 0,
769780
logprobs: null,
770-
finish_reason: response.stop_reason,
781+
finish_reason: transformFinishReason(
782+
response.stop_reason,
783+
strictOpenAiCompliance
784+
),
771785
},
772786
],
773787
usage: {
@@ -883,7 +897,10 @@ export const VertexAnthropicChatCompleteStreamChunkTransform: (
883897
{
884898
index: 0,
885899
delta: {},
886-
finish_reason: parsedChunk.delta?.stop_reason,
900+
finish_reason: transformFinishReason(
901+
parsedChunk.delta?.stop_reason,
902+
strictOpenAiCompliance
903+
),
887904
},
888905
],
889906
usage: {

0 commit comments

Comments
 (0)