Skip to content

Commit 7316934

Browse files
Merge branch 'main' into cerebras-provider
2 parents bd56b42 + 822ab9e commit 7316934

File tree

20 files changed

+7609
-7485
lines changed

20 files changed

+7609
-7485
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ You can run our packages with vanilla JS, without any bundler, by using a CDN or
9696

9797
```html
9898
<script type="module">
99-
import { HfInference } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected].0/+esm';
99+
import { HfInference } from 'https://cdn.jsdelivr.net/npm/@huggingface/[email protected].1/+esm';
100100
import { createRepo, commit, deleteRepo, listFiles } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]/+esm";
101101
</script>
102102
```

packages/inference/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "@huggingface/inference",
3-
"version": "3.4.0",
3+
"version": "3.4.1",
44
"packageManager": "[email protected]",
55
"license": "MIT",
66
"author": "Tim Mikeladze <[email protected]>",

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import { NOVITA_CONFIG } from "../providers/novita";
1111
import { REPLICATE_CONFIG } from "../providers/replicate";
1212
import { SAMBANOVA_CONFIG } from "../providers/sambanova";
1313
import { TOGETHER_CONFIG } from "../providers/together";
14+
import { OPENAI_CONFIG } from "../providers/openai";
1415
import type { InferenceProvider, InferenceTask, Options, ProviderConfig, RequestArgs } from "../types";
1516
import { isUrl } from "./isUrl";
1617
import { version as packageVersion, name as packageName } from "../../package.json";
@@ -35,6 +36,7 @@ const providerConfigs: Record<InferenceProvider, ProviderConfig> = {
3536
"fireworks-ai": FIREWORKS_AI_CONFIG,
3637
"hf-inference": HF_INFERENCE_CONFIG,
3738
hyperbolic: HYPERBOLIC_CONFIG,
39+
openai: OPENAI_CONFIG,
3840
nebius: NEBIUS_CONFIG,
3941
novita: NOVITA_CONFIG,
4042
replicate: REPLICATE_CONFIG,
@@ -74,22 +76,38 @@ export async function makeRequestOptions(
7476
if (!providerConfig) {
7577
throw new Error(`No provider config found for provider ${provider}`);
7678
}
79+
if (providerConfig.clientSideRoutingOnly && !maybeModel) {
80+
throw new Error(`Provider ${provider} requires a model ID to be passed directly.`);
81+
}
7782
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
7883
const hfModel = maybeModel ?? (await loadDefaultModel(task!));
79-
const model = await getProviderModelId({ model: hfModel, provider }, args, {
80-
task,
81-
chatCompletion,
82-
fetch: options?.fetch,
83-
});
84+
const model = providerConfig.clientSideRoutingOnly
85+
? // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
86+
removeProviderPrefix(maybeModel!, provider)
87+
: // For closed-models API providers, one needs to pass the model ID directly (e.g. "gpt-3.5-turbo")
88+
await getProviderModelId({ model: hfModel, provider }, args, {
89+
task,
90+
chatCompletion,
91+
fetch: options?.fetch,
92+
});
8493

85-
/// If accessToken is passed, it should take precedence over includeCredentials
86-
const authMethod = accessToken
87-
? accessToken.startsWith("hf_")
88-
? "hf-token"
89-
: "provider-key"
90-
: includeCredentials === "include"
91-
? "credentials-include"
92-
: "none";
94+
const authMethod = (() => {
95+
if (providerConfig.clientSideRoutingOnly) {
96+
// Closed-source providers require an accessToken (cannot be routed).
97+
if (accessToken && accessToken.startsWith("hf_")) {
98+
throw new Error(`Provider ${provider} is closed-source and does not support HF tokens.`);
99+
}
100+
return "provider-key";
101+
}
102+
if (accessToken) {
103+
return accessToken.startsWith("hf_") ? "hf-token" : "provider-key";
104+
}
105+
if (includeCredentials === "include") {
106+
// If accessToken is passed, it should take precedence over includeCredentials
107+
return "credentials-include";
108+
}
109+
return "none";
110+
})();
93111

94112
// Make URL
95113
const url = endpointUrl
@@ -178,3 +196,10 @@ async function loadTaskInfo(): Promise<Record<string, { models: { id: string }[]
178196
}
179197
return await res.json();
180198
}
199+
200+
function removeProviderPrefix(model: string, provider: string): string {
201+
if (!model.startsWith(`${provider}/`)) {
202+
throw new Error(`Models from ${provider} must be prefixed by "${provider}/". Got "${model}".`);
203+
}
204+
return model.slice(provider.length + 1);
205+
}

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ export const HARDCODED_MODEL_ID_MAPPING: Record<InferenceProvider, Record<ModelI
2525
hyperbolic: {},
2626
nebius: {},
2727
novita: {},
28+
openai: {},
2829
replicate: {},
2930
sambanova: {},
3031
together: {},
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/**
2+
* Special case: provider configuration for a private models provider (OpenAI in this case).
3+
*/
4+
import type { ProviderConfig, UrlParams, HeaderParams, BodyParams } from "../types";
5+
6+
const OPENAI_API_BASE_URL = "https://api.openai.com";
7+
8+
const makeBody = (params: BodyParams): Record<string, unknown> => {
9+
if (!params.chatCompletion) {
10+
throw new Error("OpenAI only supports chat completions.");
11+
}
12+
return {
13+
...params.args,
14+
model: params.model,
15+
};
16+
};
17+
18+
const makeHeaders = (params: HeaderParams): Record<string, string> => {
19+
return { Authorization: `Bearer ${params.accessToken}` };
20+
};
21+
22+
const makeUrl = (params: UrlParams): string => {
23+
if (!params.chatCompletion) {
24+
throw new Error("OpenAI only supports chat completions.");
25+
}
26+
return `${params.baseUrl}/v1/chat/completions`;
27+
};
28+
29+
export const OPENAI_CONFIG: ProviderConfig = {
30+
baseUrl: OPENAI_API_BASE_URL,
31+
makeBody,
32+
makeHeaders,
33+
makeUrl,
34+
clientSideRoutingOnly: true,
35+
};

packages/inference/src/tasks/nlp/featureExtraction.ts

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1+
import type { FeatureExtractionInput } from "@huggingface/tasks";
12
import { InferenceOutputError } from "../../lib/InferenceOutputError";
23
import type { BaseArgs, Options } from "../../types";
34
import { request } from "../custom/request";
45

5-
export type FeatureExtractionArgs = BaseArgs & {
6-
/**
7-
* The inputs is a string or a list of strings to get the features from.
8-
*
9-
* inputs: "That is a happy person",
10-
*
11-
*/
12-
inputs: string | string[];
13-
};
6+
export type FeatureExtractionArgs = BaseArgs & FeatureExtractionInput;
147

158
/**
169
* Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README).

packages/inference/src/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ export const INFERENCE_PROVIDERS = [
3838
"hyperbolic",
3939
"nebius",
4040
"novita",
41+
"openai",
4142
"replicate",
4243
"sambanova",
4344
"together",
@@ -97,6 +98,7 @@ export interface ProviderConfig {
9798
makeBody: (params: BodyParams) => Record<string, unknown>;
9899
makeHeaders: (params: HeaderParams) => Record<string, string>;
99100
makeUrl: (params: UrlParams) => string;
101+
clientSideRoutingOnly?: boolean;
100102
}
101103

102104
export interface HeaderParams {

packages/inference/test/HfInference.spec.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,9 @@ describe.concurrent("HfInference", () => {
755755
it("custom openai - OpenAI Specs", async () => {
756756
const OPENAI_KEY = env.OPENAI_KEY;
757757
const hf = new HfInference(OPENAI_KEY);
758-
const ep = hf.endpoint("https://api.openai.com");
759-
const stream = ep.chatCompletionStream({
760-
model: "gpt-3.5-turbo",
758+
const stream = hf.chatCompletionStream({
759+
provider: "openai",
760+
model: "openai/gpt-3.5-turbo",
761761
messages: [{ role: "user", content: "Complete the equation one + one =" }],
762762
}) as AsyncGenerator<ChatCompletionStreamOutput>;
763763
let out = "";
@@ -768,6 +768,15 @@ describe.concurrent("HfInference", () => {
768768
}
769769
expect(out).toContain("two");
770770
});
771+
it("OpenAI client side routing - model should have provider as prefix", async () => {
772+
await expect(
773+
new HfInference("dummy_token").chatCompletion({
774+
model: "gpt-3.5-turbo", // must be "openai/gpt-3.5-turbo"
775+
provider: "openai",
776+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
777+
})
778+
).rejects.toThrowError(`Models from openai must be prefixed by "openai/". Got "gpt-3.5-turbo".`);
779+
});
771780
},
772781
TIMEOUT
773782
);

0 commit comments

Comments
 (0)