Skip to content

Commit 1a812bc

Browse files
authored
[Inference] "auto" snippet and filter out clients (#1441)
1. get inference snippets with `provider: "auto"` 2. dynamic list of languages/clients (so we have only the clients that support the `auto` provider policy)
1 parent b993cc5 commit 1a812bc

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

packages/inference/src/snippets/getInferenceSnippets.ts

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import type { ChatCompletionInputMessage, GenerationParameters } from "@huggingf
1111
import type { InferenceProviderModelMapping } from "../lib/getInferenceProviderMapping";
1212
import { getProviderHelper } from "../lib/getProviderHelper";
1313
import { makeRequestOptionsFromResolvedModel } from "../lib/makeRequestOptions";
14-
import type { InferenceProvider, InferenceTask, RequestArgs } from "../types";
14+
import type { InferenceProviderOrPolicy, InferenceTask, RequestArgs } from "../types";
1515
import { templates } from "./templates.exported";
1616

1717
export type InferenceSnippetOptions = { streaming?: boolean; billTo?: string } & Record<string, unknown>;
@@ -28,6 +28,11 @@ const CLIENTS: Record<InferenceSnippetLanguage, Client[]> = {
2828
sh: [...SH_CLIENTS],
2929
};
3030

31+
const CLIENTS_AUTO_POLICY: Partial<Record<InferenceSnippetLanguage, Client[]>> = {
32+
js: ["huggingface.js"],
33+
python: ["huggingface_hub"],
34+
};
35+
3136
type InputPreparationFn = (model: ModelDataMinimal, opts?: Record<string, unknown>) => object;
3237
interface TemplateParams {
3338
accessToken?: string;
@@ -37,7 +42,7 @@ interface TemplateParams {
3742
inputs?: object;
3843
providerInputs?: object;
3944
model?: ModelDataMinimal;
40-
provider?: InferenceProvider;
45+
provider?: InferenceProviderOrPolicy;
4146
providerModelId?: string;
4247
billTo?: string;
4348
methodName?: string; // specific to snippetBasic
@@ -121,7 +126,7 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
121126
return (
122127
model: ModelDataMinimal,
123128
accessToken: string,
124-
provider: InferenceProvider,
129+
provider: InferenceProviderOrPolicy,
125130
inferenceProviderMapping?: InferenceProviderModelMapping,
126131
opts?: InferenceSnippetOptions
127132
): InferenceSnippet[] => {
@@ -139,7 +144,8 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
139144
}
140145
let providerHelper: ReturnType<typeof getProviderHelper>;
141146
try {
142-
providerHelper = getProviderHelper(provider, task);
147+
/// For the "auto" provider policy we use hf-inference snippets
148+
providerHelper = getProviderHelper(provider === "auto" ? "hf-inference" : provider, task);
143149
} catch (e) {
144150
console.error(`Failed to get provider helper for ${provider} (${task})`, e);
145151
return [];
@@ -200,9 +206,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
200206
};
201207

202208
/// Iterate over clients => check if a snippet exists => generate
209+
const clients = provider === "auto" ? CLIENTS_AUTO_POLICY : CLIENTS;
203210
return inferenceSnippetLanguages
204211
.map((language) => {
205-
return CLIENTS[language]
212+
const langClients = clients[language] ?? [];
213+
return langClients
206214
.map((client) => {
207215
if (!hasTemplate(language, client, templateName)) {
208216
return;
@@ -283,7 +291,7 @@ const snippets: Partial<
283291
(
284292
model: ModelDataMinimal,
285293
accessToken: string,
286-
provider: InferenceProvider,
294+
provider: InferenceProviderOrPolicy,
287295
inferenceProviderMapping?: InferenceProviderModelMapping,
288296
opts?: InferenceSnippetOptions
289297
) => InferenceSnippet[]
@@ -323,7 +331,7 @@ const snippets: Partial<
323331
export function getInferenceSnippets(
324332
model: ModelDataMinimal,
325333
accessToken: string,
326-
provider: InferenceProvider,
334+
provider: InferenceProviderOrPolicy,
327335
inferenceProviderMapping?: InferenceProviderModelMapping,
328336
opts?: Record<string, unknown>
329337
): InferenceSnippet[] {

0 commit comments

Comments
 (0)