Skip to content

Commit d5e797e

Browse files
committed
Restrict known additional model fields to the specific providers
1 parent 913a414 commit d5e797e

File tree

3 files changed

+203
-42
lines changed

3 files changed

+203
-42
lines changed

src/providers/bedrock/chatComplete.ts

Lines changed: 128 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ import {
1717
import { BedrockErrorResponse } from './embed';
1818
import {
1919
transformAdditionalModelRequestFields,
20+
transformAI21AdditionalModelRequestFields,
21+
transformAnthropicAdditionalModelRequestFields,
22+
transformCohereAdditionalModelRequestFields,
2023
transformInferenceConfig,
2124
} from './utils';
2225

@@ -32,6 +35,29 @@ export interface BedrockChatCompletionsParams extends Params {
3235
countPenalty?: number;
3336
}
3437

38+
export interface BedrockConverseAnthropicChatCompletionsParams
39+
extends BedrockChatCompletionsParams {
40+
anthropic_version?: string;
41+
user?: string;
42+
}
43+
44+
export interface BedrockConverseCohereChatCompletionsParams
45+
extends BedrockChatCompletionsParams {
46+
frequency_penalty?: number;
47+
presence_penalty?: number;
48+
logit_bias?: Record<string, number>;
49+
n?: number;
50+
}
51+
52+
export interface BedrockConverseAI21ChatCompletionsParams
53+
extends BedrockChatCompletionsParams {
54+
frequency_penalty?: number;
55+
presence_penalty?: number;
56+
frequencyPenalty?: number;
57+
presencePenalty?: number;
58+
countPenalty?: number;
59+
}
60+
3561
const getMessageTextContentArray = (message: Message): { text: string }[] => {
3662
if (message.content && typeof message.content === 'object') {
3763
return message.content
@@ -232,41 +258,6 @@ export const BedrockConverseChatCompleteConfig: ProviderConfig = {
232258
transform: (params: BedrockChatCompletionsParams) =>
233259
transformAdditionalModelRequestFields(params),
234260
},
235-
anthropic_version: {
236-
param: 'additionalModelRequestFields',
237-
transform: (params: BedrockChatCompletionsParams) =>
238-
transformAdditionalModelRequestFields(params),
239-
},
240-
frequency_penalty: {
241-
param: 'additionalModelRequestFields',
242-
transform: (params: BedrockChatCompletionsParams) =>
243-
transformAdditionalModelRequestFields(params),
244-
},
245-
presence_penalty: {
246-
param: 'additionalModelRequestFields',
247-
transform: (params: BedrockChatCompletionsParams) =>
248-
transformAdditionalModelRequestFields(params),
249-
},
250-
logit_bias: {
251-
param: 'additionalModelRequestFields',
252-
transform: (params: BedrockChatCompletionsParams) =>
253-
transformAdditionalModelRequestFields(params),
254-
},
255-
n: {
256-
param: 'additionalModelRequestFields',
257-
transform: (params: BedrockChatCompletionsParams) =>
258-
transformAdditionalModelRequestFields(params),
259-
},
260-
stream: {
261-
param: 'additionalModelRequestFields',
262-
transform: (params: BedrockChatCompletionsParams) =>
263-
transformAdditionalModelRequestFields(params),
264-
},
265-
countPenalty: {
266-
param: 'additionalModelRequestFields',
267-
transform: (params: BedrockChatCompletionsParams) =>
268-
transformAdditionalModelRequestFields(params),
269-
},
270261
};
271262

272263
interface BedrockChatCompletionResponse {
@@ -460,6 +451,108 @@ export const BedrockChatCompleteStreamChunkTransform: (
460451
})}\n\n`;
461452
};
462453

454+
export const BedrockConverseAnthropicChatCompleteConfig: ProviderConfig = {
455+
...BedrockConverseChatCompleteConfig,
456+
additionalModelRequestFields: {
457+
param: 'additionalModelRequestFields',
458+
transform: (params: BedrockChatCompletionsParams) =>
459+
transformAnthropicAdditionalModelRequestFields(params),
460+
},
461+
top_k: {
462+
param: 'additionalModelRequestFields',
463+
transform: (params: BedrockChatCompletionsParams) =>
464+
transformAnthropicAdditionalModelRequestFields(params),
465+
},
466+
anthropic_version: {
467+
param: 'additionalModelRequestFields',
468+
transform: (params: BedrockChatCompletionsParams) =>
469+
transformAnthropicAdditionalModelRequestFields(params),
470+
},
471+
user: {
472+
param: 'user',
473+
transform: (params: BedrockChatCompletionsParams) =>
474+
transformAnthropicAdditionalModelRequestFields(params),
475+
},
476+
};
477+
478+
export const BedrockConverseCohereChatCompleteConfig: ProviderConfig = {
479+
...BedrockConverseChatCompleteConfig,
480+
additionalModelRequestFields: {
481+
param: 'additionalModelRequestFields',
482+
transform: (params: BedrockChatCompletionsParams) =>
483+
transformCohereAdditionalModelRequestFields(params),
484+
},
485+
top_k: {
486+
param: 'additionalModelRequestFields',
487+
transform: (params: BedrockChatCompletionsParams) =>
488+
transformCohereAdditionalModelRequestFields(params),
489+
},
490+
frequency_penalty: {
491+
param: 'additionalModelRequestFields',
492+
transform: (params: BedrockChatCompletionsParams) =>
493+
transformCohereAdditionalModelRequestFields(params),
494+
},
495+
presence_penalty: {
496+
param: 'additionalModelRequestFields',
497+
transform: (params: BedrockChatCompletionsParams) =>
498+
transformCohereAdditionalModelRequestFields(params),
499+
},
500+
logit_bias: {
501+
param: 'additionalModelRequestFields',
502+
transform: (params: BedrockChatCompletionsParams) =>
503+
transformCohereAdditionalModelRequestFields(params),
504+
},
505+
n: {
506+
param: 'additionalModelRequestFields',
507+
transform: (params: BedrockChatCompletionsParams) =>
508+
transformCohereAdditionalModelRequestFields(params),
509+
},
510+
stream: {
511+
param: 'additionalModelRequestFields',
512+
transform: (params: BedrockChatCompletionsParams) =>
513+
transformCohereAdditionalModelRequestFields(params),
514+
},
515+
};
516+
517+
export const BedrockConverseAI21ChatCompleteConfig: ProviderConfig = {
518+
...BedrockConverseChatCompleteConfig,
519+
additionalModelRequestFields: {
520+
param: 'additionalModelRequestFields',
521+
transform: (params: BedrockChatCompletionsParams) =>
522+
transformAI21AdditionalModelRequestFields(params),
523+
},
524+
top_k: {
525+
param: 'additionalModelRequestFields',
526+
transform: (params: BedrockChatCompletionsParams) =>
527+
transformAI21AdditionalModelRequestFields(params),
528+
},
529+
frequency_penalty: {
530+
param: 'additionalModelRequestFields',
531+
transform: (params: BedrockChatCompletionsParams) =>
532+
transformAI21AdditionalModelRequestFields(params),
533+
},
534+
presence_penalty: {
535+
param: 'additionalModelRequestFields',
536+
transform: (params: BedrockChatCompletionsParams) =>
537+
transformAI21AdditionalModelRequestFields(params),
538+
},
539+
frequencyPenalty: {
540+
param: 'additionalModelRequestFields',
541+
transform: (params: BedrockChatCompletionsParams) =>
542+
transformAI21AdditionalModelRequestFields(params),
543+
},
544+
presencePenalty: {
545+
param: 'additionalModelRequestFields',
546+
transform: (params: BedrockChatCompletionsParams) =>
547+
transformAI21AdditionalModelRequestFields(params),
548+
},
549+
countPenalty: {
550+
param: 'additionalModelRequestFields',
551+
transform: (params: BedrockChatCompletionsParams) =>
552+
transformAI21AdditionalModelRequestFields(params),
553+
},
554+
};
555+
463556
export const BedrockCohereChatCompleteConfig: ProviderConfig = {
464557
messages: {
465558
param: 'prompt',

src/providers/bedrock/index.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ import {
1212
BedrockCohereChatCompleteResponseTransform,
1313
BedrockAI21ChatCompleteConfig,
1414
BedrockAI21ChatCompleteResponseTransform,
15+
BedrockConverseAnthropicChatCompleteConfig,
16+
BedrockConverseCohereChatCompleteConfig,
17+
BedrockConverseAI21ChatCompleteConfig,
1518
} from './chatComplete';
1619
import {
1720
BedrockAI21CompleteConfig,
@@ -64,6 +67,7 @@ const BedrockConfig: ProviderConfigs = {
6467
case ANTHROPIC:
6568
config = {
6669
complete: BedrockAnthropicCompleteConfig,
70+
chatComplete: BedrockConverseAnthropicChatCompleteConfig,
6771
api: BedrockAPIConfig,
6872
responseTransforms: {
6973
'stream-complete': BedrockAnthropicCompleteStreamChunkTransform,
@@ -74,6 +78,7 @@ const BedrockConfig: ProviderConfigs = {
7478
case COHERE:
7579
config = {
7680
complete: BedrockCohereCompleteConfig,
81+
chatComplete: BedrockConverseCohereChatCompleteConfig,
7782
embed: BedrockCohereEmbedConfig,
7883
api: BedrockAPIConfig,
7984
responseTransforms: {
@@ -126,6 +131,7 @@ const BedrockConfig: ProviderConfigs = {
126131
config = {
127132
complete: BedrockAI21CompleteConfig,
128133
api: BedrockAPIConfig,
134+
chatComplete: BedrockConverseAI21ChatCompleteConfig,
129135
responseTransforms: {
130136
complete: BedrockAI21CompleteResponseTransform,
131137
},

src/providers/bedrock/utils.ts

Lines changed: 69 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import { SignatureV4 } from '@smithy/signature-v4';
22
import { Sha256 } from '@aws-crypto/sha256-js';
3-
import { BedrockChatCompletionsParams } from './chatComplete';
3+
import {
4+
BedrockConverseAI21ChatCompletionsParams,
5+
BedrockConverseAnthropicChatCompletionsParams,
6+
BedrockChatCompletionsParams,
7+
BedrockConverseCohereChatCompletionsParams,
8+
} from './chatComplete';
49

510
export const generateAWSHeaders = async (
611
body: Record<string, any>,
@@ -66,26 +71,83 @@ export const transformAdditionalModelRequestFields = (
6671
const additionalModelRequestFields: Record<string, any> =
6772
params.additionalModelRequestFields || {};
6873
if (params['top_k']) {
69-
additionalModelRequestFields['topK'] = params['top_k'];
74+
additionalModelRequestFields['top_k'] = params['top_k'];
75+
}
76+
return additionalModelRequestFields;
77+
};
78+
79+
export const transformAnthropicAdditionalModelRequestFields = (
80+
params: BedrockConverseAnthropicChatCompletionsParams
81+
) => {
82+
const additionalModelRequestFields: Record<string, any> =
83+
params.additionalModelRequestFields || {};
84+
if (params['top_k']) {
85+
additionalModelRequestFields['top_k'] = params['top_k'];
7086
}
71-
// Backward compatibility
7287
if (params['anthropic_version']) {
7388
additionalModelRequestFields['anthropic_version'] =
7489
params['anthropic_version'];
7590
}
91+
if (params['user']) {
92+
additionalModelRequestFields['metadata'] = {
93+
user_id: params['user'],
94+
};
95+
}
96+
return additionalModelRequestFields;
97+
};
98+
99+
export const transformCohereAdditionalModelRequestFields = (
100+
params: BedrockConverseCohereChatCompletionsParams
101+
) => {
102+
const additionalModelRequestFields: Record<string, any> =
103+
params.additionalModelRequestFields || {};
104+
if (params['top_k']) {
105+
additionalModelRequestFields['top_k'] = params['top_k'];
106+
}
107+
if (params['n']) {
108+
additionalModelRequestFields['n'] = params['n'];
109+
}
76110
if (params['frequency_penalty']) {
77-
additionalModelRequestFields['frequencyPenalty'] =
111+
additionalModelRequestFields['frequency_penalty'] =
78112
params['frequency_penalty'];
79113
}
80114
if (params['presence_penalty']) {
81-
additionalModelRequestFields['presencePenalty'] =
115+
additionalModelRequestFields['presence_penalty'] =
82116
params['presence_penalty'];
83117
}
84118
if (params['logit_bias']) {
85119
additionalModelRequestFields['logitBias'] = params['logit_bias'];
86120
}
87-
if (params['n']) {
88-
additionalModelRequestFields['n'] = params['n'];
121+
if (params['stream']) {
122+
additionalModelRequestFields['stream'] = params['stream'];
123+
}
124+
return additionalModelRequestFields;
125+
};
126+
127+
export const transformAI21AdditionalModelRequestFields = (
128+
params: BedrockConverseAI21ChatCompletionsParams
129+
) => {
130+
const additionalModelRequestFields: Record<string, any> =
131+
params.additionalModelRequestFields || {};
132+
if (params['top_k']) {
133+
additionalModelRequestFields['top_k'] = params['top_k'];
134+
}
135+
if (params['frequency_penalty']) {
136+
additionalModelRequestFields['frequencyPenalty'] = {
137+
scale: params['frequency_penalty'],
138+
};
139+
}
140+
if (params['presence_penalty']) {
141+
additionalModelRequestFields['presencePenalty'] = {
142+
scale: params['presence_penalty'],
143+
};
144+
}
145+
if (params['frequencyPenalty']) {
146+
additionalModelRequestFields['frequencyPenalty'] =
147+
params['frequencyPenalty'];
148+
}
149+
if (params['presencePenalty']) {
150+
additionalModelRequestFields['presencePenalty'] = params['presencePenalty'];
89151
}
90152
if (params['countPenalty']) {
91153
additionalModelRequestFields['countPenalty'] = params['countPenalty'];

0 commit comments

Comments
 (0)