Skip to content

Commit 047708b

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents 7d64997 + 6ed9d44 commit 047708b

File tree

6 files changed

+149
-0
lines changed

6 files changed

+149
-0
lines changed

packages/inference/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Currently, we support the following providers:
5252
- [Hyperbolic](https://hyperbolic.xyz)
5353
- [Nebius](https://studio.nebius.ai)
5454
- [Novita](https://novita.ai/?utm_source=github_huggingface&utm_medium=github_readme&utm_campaign=link)
55+
- [Nscale](https://nscale.com)
5556
- [Replicate](https://replicate.com)
5657
- [Sambanova](https://sambanova.ai)
5758
- [Together](https://together.xyz)
@@ -80,6 +81,7 @@ Only a subset of models are supported when requesting third-party providers. You
8081
- [Fireworks AI supported models](https://huggingface.co/api/partners/fireworks-ai/models)
8182
- [Hyperbolic supported models](https://huggingface.co/api/partners/hyperbolic/models)
8283
- [Nebius supported models](https://huggingface.co/api/partners/nebius/models)
84+
- [Nscale supported models](https://huggingface.co/api/partners/nscale/models)
8385
- [Replicate supported models](https://huggingface.co/api/partners/replicate/models)
8486
- [Sambanova supported models](https://huggingface.co/api/partners/sambanova/models)
8587
- [Together supported models](https://huggingface.co/api/partners/together/models)

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import * as HFInference from "../providers/hf-inference";
99
import * as Hyperbolic from "../providers/hyperbolic";
1010
import * as Nebius from "../providers/nebius";
1111
import * as Novita from "../providers/novita";
12+
import * as Nscale from "../providers/nscale";
1213
import * as OpenAI from "../providers/openai";
1314
import type {
1415
AudioClassificationTaskHelper,
@@ -114,6 +115,10 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
114115
conversational: new Novita.NovitaConversationalTask(),
115116
"text-generation": new Novita.NovitaTextGenerationTask(),
116117
},
118+
nscale: {
119+
"text-to-image": new Nscale.NscaleTextToImageTask(),
120+
conversational: new Nscale.NscaleConversationalTask(),
121+
},
117122
openai: {
118123
conversational: new OpenAI.OpenAIConversationalTask(),
119124
},

packages/inference/src/providers/consts.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
2929
hyperbolic: {},
3030
nebius: {},
3131
novita: {},
32+
nscale: {},
3233
openai: {},
3334
replicate: {},
3435
sambanova: {},
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/**
2+
* See the registered mapping of HF model ID => Nscale model ID here:
3+
*
4+
* https://huggingface.co/api/partners/nscale-cloud/models
5+
*
6+
* This is a publicly available mapping.
7+
*
8+
* If you want to try to run inference for a new model locally before it's registered on huggingface.co,
9+
* you can add it to the dictionary "HARDCODED_MODEL_ID_MAPPING" in consts.ts, for dev purposes.
10+
*
11+
* - If you work at Nscale and want to update this mapping, please use the model mapping API we provide on huggingface.co
12+
* - If you're a community member and want to add a new supported HF model to Nscale, please open an issue on the present repo
13+
* and we will tag Nscale team members.
14+
*
15+
* Thanks!
16+
*/
17+
import type { TextToImageInput } from "@huggingface/tasks";
18+
import { InferenceOutputError } from "../lib/InferenceOutputError";
19+
import type { BodyParams } from "../types";
20+
import { omit } from "../utils/omit";
21+
import { BaseConversationalTask, TaskProviderHelper, type TextToImageTaskHelper } from "./providerHelper";
22+
23+
const NSCALE_API_BASE_URL = "https://inference.api.nscale.com";
24+
25+
interface NscaleCloudBase64ImageGeneration {
26+
data: Array<{
27+
b64_json: string;
28+
}>;
29+
}
30+
31+
export class NscaleConversationalTask extends BaseConversationalTask {
32+
constructor() {
33+
super("nscale", NSCALE_API_BASE_URL);
34+
}
35+
}
36+
37+
export class NscaleTextToImageTask extends TaskProviderHelper implements TextToImageTaskHelper {
38+
constructor() {
39+
super("nscale", NSCALE_API_BASE_URL);
40+
}
41+
42+
preparePayload(params: BodyParams<TextToImageInput>): Record<string, unknown> {
43+
return {
44+
...omit(params.args, ["inputs", "parameters"]),
45+
...params.args.parameters,
46+
response_format: "b64_json",
47+
prompt: params.args.inputs,
48+
model: params.model,
49+
};
50+
}
51+
52+
makeRoute(): string {
53+
return "v1/images/generations";
54+
}
55+
56+
async getResponse(
57+
response: NscaleCloudBase64ImageGeneration,
58+
url?: string,
59+
headers?: HeadersInit,
60+
outputType?: "url" | "blob"
61+
): Promise<string | Blob> {
62+
if (
63+
typeof response === "object" &&
64+
"data" in response &&
65+
Array.isArray(response.data) &&
66+
response.data.length > 0 &&
67+
"b64_json" in response.data[0] &&
68+
typeof response.data[0].b64_json === "string"
69+
) {
70+
const base64Data = response.data[0].b64_json;
71+
if (outputType === "url") {
72+
return `data:image/jpeg;base64,${base64Data}`;
73+
}
74+
return fetch(`data:image/jpeg;base64,${base64Data}`).then((res) => res.blob());
75+
}
76+
77+
throw new InferenceOutputError("Expected Nscale text-to-image response format");
78+
}
79+
}

packages/inference/src/types.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ export const INFERENCE_PROVIDERS = [
4848
"hyperbolic",
4949
"nebius",
5050
"novita",
51+
"nscale",
5152
"openai",
5253
"replicate",
5354
"sambanova",

packages/inference/test/InferenceClient.spec.ts

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1690,6 +1690,67 @@ describe.skip("InferenceClient", () => {
16901690
},
16911691
TIMEOUT
16921692
);
1693+
describe.concurrent(
1694+
"Nscale",
1695+
() => {
1696+
const client = new InferenceClient(env.HF_NSCALE_KEY ?? "dummy");
1697+
1698+
HARDCODED_MODEL_INFERENCE_MAPPING["nscale"] = {
1699+
"meta-llama/Llama-3.1-8B-Instruct": {
1700+
hfModelId: "meta-llama/Llama-3.1-8B-Instruct",
1701+
providerId: "nscale",
1702+
status: "live",
1703+
task: "conversational",
1704+
},
1705+
"black-forest-labs/FLUX.1-schnell": {
1706+
hfModelId: "black-forest-labs/FLUX.1-schnell",
1707+
providerId: "flux-schnell",
1708+
status: "live",
1709+
task: "text-to-image",
1710+
},
1711+
};
1712+
1713+
it("chatCompletion", async () => {
1714+
const res = await client.chatCompletion({
1715+
model: "meta-llama/Llama-3.1-8B-Instruct",
1716+
provider: "nscale",
1717+
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
1718+
});
1719+
if (res.choices && res.choices.length > 0) {
1720+
const completion = res.choices[0].message?.content;
1721+
expect(completion).toContain("two");
1722+
}
1723+
});
1724+
it("chatCompletion stream", async () => {
1725+
const stream = client.chatCompletionStream({
1726+
model: "meta-llama/Llama-3.1-8B-Instruct",
1727+
provider: "nscale",
1728+
messages: [{ role: "user", content: "Say 'this is a test'" }],
1729+
stream: true,
1730+
}) as AsyncGenerator<ChatCompletionStreamOutput>;
1731+
let fullResponse = "";
1732+
for await (const chunk of stream) {
1733+
if (chunk.choices && chunk.choices.length > 0) {
1734+
const content = chunk.choices[0].delta?.content;
1735+
if (content) {
1736+
fullResponse += content;
1737+
}
1738+
}
1739+
}
1740+
expect(fullResponse).toBeTruthy();
1741+
expect(fullResponse.length).toBeGreaterThan(0);
1742+
});
1743+
it("textToImage", async () => {
1744+
const res = await client.textToImage({
1745+
model: "black-forest-labs/FLUX.1-schnell",
1746+
provider: "nscale",
1747+
inputs: "An astronaut riding a horse",
1748+
});
1749+
expect(res).toBeInstanceOf(Blob);
1750+
});
1751+
},
1752+
TIMEOUT
1753+
);
16931754
describe.concurrent(
16941755
"Groq",
16951756
() => {

0 commit comments

Comments
 (0)