Skip to content

Commit 72ee387

Browse files
authored
Merge pull request #1130 from narengogi/feat/bedrock-multimodal-embeddings
feat: multimodal embeddings for bedrock titan and cohere
2 parents 76676ba + c4a5c45 commit 72ee387

File tree

2 files changed

+154
-31
lines changed

2 files changed

+154
-31
lines changed

src/providers/bedrock/embed.ts

Lines changed: 108 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,129 @@
11
import { BEDROCK } from '../../globals';
2-
import { EmbedResponse } from '../../types/embedRequestBody';
2+
import { EmbedParams, EmbedResponse } from '../../types/embedRequestBody';
33
import { Params } from '../../types/requestBody';
44
import { ErrorResponse, ProviderConfig } from '../types';
55
import { generateInvalidProviderResponseError } from '../utils';
66
import { BedrockErrorResponseTransform } from './chatComplete';
77

88
export const BedrockCohereEmbedConfig: ProviderConfig = {
9-
input: {
10-
param: 'texts',
11-
required: true,
12-
transform: (params: any): string[] => {
13-
if (Array.isArray(params.input)) {
14-
return params.input;
15-
} else {
16-
return [params.input];
17-
}
9+
input: [
10+
{
11+
param: 'texts',
12+
required: false,
13+
transform: (params: EmbedParams): string[] | undefined => {
14+
if (typeof params.input === 'string') return [params.input];
15+
else if (Array.isArray(params.input) && params.input.length > 0) {
16+
const texts: string[] = [];
17+
params.input.forEach((item) => {
18+
if (typeof item === 'string') {
19+
texts.push(item);
20+
} else if (item.text) {
21+
texts.push(item.text);
22+
}
23+
});
24+
return texts.length > 0 ? texts : undefined;
25+
}
26+
},
1827
},
19-
},
28+
{
29+
param: 'images',
30+
required: false,
31+
transform: (params: EmbedParams): string[] | undefined => {
32+
if (Array.isArray(params.input) && params.input.length > 0) {
33+
const images: string[] = [];
34+
params.input.forEach((item) => {
35+
if (typeof item === 'object' && item.image?.base64) {
36+
images.push(item.image.base64);
37+
}
38+
});
39+
return images.length > 0 ? images : undefined;
40+
}
41+
},
42+
},
43+
],
2044
input_type: {
2145
param: 'input_type',
2246
required: true,
2347
},
2448
truncate: {
2549
param: 'truncate',
50+
required: false,
51+
},
52+
encoding_format: {
53+
param: 'embedding_types',
54+
required: false,
55+
transform: (params: any): string[] => {
56+
if (Array.isArray(params.encoding_format)) return params.encoding_format;
57+
return [params.encoding_format];
58+
},
2659
},
2760
};
2861

2962
export const BedrockTitanEmbedConfig: ProviderConfig = {
30-
input: {
31-
param: 'inputText',
32-
required: true,
63+
input: [
64+
{
65+
param: 'inputText',
66+
required: false,
67+
transform: (params: EmbedParams): string | undefined => {
68+
if (
69+
Array.isArray(params.input) &&
70+
typeof params.input[0] === 'object' &&
71+
params.input[0].text
72+
) {
73+
return params.input[0].text;
74+
}
75+
if (typeof params.input === 'string') return params.input;
76+
},
77+
},
78+
{
79+
param: 'inputImage',
80+
required: false,
81+
transform: (params: EmbedParams) => {
82+
// Titan models only support one image per request
83+
if (
84+
Array.isArray(params.input) &&
85+
typeof params.input[0] === 'object' &&
86+
params.input[0].image?.base64
87+
) {
88+
return params.input[0].image.base64;
89+
}
90+
},
91+
},
92+
],
93+
dimensions: [
94+
{
95+
param: 'dimensions',
96+
required: false,
97+
transform: (params: EmbedParams): number | undefined => {
98+
if (typeof params.input === 'string') return params.dimensions;
99+
},
100+
},
101+
{
102+
param: 'embeddingConfig',
103+
required: false,
104+
transform: (
105+
params: EmbedParams
106+
): { outputEmbeddingLength: number } | undefined => {
107+
if (Array.isArray(params.input) && params.dimensions) {
108+
return {
109+
outputEmbeddingLength: params.dimensions,
110+
};
111+
}
112+
},
113+
},
114+
],
115+
encoding_format: {
116+
param: 'embeddingTypes',
117+
required: false,
118+
transform: (params: any): string[] => {
119+
if (Array.isArray(params.encoding_format)) return params.encoding_format;
120+
return [params.encoding_format];
121+
},
122+
},
123+
// Titan specific parameters
124+
normalize: {
125+
param: 'normalize',
126+
required: false,
33127
},
34128
};
35129

src/providers/cohere/embed.ts

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,60 @@ import { generateErrorResponse } from '../utils';
44
import { COHERE } from '../../globals';
55

66
export const CohereEmbedConfig: ProviderConfig = {
7-
input: {
8-
param: 'texts',
9-
required: true,
10-
transform: (params: EmbedParams): string[] => {
11-
if (Array.isArray(params.input)) {
12-
return params.input as string[];
13-
} else {
14-
return [params.input];
15-
}
7+
input: [
8+
{
9+
param: 'texts',
10+
required: false,
11+
transform: (params: EmbedParams): string[] | undefined => {
12+
if (typeof params.input === 'string') return [params.input];
13+
else if (Array.isArray(params.input) && params.input.length > 0) {
14+
const texts: string[] = [];
15+
params.input.forEach((item) => {
16+
if (typeof item === 'string') {
17+
texts.push(item);
18+
} else if (item.text) {
19+
texts.push(item.text);
20+
}
21+
});
22+
return texts.length > 0 ? texts : undefined;
23+
}
24+
},
1625
},
17-
},
18-
model: {
19-
param: 'model',
20-
default: 'embed-english-light-v2.0',
21-
},
26+
{
27+
param: 'images',
28+
required: false,
29+
transform: (params: EmbedParams): string[] | undefined => {
30+
if (Array.isArray(params.input) && params.input.length > 0) {
31+
const images: string[] = [];
32+
params.input.forEach((item) => {
33+
if (typeof item === 'object' && item.image?.base64) {
34+
images.push(item.image.base64);
35+
}
36+
});
37+
return images.length > 0 ? images : undefined;
38+
}
39+
},
40+
},
41+
],
2242
input_type: {
2343
param: 'input_type',
44+
required: true,
45+
},
46+
truncate: {
47+
param: 'truncate',
2448
required: false,
2549
},
26-
embedding_types: {
50+
encoding_format: {
2751
param: 'embedding_types',
2852
required: false,
53+
transform: (params: any): string[] => {
54+
if (Array.isArray(params.encoding_format)) return params.encoding_format;
55+
return [params.encoding_format];
56+
},
2957
},
30-
truncate: {
31-
param: 'truncate',
58+
//backwards compatibility
59+
embedding_types: {
60+
param: 'embedding_types',
3261
required: false,
3362
},
3463
};

0 commit comments

Comments
 (0)