Skip to content

Commit 5acd856

Browse files
authored
Merge pull request #1089 from b4s36t4/feat/ft-batch-improvements
chore: batch & fine-tune improvements
2 parents 5d2ac80 + ee33dfe commit 5acd856

20 files changed

+1078
-127
lines changed

src/globals.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,9 @@ export const documentMimeTypes = [
229229
fileExtensionMimeTypeMap.md,
230230
fileExtensionMimeTypeMap.txt,
231231
];
232+
233+
export enum BatchEndpoints {
234+
CHAT_COMPLETIONS = '/v1/chat/completions',
235+
COMPLETIONS = '/v1/completions',
236+
EMBEDDINGS = '/v1/embeddings',
237+
}

src/handlers/handlerUtils.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,10 +931,13 @@ export function constructConfigFromRequestHeaders(
931931
requestHeaders[`x-${POWERED_BY}-vertex-storage-bucket-name`],
932932
filename: requestHeaders[`x-${POWERED_BY}-provider-file-name`],
933933
vertexModelName: requestHeaders[`x-${POWERED_BY}-provider-model`],
934+
vertexBatchEndpoint:
935+
requestHeaders[`x-${POWERED_BY}-provider-batch-endpoint`],
934936
};
935937

936938
const fireworksConfig = {
937939
fireworksAccountId: requestHeaders[`x-${POWERED_BY}-fireworks-account-id`],
940+
fireworksFileLength: requestHeaders[`x-${POWERED_BY}-file-upload-size`],
938941
};
939942

940943
const anthropicConfig = {

src/handlers/streamHandlerUtils.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ export function createLineSplitter(): TransformStream {
309309
leftover = lines.pop() || '';
310310
for (const line of lines) {
311311
if (line.trim()) {
312-
controller.enqueue(line);
312+
controller.enqueue(line.trim());
313313
}
314314
}
315315
return;

src/providers/fireworks-ai/api.ts

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,23 @@ const FireworksAIAPIConfig: ProviderAPIConfig = {
2121
Accept: 'application/json',
2222
};
2323
},
24-
getEndpoint: ({ fn, gatewayRequestBodyJSON: gatewayRequestBody, c }) => {
24+
getEndpoint: ({
25+
fn,
26+
gatewayRequestBodyJSON: gatewayRequestBody,
27+
c,
28+
gatewayRequestURL,
29+
}) => {
2530
const model = gatewayRequestBody?.model;
31+
32+
const jobIdIndex = ['cancelFinetune'].includes(fn ?? '') ? -2 : -1;
33+
const jobId = gatewayRequestURL.split('/').at(jobIdIndex);
34+
35+
const url = new URL(gatewayRequestURL);
36+
const params = url.searchParams;
37+
38+
const size = params.get('limit') ?? 50;
39+
const page = params.get('after') ?? '1';
40+
2641
switch (fn) {
2742
case 'complete':
2843
return '/completions';
@@ -33,7 +48,7 @@ const FireworksAIAPIConfig: ProviderAPIConfig = {
3348
case 'imageGenerate':
3449
return `/image_generation/${model}`;
3550
case 'uploadFile':
36-
return `/datasets`;
51+
return '';
3752
case 'retrieveFile': {
3853
const datasetId = c.req.param('id');
3954
return `/datasets/${datasetId}`;
@@ -45,13 +60,13 @@ const FireworksAIAPIConfig: ProviderAPIConfig = {
4560
return `/datasets/${datasetId}`;
4661
}
4762
case 'createFinetune':
48-
return `/fineTuningJobs`;
63+
return `/supervisedFineTuningJobs`;
4964
case 'retrieveFinetune':
50-
return `/fineTuningJobs/${c.req.param('jobId')}`;
65+
return `/supervisedFineTuningJobs/${jobId}`;
5166
case 'listFinetunes':
52-
return `/fineTuningJobs`;
67+
return `/supervisedFineTuningJobs?pageToken=${page}&pageSize=${size}`;
5368
case 'cancelFinetune':
54-
return `/fineTuningJobs/${c.req.param('jobId')}`;
69+
return `/supervisedFineTuningJobs/${jobId}`;
5570
default:
5671
return '';
5772
}
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import { FIREWORKS_AI } from '../../globals';
2+
import { Params } from '../../types/requestBody';
3+
import { RequestHandler } from '../types';
4+
import FireworksAIAPIConfig from './api';
5+
import { fireworkFinetuneToOpenAIFinetune } from './utils';
6+
7+
export const FireworkCancelFinetuneResponseTransform = (
8+
response: any,
9+
status: number
10+
) => {
11+
if (status !== 200) {
12+
const error = response?.error || 'Failed to cancel finetune';
13+
return new Response(JSON.stringify({ error: { message: error } }), {
14+
status: status || 500,
15+
});
16+
}
17+
18+
return fireworkFinetuneToOpenAIFinetune(response);
19+
};
20+
21+
export const FireworksCancelFinetuneRequestHandler: RequestHandler<
22+
Params
23+
> = async ({ requestBody, requestURL, providerOptions, c }) => {
24+
const headers = await FireworksAIAPIConfig.headers({
25+
c,
26+
fn: 'cancelFinetune',
27+
providerOptions,
28+
transformedRequestUrl: requestURL,
29+
transformedRequestBody: requestBody,
30+
});
31+
32+
const baseURL = await FireworksAIAPIConfig.getBaseURL({
33+
c,
34+
gatewayRequestURL: requestURL,
35+
providerOptions,
36+
});
37+
38+
const endpoint = FireworksAIAPIConfig.getEndpoint({
39+
c,
40+
fn: 'cancelFinetune',
41+
gatewayRequestBodyJSON: requestBody,
42+
gatewayRequestURL: requestURL,
43+
providerOptions,
44+
});
45+
46+
try {
47+
const request = await fetch(baseURL + endpoint, {
48+
method: 'DELETE',
49+
headers,
50+
body: JSON.stringify(requestBody),
51+
});
52+
53+
if (!request.ok) {
54+
const error = await request.json();
55+
return new Response(
56+
JSON.stringify({
57+
error: { message: (error as any).error },
58+
provider: FIREWORKS_AI,
59+
}),
60+
{
61+
status: 500,
62+
headers: {
63+
'Content-Type': 'application/json',
64+
},
65+
}
66+
);
67+
}
68+
69+
const response = await request.json();
70+
71+
const mappedResponse = fireworkFinetuneToOpenAIFinetune(response as any);
72+
73+
return new Response(JSON.stringify(mappedResponse), {
74+
status: 200,
75+
headers: { 'Content-Type': 'application/json' },
76+
});
77+
} catch (error) {
78+
const errorMessage =
79+
error instanceof Error ? error.message : 'Unknown error';
80+
return new Response(JSON.stringify({ error: { message: errorMessage } }), {
81+
status: 500,
82+
});
83+
}
84+
};
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import { FIREWORKS_AI } from '../../globals';
2+
import { constructConfigFromRequestHeaders } from '../../handlers/handlerUtils';
3+
import { transformUsingProviderConfig } from '../../services/transformToProviderRequest';
4+
import { Options } from '../../types/requestBody';
5+
import { FinetuneRequest, ProviderConfig } from '../types';
6+
import { fireworkFinetuneToOpenAIFinetune } from './utils';
7+
8+
export const getHyperparameters = (value: FinetuneRequest) => {
9+
let hyperparameters = value?.hyperparameters;
10+
if (!hyperparameters) {
11+
const method = value?.method?.type;
12+
const methodHyperparameters =
13+
method && value.method?.[method]?.hyperparameters;
14+
hyperparameters = methodHyperparameters;
15+
}
16+
return hyperparameters ?? {};
17+
};
18+
19+
export const FireworksFinetuneCreateConfig: ProviderConfig = {
20+
training_file: {
21+
param: 'dataset',
22+
required: true,
23+
},
24+
validation_file: {
25+
param: 'evaluationDataset',
26+
required: true,
27+
},
28+
suffix: {
29+
param: 'displayName',
30+
required: true,
31+
},
32+
model: {
33+
param: 'baseModel',
34+
required: true,
35+
},
36+
hyperparameters: {
37+
param: 'epochs',
38+
required: true,
39+
transform: (value: FinetuneRequest) => {
40+
return getHyperparameters(value).n_epochs;
41+
},
42+
},
43+
learning_rate: {
44+
param: 'learning_rate',
45+
required: true,
46+
transform: (value: FinetuneRequest) => {
47+
return getHyperparameters(value).learning_rate_multiplier;
48+
},
49+
default: (value: FinetuneRequest) => {
50+
return getHyperparameters(value).learning_rate_multiplier;
51+
},
52+
},
53+
output_model: {
54+
// use the suffix as the output model name
55+
param: 'outputModel',
56+
required: true,
57+
},
58+
};
59+
60+
export const FireworksRequestTransform = (
61+
requestBody: Record<string, any>,
62+
requestHeaders: Record<string, string>
63+
) => {
64+
const providerOptions = constructConfigFromRequestHeaders(
65+
requestHeaders
66+
) as Options;
67+
68+
if (requestBody.training_file) {
69+
requestBody.training_file = `accounts/${providerOptions.fireworksAccountId}/datasets/${requestBody.training_file}`;
70+
}
71+
72+
if (requestBody.validation_file) {
73+
requestBody.validation_file = `accounts/${providerOptions.fireworksAccountId}/datasets/${requestBody.validation_file}`;
74+
}
75+
76+
if (requestBody.model) {
77+
requestBody.model = `accounts/fireworks/models/${requestBody.model}`;
78+
}
79+
80+
if (requestBody.output_model) {
81+
requestBody.output_model = `accounts/${providerOptions.fireworksAccountId}/models/${requestBody.suffix}`;
82+
}
83+
84+
const transformedRequestBody = transformUsingProviderConfig(
85+
FireworksFinetuneCreateConfig,
86+
requestBody,
87+
providerOptions as Options
88+
);
89+
90+
return transformedRequestBody;
91+
};
92+
93+
export const FireworkFinetuneTransform = (response: any, status: number) => {
94+
if (status !== 200) {
95+
const error = response?.error || 'Failed to create finetune';
96+
return new Response(
97+
JSON.stringify({
98+
error: {
99+
message: error,
100+
},
101+
provider: FIREWORKS_AI,
102+
}),
103+
{
104+
status: status || 500,
105+
headers: {
106+
'content-type': 'application/json',
107+
},
108+
}
109+
);
110+
}
111+
112+
const mappedResponse = fireworkFinetuneToOpenAIFinetune(response);
113+
114+
return mappedResponse;
115+
};

src/providers/fireworks-ai/index.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import { ProviderConfigs } from '../types';
22
import FireworksAIAPIConfig from './api';
3+
import {
4+
FireworkCancelFinetuneResponseTransform,
5+
FireworksCancelFinetuneRequestHandler,
6+
} from './cancelFinetune';
37
import {
48
FireworksAIChatCompleteConfig,
59
FireworksAIChatCompleteResponseTransform,
@@ -10,6 +14,11 @@ import {
1014
FireworksAICompleteResponseTransform,
1115
FireworksAICompleteStreamChunkTransform,
1216
} from './complete';
17+
import {
18+
FireworkFinetuneTransform,
19+
FireworksFinetuneCreateConfig,
20+
FireworksRequestTransform,
21+
} from './createFinetune';
1322
import {
1423
FireworksAIEmbedConfig,
1524
FireworksAIEmbedResponseTransform,
@@ -19,13 +28,16 @@ import {
1928
FireworksAIImageGenerateResponseTransform,
2029
} from './imageGenerate';
2130
import { FireworksFileListResponseTransform } from './listFiles';
31+
import { FireworkListFinetuneResponseTransform } from './listFinetune';
2232
import { FireworksFileRetrieveResponseTransform } from './retrieveFile';
33+
import { FireworkFileUploadRequestHandler } from './uploadFile';
2334

2435
const FireworksAIConfig: ProviderConfigs = {
2536
complete: FireworksAICompleteConfig,
2637
chatComplete: FireworksAIChatCompleteConfig,
2738
embed: FireworksAIEmbedConfig,
2839
imageGenerate: FireworksAIImageGenerateConfig,
40+
createFinetune: FireworksFinetuneCreateConfig,
2941
api: FireworksAIAPIConfig,
3042
responseTransforms: {
3143
complete: FireworksAICompleteResponseTransform,
@@ -36,6 +48,17 @@ const FireworksAIConfig: ProviderConfigs = {
3648
imageGenerate: FireworksAIImageGenerateResponseTransform,
3749
listFiles: FireworksFileListResponseTransform,
3850
retrieveFile: FireworksFileRetrieveResponseTransform,
51+
listFinetunes: FireworkListFinetuneResponseTransform,
52+
retrieveFinetune: FireworkFinetuneTransform,
53+
createFinetune: FireworkFinetuneTransform,
54+
cancelFinetune: FireworkCancelFinetuneResponseTransform,
55+
},
56+
requestHandlers: {
57+
uploadFile: FireworkFileUploadRequestHandler,
58+
cancelFinetune: FireworksCancelFinetuneRequestHandler,
59+
},
60+
requestTransforms: {
61+
createFinetune: FireworksRequestTransform,
3962
},
4063
};
4164

src/providers/fireworks-ai/listFiles.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { FIREWORKS_AI } from '../../globals';
12
import {
23
FireworksAIErrorResponse,
34
FireworksAIErrorResponseTransform,
@@ -25,5 +26,11 @@ export const FireworksFileListResponseTransform = (
2526
};
2627
}
2728

28-
return FireworksAIErrorResponseTransform(response);
29+
return {
30+
error: {
31+
message: (response as any).message ?? 'unable to fetch files.',
32+
param: null,
33+
},
34+
provider: FIREWORKS_AI,
35+
};
2936
};
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { FinetuneResponse } from './types';
2+
import { fireworkFinetuneToOpenAIFinetune } from './utils';
3+
4+
export const FireworkListFinetuneResponseTransform = (
5+
response: any,
6+
status: number
7+
) => {
8+
if (status !== 200) {
9+
const error = response?.error || 'Failed to list finetunes';
10+
return new Response(
11+
JSON.stringify({
12+
error: { message: error },
13+
}),
14+
{
15+
status: status || 500,
16+
headers: {
17+
'Content-Type': 'application/json',
18+
},
19+
}
20+
);
21+
}
22+
23+
const list = response?.supervisedFineTuningJobs ?? [];
24+
const mappedResponse = list.map((finetune: FinetuneResponse) => {
25+
return fireworkFinetuneToOpenAIFinetune(finetune);
26+
});
27+
28+
const firstId = mappedResponse[0]?.id;
29+
const lastId = mappedResponse[mappedResponse.length - 1]?.id;
30+
31+
return {
32+
object: 'list',
33+
data: mappedResponse,
34+
first_id: firstId,
35+
last_id: lastId,
36+
has_more: !!response.nextPageToken,
37+
};
38+
};

0 commit comments

Comments
 (0)