Skip to content

Commit c67d71d

Browse files
committed
chore: add support for normal file upload for inference purposes
1 parent a9527a6 commit c67d71d

File tree

2 files changed

+211
-67
lines changed

2 files changed

+211
-67
lines changed

src/providers/google-vertex-ai/uploadFile.ts

Lines changed: 99 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { ProviderConfig, RequestHandler } from '../types';
22
import {
3+
generateSignedURL,
34
getModelAndProvider,
45
GoogleResponseHandler,
56
vertexRequestLineHandler,
@@ -49,7 +50,11 @@ export const GoogleFileUploadRequestHandler: RequestHandler<
4950
vertexBatchEndpoint = BatchEndpoints.CHAT_COMPLETIONS, //default to inference endpoint
5051
} = providerOptions;
5152

52-
if (!vertexModelName || !vertexStorageBucketName) {
53+
let purpose = requestHeaders['x-portkey-file-purpose'] ?? '';
54+
if (
55+
(purpose === 'upload' ? false : !vertexModelName) ||
56+
!vertexStorageBucketName
57+
) {
5358
return GoogleResponseHandler(
5459
'Invalid request, please provide `x-portkey-provider-model` and `x-portkey-vertex-storage-bucket-name` in the request headers',
5560
400
@@ -73,73 +78,79 @@ export const GoogleFileUploadRequestHandler: RequestHandler<
7378
}
7479

7580
let isPurposeHeader = false;
76-
let purpose = '';
81+
let transformStream: ReadableStream<any> | TransformStream<any, any> =
82+
requestBody;
83+
let uploadMethod = 'PUT';
7784
// Create a reusable line splitter stream
7885
const lineSplitter = createLineSplitter();
7986

80-
// Transform stream to process each complete line.
81-
const transformStream = new TransformStream({
82-
transform: function (chunk, controller) {
83-
let buffer;
84-
try {
85-
const _chunk = chunk.toString();
86-
87-
const match = _chunk.match(/name="([^"]+)"/);
88-
const headerKey = match ? match[1] : null;
89-
90-
if (headerKey && headerKey === 'purpose') {
91-
isPurposeHeader = true;
92-
return;
93-
}
94-
95-
if (isPurposeHeader && _chunk?.length > 0 && !purpose) {
96-
isPurposeHeader = false;
97-
purpose = _chunk.trim();
98-
return;
99-
}
100-
101-
if (!_chunk) {
102-
return;
103-
}
104-
105-
const json = JSON.parse(chunk.toString());
106-
107-
if (json && !purpose) {
108-
// Close the stream.
109-
controller.terminate();
87+
if (purpose === 'upload') {
88+
uploadMethod = 'POST';
89+
} else {
90+
// Transform stream to process each complete line.
91+
transformStream = new TransformStream({
92+
transform: function (chunk, controller) {
93+
let buffer;
94+
try {
95+
const _chunk = chunk.toString();
96+
97+
const match = _chunk.match(/name="([^"]+)"/);
98+
const headerKey = match ? match[1] : null;
99+
100+
if (headerKey && headerKey === 'purpose') {
101+
isPurposeHeader = true;
102+
return;
103+
}
104+
105+
if (isPurposeHeader && _chunk?.length > 0 && !purpose) {
106+
isPurposeHeader = false;
107+
purpose = _chunk.trim();
108+
return;
109+
}
110+
111+
if (!_chunk) {
112+
return;
113+
}
114+
115+
const json = JSON.parse(chunk.toString());
116+
117+
if (json && !purpose) {
118+
// Close the stream.
119+
controller.terminate();
120+
}
121+
122+
const toTranspose = purpose === 'batch' ? json.body : json;
123+
const transformedBody = transformUsingProviderConfig(
124+
providerConfig,
125+
toTranspose
126+
);
127+
128+
delete transformedBody['model'];
129+
130+
const bufferTransposed = vertexRequestLineHandler(
131+
purpose,
132+
vertexBatchEndpoint,
133+
transformedBody,
134+
json['custom_id']
135+
);
136+
137+
buffer = JSON.stringify(bufferTransposed);
138+
} catch {
139+
buffer = null;
140+
} finally {
141+
if (buffer) {
142+
controller.enqueue(encoder.encode(buffer + '\n'));
143+
}
110144
}
111-
112-
const toTranspose = purpose === 'batch' ? json.body : json;
113-
const transformedBody = transformUsingProviderConfig(
114-
providerConfig,
115-
toTranspose
116-
);
117-
118-
delete transformedBody['model'];
119-
120-
const bufferTransposed = vertexRequestLineHandler(
121-
purpose,
122-
vertexBatchEndpoint,
123-
transformedBody,
124-
json['custom_id']
125-
);
126-
127-
buffer = JSON.stringify(bufferTransposed);
128-
} catch {
129-
buffer = null;
130-
} finally {
131-
if (buffer) {
132-
controller.enqueue(encoder.encode(buffer + '\n'));
133-
}
134-
}
135-
},
136-
flush(controller) {
137-
controller.terminate();
138-
},
139-
});
145+
},
146+
flush(controller) {
147+
controller.terminate();
148+
},
149+
});
150+
requestBody.pipeThrough(lineSplitter).pipeTo(transformStream.writable);
151+
}
140152

141153
// Pipe the node stream through our line splitter and into the transform stream.
142-
requestBody.pipeThrough(lineSplitter).pipeTo(transformStream.writable);
143154

144155
const providerHeaders = await GoogleApiConfig.headers({
145156
c,
@@ -151,15 +162,36 @@ export const GoogleFileUploadRequestHandler: RequestHandler<
151162
});
152163

153164
const encodedFile = encodeURIComponent(objectKey ?? '');
154-
const url = `https://storage.googleapis.com/${vertexStorageBucketName}/${encodedFile}`;
165+
let url;
166+
if (uploadMethod !== 'POST') {
167+
url = `https://storage.googleapis.com/${vertexStorageBucketName}/${encodedFile}`;
168+
} else {
169+
url = await generateSignedURL(
170+
providerOptions.vertexServiceAccountJson ?? {},
171+
vertexStorageBucketName,
172+
objectKey,
173+
10 * 60,
174+
'POST',
175+
c.req.param(),
176+
{}
177+
);
178+
}
155179

156180
const options = {
157-
body: transformStream.readable,
181+
body:
182+
uploadMethod === 'POST'
183+
? (transformStream as ReadableStream<any>)
184+
: (transformStream as TransformStream).readable,
158185
headers: {
159-
Authorization: providerHeaders.Authorization,
160-
'Content-Type': 'application/octet-stream',
186+
...(uploadMethod !== 'POST'
187+
? { Authorization: providerHeaders.Authorization }
188+
: {}),
189+
'Content-Type':
190+
uploadMethod === 'POST'
191+
? requestHeaders['content-type']
192+
: 'application/octet-stream',
161193
},
162-
method: 'PUT',
194+
method: uploadMethod,
163195
duplex: 'half',
164196
};
165197

src/providers/google-vertex-ai/utils.ts

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,3 +569,115 @@ export const vertexRequestLineHandler = (
569569
export const isEmbeddingModel = (modelName: string) => {
570570
return modelName.includes('embedding');
571571
};
572+
573+
export const generateSignedURL = async (
574+
serviceAccountInfo: Record<string, any>,
575+
bucketName: string,
576+
objectName: string,
577+
expiration: number = 604800,
578+
httpMethod: string = 'GET',
579+
queryParameters: Record<string, string> = {},
580+
headers: Record<string, string> = {}
581+
): Promise<string> => {
582+
if (expiration > 604800) {
583+
throw new Error(
584+
"Expiration Time can't be longer than 604800 seconds (7 days)."
585+
);
586+
}
587+
588+
const escapedObjectName = encodeURIComponent(objectName).replace(/%2F/g, '/');
589+
const canonicalUri = `/${escapedObjectName}`;
590+
591+
const datetimeNow = new Date();
592+
const requestTimestamp = datetimeNow
593+
.toISOString()
594+
.replace(/[-:]/g, '') // Remove hyphens and colons
595+
.replace(/\.\d{3}Z$/, 'Z'); // Remove milliseconds and ensure Z at end
596+
const datestamp = datetimeNow.toISOString().slice(0, 10).replace(/-/g, '');
597+
598+
const clientEmail = serviceAccountInfo.client_email;
599+
const credentialScope = `${datestamp}/auto/storage/goog4_request`;
600+
const credential = `${clientEmail}/${credentialScope}`;
601+
602+
const host = `${bucketName}.storage.googleapis.com`;
603+
headers['host'] = host;
604+
605+
// Create canonical headers
606+
let canonicalHeaders = '';
607+
const orderedHeaders = Object.keys(headers).sort();
608+
for (const key of orderedHeaders) {
609+
const lowerKey = key.toLowerCase();
610+
const value = headers[key].toLowerCase();
611+
canonicalHeaders += `${lowerKey}:${value}\n`;
612+
}
613+
614+
// Create signed headers
615+
const signedHeaders = orderedHeaders
616+
.map((key) => key.toLowerCase())
617+
.join(';');
618+
619+
// Add required query parameters
620+
const queryParams: Record<string, string> = {
621+
...queryParameters,
622+
'X-Goog-Algorithm': 'GOOG4-RSA-SHA256',
623+
'X-Goog-Credential': credential,
624+
'X-Goog-Date': requestTimestamp,
625+
'X-Goog-Expires': expiration.toString(),
626+
'X-Goog-SignedHeaders': signedHeaders,
627+
};
628+
629+
// Create canonical query string
630+
const canonicalQueryString = Object.keys(queryParams)
631+
.sort()
632+
.map(
633+
(key) =>
634+
`${encodeURIComponent(key)}=${encodeURIComponent(queryParams[key])}`
635+
)
636+
.join('&');
637+
638+
// Create canonical request
639+
const canonicalRequest = [
640+
httpMethod,
641+
canonicalUri,
642+
canonicalQueryString,
643+
canonicalHeaders,
644+
signedHeaders,
645+
'UNSIGNED-PAYLOAD',
646+
].join('\n');
647+
648+
// Hash the canonical request
649+
const canonicalRequestHash = await crypto.subtle.digest(
650+
'SHA-256',
651+
new TextEncoder().encode(canonicalRequest)
652+
);
653+
654+
// Create string to sign
655+
const stringToSign = [
656+
'GOOG4-RSA-SHA256',
657+
requestTimestamp,
658+
credentialScope,
659+
Array.from(new Uint8Array(canonicalRequestHash))
660+
.map((b) => b.toString(16).padStart(2, '0'))
661+
.join(''),
662+
].join('\n');
663+
664+
// Sign the string
665+
const privateKey = await importPrivateKey(serviceAccountInfo.private_key);
666+
const signature = await crypto.subtle.sign(
667+
{
668+
name: 'RSASSA-PKCS1-v1_5',
669+
hash: { name: 'SHA-256' },
670+
},
671+
privateKey,
672+
new TextEncoder().encode(stringToSign)
673+
);
674+
675+
// Convert signature to hex
676+
const signatureHex = Array.from(new Uint8Array(signature))
677+
.map((b) => b.toString(16).padStart(2, '0'))
678+
.join('');
679+
680+
// Construct the final URL
681+
const schemeAndHost = `https://${host}`;
682+
return `${schemeAndHost}${canonicalUri}?${canonicalQueryString}&x-goog-signature=${signatureHex}`;
683+
};

0 commit comments

Comments
 (0)