Skip to content

Commit a33160c

Browse files
authored
Pass inferenceProvider as mapping (#553)
* Pass inferenceProvider as mapping * txt * run formatter * fix lint * should fix lint * go for it * finally?
1 parent 4450d31 commit a33160c

File tree

3 files changed

+70
-36
lines changed

3 files changed

+70
-36
lines changed

README.md

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -398,43 +398,47 @@ The `InferenceSnippet` component is used to render an interactive interface for
398398

399399
Below is a description of the props that can be passed to this component:
400400

401-
- **modelId** (string, required):
402-
The identifier of the AI model to be used for inference. This should be a valid model ID, such as `"deepseek-ai/DeepSeek-R1"`.
403-
404401
- **pipeline** (string, required):
405402
Specifies the type of pipeline to be used for inference. Common values include `"text-generation"`, `"text-classification"`, etc.
406403

404+
- **providersMapping** (mapping of {modelId: string, providerModelId: string}, required):
405+
A mapping which keys are provider names and values are objects with `modelId` and `providerModelId`.
406+
Example: `{"fireworks-ai": {modelId: "deepseek-ai/DeepSeek-R1", providerModelId: "accounts/fireworks/models/deepseek-r1", novita: {modelId: "deepseek-ai/DeepSeek-V3-0324", providerModelId: "deepseek/deepseek-v3-0324"}}`
407+
407408
- **conversational** (boolean, optional):
408409
If set to `true`, the component will enable conversational mode, allowing for multi-turn interactions for `text-generation` models.
409410

410-
- **providers** (array of strings, required):
411-
A list of provider names that support the specified model and pipeline. Example: `["fireworks-ai", "cerebras", "cohere", "hyperbolic"]`.
412-
413411
#### Example Usage
414412

415413
```svelte
416414
<InferenceSnippet
417-
modelId="deepseek-ai/DeepSeek-R1"
418415
pipeline="text-generation"
419416
conversational
420-
providers={["fireworks-ai", "cerebras", "cohere", "hyperbolic"]}
417+
providersMapping={{
418+
"fireworks-ai": {modelId: "deepseek-ai/DeepSeek-R1", providerModelId: "accounts/fireworks/models/deepseek-r1"},
419+
novita: {modelId: "deepseek-ai/DeepSeek-V3-0324", providerModelId: "deepseek/deepseek-v3-0324"}
420+
}}
421421
/>
422422
```
423423

424424
```svelte
425425
<InferenceSnippet
426-
modelId="deepseek-ai/DeepSeek-R1"
427426
pipeline="text-generation"
428427
conversational
429-
providers={["fireworks-ai"]}
428+
providers={{
429+
"fireworks-ai": {modelId: "deepseek-ai/DeepSeek-R1", providerModelId: "accounts/fireworks/models/deepseek-r1"}
430+
}}
430431
/>
431432
```
432433

433434
```svelte
434435
<InferenceSnippet
435-
modelId="black-forest-labs/FLUX.1-dev"
436436
pipeline="text-to-image"
437-
providers={["black-forest-labs", "replicate", "fal-ai"]}
437+
providers={{
438+
"black-forest-labs": {modelId: "black-forest-labs/FLUX.1-dev", providerModelId: "flux-dev"},
439+
"replicate": {modelId: "black-forest-labs/FLUX.1-dev", providerModelId: "black-forest-labs/flux-dev"},
440+
"fal-ai": {modelId: "black-forest-labs/FLUX.1-dev", providerModelId: "fal-ai/flux/dev"},
441+
}}
438442
/>
439443
```
440444

kit/src/lib/InferenceSnippet/InferenceSnippet.svelte

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,26 @@
2828
import Dropdown from "$lib/Dropdown.svelte";
2929
import DropdownEntry from "$lib/DropdownEntry.svelte";
3030
31-
export let modelId: string;
31+
type InferenceProviderNotOpenAI = Exclude<InferenceProvider, "openai">;
32+
3233
export let pipeline: PipelineType;
3334
export let conversational = false;
34-
export let providers: Exclude<InferenceProvider, "openai">[] = [];
35+
export let providersMapping: Partial<
36+
Record<
37+
InferenceProviderNotOpenAI,
38+
{
39+
modelId: string;
40+
providerModelId: string;
41+
}
42+
>
43+
> = {};
3544
45+
let providers = Object.keys(providersMapping) as InferenceProviderNotOpenAI[];
3646
let selectedProvider = providers[0];
3747
let streaming = false;
3848
3949
const model = {
40-
id: modelId,
50+
id: providersMapping[selectedProvider]!.modelId,
4151
pipeline_tag: pipeline,
4252
tags: conversational ? ["conversational"] : [],
4353
};
@@ -46,7 +56,8 @@
4656
const availableSnippets = snippets.getInferenceSnippets(
4757
model as ModelDataMinimal,
4858
accessToken,
49-
selectedProvider
59+
selectedProvider,
60+
providersMapping[selectedProvider]!.providerModelId
5061
);
5162
const languages = [...new Set(availableSnippets.map((s) => s.language))];
5263
let selectedLanguage = languages[0];
@@ -60,15 +71,18 @@
6071
$: selectedClient = clients?.[0];
6172
6273
$: code = snippets
63-
.getInferenceSnippets(model as ModelDataMinimal, accessToken, selectedProvider, undefined, {
64-
streaming,
65-
})
74+
.getInferenceSnippets(
75+
model as ModelDataMinimal,
76+
accessToken,
77+
selectedProvider,
78+
providersMapping[selectedProvider]!.providerModelId,
79+
{
80+
streaming,
81+
}
82+
)
6683
.find((s) => s.language === selectedLanguage && s.client === selectedClient)?.content;
6784
68-
const PRETTY_NAMES: Record<
69-
Exclude<InferenceProvider, "openai"> | InferenceSnippetLanguage,
70-
string
71-
> = {
85+
const PRETTY_NAMES: Record<InferenceProviderNotOpenAI | InferenceSnippetLanguage, string> = {
7286
// inference providers
7387
"black-forest-labs": "Black Forest Labs",
7488
cerebras: "Cerebras",
@@ -90,7 +104,7 @@
90104
};
91105
92106
const ICONS: Record<
93-
Exclude<InferenceProvider, "openai"> | InferenceSnippetLanguage,
107+
InferenceProviderNotOpenAI | InferenceSnippetLanguage,
94108
new (...args: any) => SvelteComponent
95109
> = {
96110
// inference providers

kit/src/lib/InferenceSnippet/README.md

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,43 +6,59 @@ The `InferenceSnippet` component is used to render an interactive interface for
66

77
Below is a description of the props that can be passed to this component:
88

9-
- **modelId** (string, required):
10-
The identifier of the AI model to be used for inference. This should be a valid model ID, such as `"deepseek-ai/DeepSeek-R1"`.
11-
129
- **pipeline** (string, required):
1310
Specifies the type of pipeline to be used for inference. Common values include `"text-generation"`, `"text-classification"`, etc.
1411

12+
- **providersMapping** (mapping of {modelId: string, providerModelId: string}, required):
13+
A mapping which keys are provider names and values are objects with `modelId` and `providerModelId`.
14+
Example: `{"fireworks-ai": {modelId: "deepseek-ai/DeepSeek-R1", providerModelId: "accounts/fireworks/models/deepseek-r1", novita: {modelId: "deepseek-ai/DeepSeek-V3-0324", providerModelId: "deepseek/deepseek-v3-0324"}}`
15+
1516
- **conversational** (boolean, optional):
1617
If set to `true`, the component will enable conversational mode, allowing for multi-turn interactions for `text-generation` models.
1718

18-
- **providers** (array of strings, required):
19-
A list of provider names that support the specified model and pipeline. Example: `["fireworks-ai", "cerebras", "cohere", "hyperbolic"]`.
20-
2119
#### Example Usage
2220

2321
```svelte
2422
<InferenceSnippet
25-
modelId="deepseek-ai/DeepSeek-R1"
2623
pipeline="text-generation"
2724
conversational
28-
providers={["fireworks-ai", "cerebras", "cohere", "hyperbolic"]}
25+
providersMapping={{
26+
"fireworks-ai": {
27+
modelId: "deepseek-ai/DeepSeek-R1",
28+
providerModelId: "accounts/fireworks/models/deepseek-r1",
29+
},
30+
novita: {
31+
modelId: "deepseek-ai/DeepSeek-V3-0324",
32+
providerModelId: "deepseek/deepseek-v3-0324",
33+
},
34+
}}
2935
/>
3036
```
3137

3238
```svelte
3339
<InferenceSnippet
34-
modelId="deepseek-ai/DeepSeek-R1"
3540
pipeline="text-generation"
3641
conversational
37-
providers={["fireworks-ai"]}
42+
providers={{
43+
"fireworks-ai": {
44+
modelId: "deepseek-ai/DeepSeek-R1",
45+
providerModelId: "accounts/fireworks/models/deepseek-r1",
46+
},
47+
}}
3848
/>
3949
```
4050

4151
```svelte
4252
<InferenceSnippet
43-
modelId="black-forest-labs/FLUX.1-dev"
4453
pipeline="text-to-image"
45-
providers={["black-forest-labs", "replicate", "fal-ai"]}
54+
providers={{
55+
"black-forest-labs": { modelId: "black-forest-labs/FLUX.1-dev", providerModelId: "flux-dev" },
56+
replicate: {
57+
modelId: "black-forest-labs/FLUX.1-dev",
58+
providerModelId: "black-forest-labs/flux-dev",
59+
},
60+
"fal-ai": { modelId: "black-forest-labs/FLUX.1-dev", providerModelId: "fal-ai/flux/dev" },
61+
}}
4662
/>
4763
```
4864

0 commit comments

Comments
 (0)