Skip to content

Commit 822ab9e

Browse files
julien-cWauplinSBrandeis
authored
[inference] fold openai support into provider param (#1205)
ie. no need to override a `endpoint` anymore This only works in "client-side" mode ie when passing a provider key WDYT? --------- Co-authored-by: Wauplin <[email protected]> Co-authored-by: SBrandeis <[email protected]>
1 parent fac3157 commit 822ab9e

File tree

6 files changed

+7561
-7458
lines changed

6 files changed

+7561
-7458
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
1010
import { REPLICATE_CONFIG } from "../providers/replicate";
1111
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
1212
import { TOGETHER_CONFIG } from "../providers/together";
13+
import { OPENAI_CONFIG } from "../providers/openai";
1314
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
1415
import { isUrl } from "./isUrl";
1516
import { version as packageVersion, name as packageName } from "../../package.json";
@@ -33,6 +34,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
3334
"fireworks-ai": FIREWORKS_AI_CONFIG,
3435
"hf-inference": HF_INFERENCE_CONFIG,
3536
hyperbolic: HYPERBOLIC_CONFIG,
37+
openai: OPENAI_CONFIG,
3638
nebius: NEBIUS_CONFIG,
3739
novita: NOVITA_CONFIG,
3840
replicate: REPLICATE_CONFIG,
@@ -72,22 +74,38 @@ export async function makeRequestOptions(
7274
if (!providerConfig) {
7375
throw new Error(`No provider config found for provider ${provider}`);
7476
}
77+
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
78+
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
79+
}
7580
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
7681
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
77-
const model = await getProviderModelId({ model: hfModel, provider }, args, {
78-
task,
79-
chatCompletion,
80-
fetch: options?.fetch,
81-
});
82+
const model = providerConfig.clientSideRoutingOnly
83+
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
84+
removeProviderPrefix(maybeModel!, provider)
85+
: // For closed-models API providers, one needs to pass the model ID directly (e.g. "gpt-3.5-turbo")
86+
await getProviderModelId({ model: hfModel, provider }, args, {
87+
task,
88+
chatCompletion,
89+
fetch: options?.fetch,
90+
});
8291

83-
/// If accessToken is passed, it should take precedence over includeCredentials
84-
const authMethod = accessToken
85-
? accessToken.startsWith("hf_")
86-
? "hf-token"
87-
: "provider-key"
88-
: includeCredentials === "include"
89-
? "credentials-include"
90-
: "none";
92+
const authMethod = (() => {
93+
if (providerConfig.clientSideRoutingOnly) {
94+
// Closed-source providers require an accessToken (cannot be routed).
95+
if (accessToken && accessToken.startsWith("hf_")) {
96+
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
97+
}
98+
return "provider-key";
99+
}
100+
if (accessToken) {
101+
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
102+
}
103+
if (includeCredentials === "include") {
104+
// If accessToken is passed, it should take precedence over includeCredentials
105+
return "credentials-include";
106+
}
107+
return "none";
108+
})();
91109

92110
// Make URL
93111
const url = endpointUrl
@@ -176,3 +194,10 @@ async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[]
176194
}
177195
return await res.json();
178196
}
197+
198+
function removeProviderPrefix(model: string, provider: string): string {
199+
if (!model.startsWith(`${provider}/`)) {
200+
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
201+
}
202+
return model.slice(provider.length + 1);
203+
}

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2424
hyperbolic: {},
2525
nebius: {},
2626
novita: {},
27+
openai: {},
2728
replicate: {},
2829
sambanova: {},
2930
together: {},
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* Special case: provider configuration for a private models provider (OpenAI in this case).
3+
*/
4+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
5+
6+
const OPENAI_API_BASE_URL = "https://api.openai.com";
7+
8+
const makeBody = (params: BodyParams): Record<string, unknown> => {
9+
if (!params.chatCompletion) {
10+
throw new Error("OpenAI only supports chat completions.");
11+
}
12+
return {
13+
...params.args,
14+
model: params.model,
15+
};
16+
};
17+
18+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
19+
return { Authorization: `Bearer ${params.accessToken}` };
20+
};
21+
22+
const makeUrl = (params: UrlParams): string => {
23+
if (!params.chatCompletion) {
24+
throw new Error("OpenAI only supports chat completions.");
25+
}
26+
return `${params.baseUrl}/v1/chat/completions`;
27+
};
28+
29+
export const OPENAI_CONFIG: ProviderConfig = {
30+
baseUrl: OPENAI_API_BASE_URL,
31+
makeBody,
32+
makeHeaders,
33+
makeUrl,
34+
clientSideRoutingOnly: true,
35+
};

packages/inference/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export const INFERENCE_PROVIDERS = [
3737
"hyperbolic",
3838
"nebius",
3939
"novita",
40+
"openai",
4041
"replicate",
4142
"sambanova",
4243
"together",
@@ -96,6 +97,7 @@ export interface ProviderConfig {
9697
makeBody: (params: BodyParams) => Record<string, unknown>;
9798
makeHeaders: (params: HeaderParams) => Record<string, string>;
9899
makeUrl: (params: UrlParams) => string;
100+
clientSideRoutingOnly?: boolean;
99101
}
100102

101103
export interface HeaderParams {

packages/inference/test/HfInference.spec.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,9 @@ describe.concurrent("HfInference", () => {
755755
it("custom openai - OpenAI Specs", async () => {
756756
const OPENAI_KEY = env.OPENAI_KEY;
757757
const hf = new HfInference(OPENAI_KEY);
758-
const ep = hf.endpoint("https://api.openai.com");
759-
const stream = ep.chatCompletionStream({
760-
model: "gpt-3.5-turbo",
758+
const stream = hf.chatCompletionStream({
759+
provider: "openai",
760+
model: "openai/gpt-3.5-turbo",
761761
messages: [{ role: "user", content: "Complete the equation one + one =" }],
762762
}) as AsyncGenerator<ChatCompletionStreamOutput>;
763763
let out = "";
@@ -768,6 +768,15 @@ describe.concurrent("HfInference", () => {
768768
}
769769
expect(out).toContain("two");
770770
});
771+
it("OpenAI client side routing - model should have provider as prefix", async () => {
772+
await expect(
773+
new HfInference("dummy_token").chatCompletion({
774+
model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo"
775+
provider: "openai",
776+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
777+
})
778+
).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`);
779+
});
771780
},
772781
TIMEOUT
773782
);

0 commit comments

Comments
 (0)