Skip to content

Commit c242011

Browse files
authored
Feat/add baseten provider (#1778)
This PR adds Baseten as a new inference provider for conversational tasks. Route: /v1/chat/completions (OpenAI Compatible) Added tests to `InferenceClient.spec.ts`. Ran tests locally to check they're all passing. --------- Co-authored-by: AlexKer <[email protected]>
1 parent 74d949c commit c242011

File tree

6 files changed

+91
-0
lines changed

6 files changed

+91
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Currently, we support the following providers:
6161
- [Sambanova](https://sambanova.ai)
6262
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
6363
- [Together](https://together.xyz)
64+
- [Baseten](https://baseten.co)
6465
- [Blackforestlabs](https://blackforestlabs.ai)
6566
- [Cohere](https://cohere.com)
6667
- [Cerebras](https://cerebras.ai/)
@@ -97,6 +98,7 @@ Only a subset of models are supported when requesting third-party providers. You
9798
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
9899
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
99100
- [Together supported models](https://huggingface.co/api/partners/together/models)
101+
- [Baseten supported models](https://huggingface.co/api/partners/baseten/models)
100102
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
101103
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
102104
- [Groq supported models](https://console.groq.com/docs/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import * as Baseten from "../providers/baseten.js";
12
import * as BlackForestLabs from "../providers/black-forest-labs.js";
23
import * as Cerebras from "../providers/cerebras.js";
34
import * as Cohere from "../providers/cohere.js";
@@ -55,6 +56,9 @@ import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from
5556
import { InferenceClientInputError } from "../errors.js";
5657

5758
export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask, TaskProviderHelper>>> = {
59+
baseten: {
60+
conversational: new Baseten.BasetenConversationalTask(),
61+
},
5862
"black-forest-labs": {
5963
"text-to-image": new BlackForestLabs.BlackForestLabsTextToImageTask(),
6064
},
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/**
2+
* See the registered mapping of HF model ID => Baseten model ID here:
3+
*
4+
* https://huggingface.co/api/partners/baseten/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Baseten and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Baseten, please open an issue on the present repo
13+
* and we will tag Baseten team members.
14+
*
15+
* Thanks!
16+
*/
17+
import { BaseConversationalTask } from "./providerHelper.js";
18+
19+
const BASETEN_API_BASE_URL = "https://inference.baseten.co";
20+
21+
export class BasetenConversationalTask extends BaseConversationalTask {
22+
constructor() {
23+
super("baseten", BASETEN_API_BASE_URL);
24+
}
25+
}

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
1818
* Example:
1919
* "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen2.5-Coder-32B-Instruct",
2020
*/
21+
baseten: {},
2122
"black-forest-labs": {},
2223
cerebras: {},
2324
cohere: {},

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ export interface Options {
4545
export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
4646

4747
export const INFERENCE_PROVIDERS = [
48+
"baseten",
4849
"black-forest-labs",
4950
"cerebras",
5051
"cohere",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,4 +2343,62 @@ describe.skip("InferenceClient", () => {
23432343
},
23442344
TIMEOUT
23452345
);
2346+
2347+
describe.concurrent(
2348+
"Baseten",
2349+
() => {
2350+
const client = new InferenceClient(env.HF_BASETEN_KEY ?? "dummy");
2351+
2352+
HARDCODED_MODEL_INFERENCE_MAPPING["baseten"] = {
2353+
"Qwen/Qwen3-235B-A22B-Instruct-2507": {
2354+
provider: "baseten",
2355+
hfModelId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
2356+
providerId: "Qwen/Qwen3-235B-A22B-Instruct-2507",
2357+
status: "live",
2358+
task: "conversational",
2359+
},
2360+
};
2361+
2362+
it("chatCompletion - Qwen3 235B Instruct", async () => {
2363+
const res = await client.chatCompletion({
2364+
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
2365+
provider: "baseten",
2366+
messages: [{ role: "user", content: "What is 5 + 3?" }],
2367+
max_tokens: 20,
2368+
});
2369+
if (res.choices && res.choices.length > 0) {
2370+
const completion = res.choices[0].message?.content;
2371+
expect(completion).toBeDefined();
2372+
expect(typeof completion).toBe("string");
2373+
expect(completion).toMatch(/(eight|8)/i);
2374+
}
2375+
});
2376+
2377+
it("chatCompletion stream - Qwen3 235B", async () => {
2378+
const stream = client.chatCompletionStream({
2379+
model: "Qwen/Qwen3-235B-A22B-Instruct-2507",
2380+
provider: "baseten",
2381+
messages: [{ role: "user", content: "Count from 1 to 3" }],
2382+
stream: true,
2383+
max_tokens: 20,
2384+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
2385+
2386+
let fullResponse = "";
2387+
for await (const chunk of stream) {
2388+
if (chunk.choices && chunk.choices.length > 0) {
2389+
const content = chunk.choices[0].delta?.content;
2390+
if (content) {
2391+
fullResponse += content;
2392+
}
2393+
}
2394+
}
2395+
2396+
// Verify we got a meaningful response
2397+
expect(fullResponse).toBeTruthy();
2398+
expect(fullResponse.length).toBeGreaterThan(0);
2399+
expect(fullResponse).toMatch(/1.*2.*3/);
2400+
});
2401+
},
2402+
TIMEOUT
2403+
);
23462404
});

0 commit comments

Comments
 (0)