Skip to content

Commit 5a3d6c6

Browse files
GnoaleSBrandeis
andauthored
feat: add scaleway inference provider (#1674)
Hi 🤗 Register [Scaleway](https://www.scaleway.com/en/) as inference provider Test passing ``` > vitest run --config vitest.config.mts -t Scaleway RUN v0.34.6 /Users/gnoale/git/huggingface.js/packages/inference stderr | unknown test Set HF_TOKEN in the env to run the tests for better rate limits ↓ src/vendor/fetch-event-source/parse.spec.ts (17) [skipped] ✓ test/InferenceClient.spec.ts (112) 6116ms Test Files 1 passed | 1 skipped (2) Tests 5 passed | 124 skipped (129) Start at 14:03:55 Duration 6.58s (transform 253ms, setup 18ms, collect 327ms, tests 6.12s, environment 0ms, prepare 82ms) ``` --------- Co-authored-by: SBrandeis <[email protected]>
1 parent 0c1b347 commit 5a3d6c6

File tree

6 files changed

+215
-0
lines changed

6 files changed

+215
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Currently, we support the following providers:
5858
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
5959
- [Replicate](https://replicate.com)
6060
- [Sambanova](https://sambanova.ai)
61+
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
6162
- [Together](https://together.xyz)
6263
- [Blackforestlabs](https://blackforestlabs.ai)
6364
- [Cohere](https://cohere.com)
@@ -92,6 +93,7 @@ Only a subset of models are supported when requesting third-party providers. You
9293
- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models)
9394
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
9495
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
96+
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
9597
- [Together supported models](https://huggingface.co/api/partners/together/models)
9698
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
9799
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import type {
4747
} from "../providers/providerHelper.js";
4848
import * as Replicate from "../providers/replicate.js";
4949
import * as Sambanova from "../providers/sambanova.js";
50+
import * as Scaleway from "../providers/scaleway.js";
5051
import * as Together from "../providers/together.js";
5152
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
5253
import { InferenceClientInputError } from "../errors.js";
@@ -148,6 +149,11 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
148149
conversational: new Sambanova.SambanovaConversationalTask(),
149150
"feature-extraction": new Sambanova.SambanovaFeatureExtractionTask(),
150151
},
152+
scaleway: {
153+
conversational: new Scaleway.ScalewayConversationalTask(),
154+
"text-generation": new Scaleway.ScalewayTextGenerationTask(),
155+
"feature-extraction": new Scaleway.ScalewayFeatureExtractionTask(),
156+
},
151157
together: {
152158
"text-to-image": new Together.TogetherTextToImageTask(),
153159
conversational: new Together.TogetherConversationalTask(),

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
3434
ovhcloud: {},
3535
replicate: {},
3636
sambanova: {},
37+
scaleway: {},
3738
together: {},
3839
};
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/**
2+
* See the registered mapping of HF model ID => Scaleway model ID here:
3+
*
4+
* https://huggingface.co/api/partners/scaleway/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Scaleway and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Scaleway, please open an issue on the present repo
13+
* and we will tag Scaleway team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { FeatureExtractionOutput, TextGenerationOutput } from "@huggingface/tasks";
18+
import type { BodyParams } from "../types.js";
19+
import { InferenceClientProviderOutputError } from "../errors.js";
20+
21+
import type { FeatureExtractionTaskHelper } from "./providerHelper.js";
22+
import { BaseConversationalTask, TaskProviderHelper, BaseTextGenerationTask } from "./providerHelper.js";
23+
24+
const SCALEWAY_API_BASE_URL = "https://api.scaleway.ai";
25+
26+
interface ScalewayEmbeddingsResponse {
27+
data: Array<{
28+
embedding: number[];
29+
}>;
30+
}
31+
32+
export class ScalewayConversationalTask extends BaseConversationalTask {
33+
constructor() {
34+
super("scaleway", SCALEWAY_API_BASE_URL);
35+
}
36+
}
37+
38+
export class ScalewayTextGenerationTask extends BaseTextGenerationTask {
39+
constructor() {
40+
super("scaleway", SCALEWAY_API_BASE_URL);
41+
}
42+
43+
override preparePayload(params: BodyParams): Record<string, unknown> {
44+
return {
45+
model: params.model,
46+
...params.args,
47+
prompt: params.args.inputs,
48+
};
49+
}
50+
51+
override async getResponse(response: unknown): Promise<TextGenerationOutput> {
52+
if (
53+
typeof response === "object" &&
54+
response !== null &&
55+
"choices" in response &&
56+
Array.isArray(response.choices) &&
57+
response.choices.length > 0
58+
) {
59+
const completion: unknown = response.choices[0];
60+
if (
61+
typeof completion === "object" &&
62+
!!completion &&
63+
"text" in completion &&
64+
completion.text &&
65+
typeof completion.text === "string"
66+
) {
67+
return {
68+
generated_text: completion.text,
69+
};
70+
}
71+
}
72+
throw new InferenceClientProviderOutputError("Received malformed response from Scaleway text generation API");
73+
}
74+
}
75+
76+
export class ScalewayFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
77+
constructor() {
78+
super("scaleway", SCALEWAY_API_BASE_URL);
79+
}
80+
81+
preparePayload(params: BodyParams): Record<string, unknown> {
82+
return {
83+
input: params.args.inputs,
84+
model: params.model,
85+
};
86+
}
87+
88+
makeRoute(): string {
89+
return "v1/embeddings";
90+
}
91+
92+
async getResponse(response: ScalewayEmbeddingsResponse): Promise<FeatureExtractionOutput> {
93+
return response.data.map((item) => item.embedding);
94+
}
95+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export const INFERENCE_PROVIDERS = [
6161
"ovhcloud",
6262
"replicate",
6363
"sambanova",
64+
"scaleway",
6465
"together",
6566
] as const;
6667

packages/inference/test/InferenceClient.spec.ts

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1516,6 +1516,116 @@ describe.skip("InferenceClient", () => {
15161516
TIMEOUT
15171517
);
15181518

1519+
describe.concurrent(
1520+
"Scaleway",
1521+
() => {
1522+
const client = new InferenceClient(env.HF_SCALEWAY_KEY ?? "dummy");
1523+
1524+
HARDCODED_MODEL_INFERENCE_MAPPING.scaleway = {
1525+
"meta-llama/Llama-3.1-8B-Instruct": {
1526+
provider: "scaleway",
1527+
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
1528+
providerId: "llama-3.1-8b-instruct",
1529+
status: "live",
1530+
task: "conversational",
1531+
},
1532+
"BAAI/bge-multilingual-gemma2": {
1533+
provider: "scaleway",
1534+
hfModelId: "BAAI/bge-multilingual-gemma2",
1535+
providerId: "bge-multilingual-gemma2",
1536+
task: "feature-extraction",
1537+
status: "live",
1538+
},
1539+
"google/gemma-3-27b-it": {
1540+
provider: "scaleway",
1541+
hfModelId: "google/gemma-3-27b-it",
1542+
providerId: "gemma-3-27b-it",
1543+
task: "conversational",
1544+
status: "live",
1545+
},
1546+
};
1547+
1548+
it("chatCompletion", async () => {
1549+
const res = await client.chatCompletion({
1550+
model: "meta-llama/Llama-3.1-8B-Instruct",
1551+
provider: "scaleway",
1552+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1553+
tool_choice: "none",
1554+
});
1555+
if (res.choices && res.choices.length > 0) {
1556+
const completion = res.choices[0].message?.content;
1557+
expect(completion).toMatch(/(to )?(two|2)/i);
1558+
}
1559+
});
1560+
1561+
it("chatCompletion stream", async () => {
1562+
const stream = client.chatCompletionStream({
1563+
model: "meta-llama/Llama-3.1-8B-Instruct",
1564+
provider: "scaleway",
1565+
messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }],
1566+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1567+
let out = "";
1568+
for await (const chunk of stream) {
1569+
if (chunk.choices && chunk.choices.length > 0) {
1570+
out += chunk.choices[0].delta.content;
1571+
}
1572+
}
1573+
expect(out).toMatch(/(two|2)/i);
1574+
});
1575+
1576+
it("chatCompletion multimodal", async () => {
1577+
const res = await client.chatCompletion({
1578+
model: "google/gemma-3-27b-it",
1579+
provider: "scaleway",
1580+
messages: [
1581+
{
1582+
role: "user",
1583+
content: [
1584+
{
1585+
type: "image_url",
1586+
image_url: {
1587+
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
1588+
},
1589+
},
1590+
{ type: "text", text: "What is this?" },
1591+
],
1592+
},
1593+
],
1594+
});
1595+
expect(res.choices).toBeDefined();
1596+
expect(res.choices?.length).toBeGreaterThan(0);
1597+
expect(res.choices?.[0].message?.content).toContain("Statue of Liberty");
1598+
});
1599+
1600+
it("textGeneration", async () => {
1601+
const res = await client.textGeneration({
1602+
model: "meta-llama/Llama-3.1-8B-Instruct",
1603+
provider: "scaleway",
1604+
inputs: "Once upon a time,",
1605+
temperature: 0,
1606+
max_tokens: 19,
1607+
});
1608+
1609+
expect(res).toMatchObject({
1610+
generated_text:
1611+
" in a small village nestled in the rolling hills of the countryside, there lived a young girl named",
1612+
});
1613+
});
1614+
1615+
it("featureExtraction", async () => {
1616+
const res = await client.featureExtraction({
1617+
model: "BAAI/bge-multilingual-gemma2",
1618+
provider: "scaleway",
1619+
inputs: "That is a happy person",
1620+
});
1621+
1622+
expect(res).toBeInstanceOf(Array);
1623+
expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)]));
1624+
});
1625+
},
1626+
TIMEOUT
1627+
);
1628+
15191629
describe.concurrent("3rd party providers", () => {
15201630
it("chatCompletion - fails with unsupported model", async () => {
15211631
expect(

0 commit comments

Comments
 (0)