Skip to content

Commit 9dc304d

Browse files
committed
feat: add scaleway inference provider
1 parent eebfb1f commit 9dc304d

File tree

6 files changed

+220
-1
lines changed

6 files changed

+220
-1
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Currently, we support the following providers:
5858
- [OVHcloud](https://endpoints.ai.cloud.ovh.net/)
5959
- [Replicate](https://replicate.com)
6060
- [Sambanova](https://sambanova.ai)
61+
- [Scaleway](https://www.scaleway.com/en/generative-apis/)
6162
- [Together](https://together.xyz)
6263
- [Blackforestlabs](https://blackforestlabs.ai)
6364
- [Cohere](https://cohere.com)
@@ -92,6 +93,7 @@ Only a subset of models are supported when requesting third-party providers. You
9293
- [OVHcloud supported models](https://huggingface.co/api/partners/ovhcloud/models)
9394
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
9495
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
96+
- [Scaleway supported models](https://huggingface.co/api/partners/scaleway/models)
9597
- [Together supported models](https://huggingface.co/api/partners/together/models)
9698
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
9799
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ import type {
4747
} from "../providers/providerHelper.js";
4848
import * as Replicate from "../providers/replicate.js";
4949
import * as Sambanova from "../providers/sambanova.js";
50+
import * as Scaleway from "../providers/scaleway.js";
5051
import * as Together from "../providers/together.js";
5152
import type { InferenceProvider, InferenceProviderOrPolicy, InferenceTask } from "../types.js";
5253
import { InferenceClientInputError } from "../errors.js";
@@ -148,6 +149,12 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
148149
conversational: new Sambanova.SambanovaConversationalTask(),
149150
"feature-extraction": new Sambanova.SambanovaFeatureExtractionTask(),
150151
},
152+
scaleway: {
153+
conversational: new Scaleway.ScalewayConversationalTask(),
154+
"image-to-text": new Scaleway.ScalewayConversationalTask(),
155+
"text-generation": new Scaleway.ScalewayTextGenerationTask(),
156+
"feature-extraction": new Scaleway.ScalewayFeatureExtractionTask(),
157+
},
151158
together: {
152159
"text-to-image": new Together.TogetherTextToImageTask(),
153160
conversational: new Together.TogetherConversationalTask(),

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,6 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
3434
ovhcloud: {},
3535
replicate: {},
3636
sambanova: {},
37+
scaleway: {},
3738
together: {},
3839
};
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/**
2+
* See the registered mapping of HF model ID => Scaleway model ID here:
3+
*
4+
* https://huggingface.co/api/partners/scaleway/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Scaleway and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Scaleway, please open an issue on the present repo
13+
* and we will tag Scaleway team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type {
18+
FeatureExtractionOutput,
19+
ImageToTextInput,
20+
ImageToTextOutput,
21+
TextGenerationOutput,
22+
} from "@huggingface/tasks";
23+
import type { BodyParams, RequestArgs } from "../types.js";
24+
import { InferenceClientProviderOutputError } from "../errors.js";
25+
import { base64FromBytes } from "../utils/base64FromBytes.js";
26+
27+
import {
28+
BaseConversationalTask,
29+
TaskProviderHelper,
30+
FeatureExtractionTaskHelper,
31+
BaseTextGenerationTask,
32+
ImageToTextTaskHelper,
33+
} from "./providerHelper.js";
34+
35+
const SCALEWAY_API_BASE_URL = "https://api.scaleway.ai";
36+
37+
interface ScalewayEmbeddingsResponse {
38+
data: Array<{
39+
embedding: number[];
40+
}>;
41+
}
42+
43+
export class ScalewayConversationalTask extends BaseConversationalTask {
44+
constructor() {
45+
super("scaleway", SCALEWAY_API_BASE_URL);
46+
}
47+
}
48+
49+
export class ScalewayTextGenerationTask extends BaseTextGenerationTask {
50+
constructor() {
51+
super("scaleway", SCALEWAY_API_BASE_URL);
52+
}
53+
54+
override preparePayload(params: BodyParams): Record<string, unknown> {
55+
return {
56+
model: params.model,
57+
...params.args,
58+
prompt: params.args.inputs,
59+
};
60+
}
61+
62+
override async getResponse(response: unknown): Promise<TextGenerationOutput> {
63+
if (
64+
typeof response === "object" &&
65+
response !== null &&
66+
"choices" in response &&
67+
Array.isArray((response as any).choices) &&
68+
(response as any).choices.length > 0
69+
) {
70+
const completion = (response as any).choices[0];
71+
if (completion.text) {
72+
return {
73+
generated_text: completion.text,
74+
};
75+
}
76+
}
77+
throw new InferenceClientProviderOutputError("Received malformed response from Scaleway text generation API");
78+
}
79+
}
80+
81+
export class ScalewayFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
82+
constructor() {
83+
super("scaleway", SCALEWAY_API_BASE_URL);
84+
}
85+
86+
preparePayload(params: BodyParams): Record<string, unknown> {
87+
return {
88+
input: params.args.inputs,
89+
model: params.model,
90+
};
91+
}
92+
93+
makeRoute(): string {
94+
return "v1/embeddings";
95+
}
96+
97+
async getResponse(response: ScalewayEmbeddingsResponse): Promise<FeatureExtractionOutput> {
98+
return response.data.map((item) => item.embedding);
99+
}
100+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ export const INFERENCE_PROVIDERS = [
6161
"ovhcloud",
6262
"replicate",
6363
"sambanova",
64+
"scaleway",
6465
"together",
6566
] as const;
6667

packages/inference/test/InferenceClient.spec.ts

Lines changed: 109 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ if (!env.HF_TOKEN) {
2222
console.warn("Set HF_TOKEN in the env to run the tests for better rate limits");
2323
}
2424

25-
describe.skip("InferenceClient", () => {
25+
describe.concurrent("InferenceClient", () => {
2626
// Individual tests can be ran without providing an api key, however running all tests without an api key will result in rate limiting error.
2727

2828
describe("backward compatibility", () => {
@@ -1516,6 +1516,114 @@ describe.skip("InferenceClient", () => {
15161516
TIMEOUT
15171517
);
15181518

1519+
describe.concurrent(
1520+
"Scaleway",
1521+
() => {
1522+
const client = new InferenceClient(env.HF_SCALEWAY_KEY ?? "dummy");
1523+
1524+
HARDCODED_MODEL_INFERENCE_MAPPING.scaleway = {
1525+
"meta-llama/Llama-3.1-8B-Instruct": {
1526+
provider: "scaleway",
1527+
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
1528+
providerId: "llama-3.1-8b-instruct",
1529+
status: "live",
1530+
task: "conversational",
1531+
},
1532+
"BAAI/bge-multilingual-gemma2": {
1533+
provider: "scaleway",
1534+
hfModelId: "BAAI/bge-multilingual-gemma2",
1535+
providerId: "bge-multilingual-gemma2",
1536+
task: "feature-extraction",
1537+
status: "live",
1538+
},
1539+
"google/gemma-3-27b-it": {
1540+
provider: "scaleway",
1541+
hfModelId: "google/gemma-3-27b-it",
1542+
providerId: "gemma-3-27b-it",
1543+
task: "conversational",
1544+
status: "live",
1545+
},
1546+
};
1547+
1548+
it("chatCompletion", async () => {
1549+
const res = await client.chatCompletion({
1550+
model: "meta-llama/Llama-3.1-8B-Instruct",
1551+
provider: "scaleway",
1552+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1553+
});
1554+
if (res.choices && res.choices.length > 0) {
1555+
const completion = res.choices[0].message?.content;
1556+
expect(completion).toMatch(/(to )?(two|2)/i);
1557+
}
1558+
});
1559+
1560+
it("chatCompletion stream", async () => {
1561+
const stream = client.chatCompletionStream({
1562+
model: "meta-llama/Llama-3.1-8B-Instruct",
1563+
provider: "scaleway",
1564+
messages: [{ role: "system", content: "Complete the equation 1 + 1 = , just the answer" }],
1565+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1566+
let out = "";
1567+
for await (const chunk of stream) {
1568+
if (chunk.choices && chunk.choices.length > 0) {
1569+
out += chunk.choices[0].delta.content;
1570+
}
1571+
}
1572+
expect(out).toMatch(/(two|2)/i);
1573+
});
1574+
1575+
it("imageToText", async () => {
1576+
const res = await client.chatCompletion({
1577+
model: "google/gemma-3-27b-it",
1578+
provider: "scaleway",
1579+
messages: [
1580+
{
1581+
role: "user",
1582+
content: [
1583+
{
1584+
type: "image_url",
1585+
image_url: {
1586+
url: "https://cdn.britannica.com/61/93061-050-99147DCE/Statue-of-Liberty-Island-New-York-Bay.jpg",
1587+
},
1588+
},
1589+
],
1590+
},
1591+
],
1592+
});
1593+
expect(res.choices).toBeDefined();
1594+
expect(res.choices?.length).toBeGreaterThan(0);
1595+
expect(res.choices?.[0].message?.content).toContain("Statue of Liberty");
1596+
});
1597+
1598+
it("textGeneration", async () => {
1599+
const res = await client.textGeneration({
1600+
model: "meta-llama/Llama-3.1-8B-Instruct",
1601+
provider: "scaleway",
1602+
inputs: "Once upon a time,",
1603+
temperature: 0,
1604+
max_tokens: 19,
1605+
});
1606+
1607+
expect(res).toMatchObject({
1608+
generated_text:
1609+
" in a small village nestled in the rolling hills of the countryside, there lived a young girl named",
1610+
});
1611+
});
1612+
1613+
it("featureExtraction", async () => {
1614+
const res = await client.featureExtraction({
1615+
model: "BAAI/bge-multilingual-gemma2",
1616+
provider: "scaleway",
1617+
inputs: "That is a happy person",
1618+
});
1619+
1620+
expect(res).toBeInstanceOf(Array);
1621+
expect(res[0]).toEqual(expect.arrayContaining([expect.any(Number)]));
1622+
});
1623+
},
1624+
TIMEOUT
1625+
);
1626+
15191627
describe.concurrent("3rd party providers", () => {
15201628
it("chatCompletion - fails with unsupported model", async () => {
15211629
expect(

0 commit comments

Comments
 (0)