Skip to content

Commit 3fd6e72

Browse files
SBrandeisDeep-Unlearning
authored andcommitted
InferenceClient methods have the proper types now (#1410)
# TL;DR Types of `InferenceClient` methods are now correct 🥳 ```typescript const client = new InferenceClient(); const output = client.textToImage( { inputs: "test input", parameters: { /// Correctly typed! And auto-completion works! }, }, options ) ```
1 parent 21678ca commit 3fd6e72

File tree

5 files changed

+30
-23
lines changed

5 files changed

+30
-23
lines changed

packages/inference/src/InferenceClient.ts

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
import * as tasks from "./tasks";
2-
import type { Options, RequestArgs } from "./types";
3-
import type { DistributiveOmit } from "./utils/distributive-omit";
2+
import type { Options } from "./types";
43
import { omit } from "./utils/omit";
4+
import { typedEntries } from "./utils/typedEntries";
55

66
/* eslint-disable @typescript-eslint/no-empty-interface */
77
/* eslint-disable @typescript-eslint/no-unsafe-declaration-merging */
88

99
type Task = typeof tasks;
1010

11-
type TaskWithNoAccessToken = {
12-
[key in keyof Task]: (
13-
args: DistributiveOmit<Parameters<Task[key]>[0], "accessToken">,
14-
options?: Parameters<Task[key]>[1]
15-
) => ReturnType<Task[key]>;
16-
};
17-
1811
export class InferenceClient {
1912
private readonly accessToken: string;
2013
private readonly defaultOptions: Options;
@@ -28,15 +21,19 @@ export class InferenceClient {
2821
this.accessToken = accessToken;
2922
this.defaultOptions = defaultOptions;
3023

31-
for (const [name, fn] of Object.entries(tasks)) {
24+
for (const [name, fn] of typedEntries(tasks)) {
3225
Object.defineProperty(this, name, {
3326
enumerable: false,
34-
value: (params: RequestArgs, options: Options) =>
27+
value: (params: Parameters<typeof fn>[0], options: Parameters<typeof fn>[1]) =>
3528
// eslint-disable-next-line @typescript-eslint/no-explicit-any
36-
fn({ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params } as any, {
37-
...omit(defaultOptions, ["endpointUrl"]),
38-
...options,
39-
}),
29+
(fn as any)(
30+
/// ^ The cast of fn to any is necessary, otherwise TS can't compile because the generated union type is too complex
31+
{ endpointUrl: defaultOptions.endpointUrl, accessToken, ...params },
32+
{
33+
...omit(defaultOptions, ["endpointUrl"]),
34+
...options,
35+
}
36+
),
4037
});
4138
}
4239
}
@@ -51,7 +48,7 @@ export class InferenceClient {
5148
}
5249
}
5350

54-
export interface InferenceClient extends TaskWithNoAccessToken {}
51+
export interface InferenceClient extends Task {}
5552

5653
/**
5754
* For backward compatibility only, will remove soon.

packages/inference/src/lib/getInferenceProviderMapping.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,10 @@ export async function resolveProvider(
112112
throw new Error("Specifying a model is required when provider is 'auto'");
113113
}
114114
const inferenceProviderMapping = await fetchInferenceProviderMappingForModel(modelId);
115-
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider;
115+
provider = Object.keys(inferenceProviderMapping)[0] as InferenceProvider | undefined;
116+
}
117+
if (!provider) {
118+
throw new Error(`No Inference Provider available for model ${modelId}.`);
116119
}
117120
return provider;
118121
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
export function typedEntries<T extends { [s: string]: T[keyof T] } | ArrayLike<T[keyof T]>>(
2+
obj: T
3+
): [keyof T, T[keyof T]][] {
4+
return Object.entries(obj) as [keyof T, T[keyof T]][];
5+
}

packages/inference/test/InferenceClient.spec.ts

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,7 @@ describe.skip("InferenceClient", () => {
294294
await hf.tableQuestionAnswering({
295295
model: "google/tapas-base-finetuned-wtq",
296296
inputs: {
297-
query: "How many stars does the transformers repository have?",
297+
question: "How many stars does the transformers repository have?",
298298
table: {
299299
Repository: ["Transformers", "Datasets", "Tokenizers"],
300300
Stars: ["36542", "4512", "3934"],
@@ -488,7 +488,8 @@ describe.skip("InferenceClient", () => {
488488
expect(
489489
await hf.translation({
490490
model: "t5-base",
491-
inputs: ["My name is Wolfgang and I live in Berlin", "I work as programmer"],
491+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
492+
inputs: ["My name is Wolfgang and I live in Berlin", "I work as programmer"] as any,
492493
})
493494
).toMatchObject([
494495
{
@@ -505,7 +506,8 @@ describe.skip("InferenceClient", () => {
505506
model: "facebook/bart-large-mnli",
506507
inputs: [
507508
"Hi, I recently bought a device from your company but it is not working as advertised and I would like to get reimbursed!",
508-
],
509+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
510+
] as any,
509511
parameters: { candidate_labels: ["refund", "legal", "faq"] },
510512
})
511513
).toEqual(

packages/mcp-client/src/McpClient.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { Client } from "@modelcontextprotocol/sdk/client/index.js";
22
import type { StdioServerParameters } from "@modelcontextprotocol/sdk/client/stdio.js";
33
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
44
import { InferenceClient } from "@huggingface/inference";
5-
import type { InferenceClientEndpoint, InferenceProvider } from "@huggingface/inference";
5+
import type { InferenceClientEndpoint, InferenceProviderOrPolicy } from "@huggingface/inference";
66
import type {
77
ChatCompletionInputMessage,
88
ChatCompletionInputTool,
@@ -23,7 +23,7 @@ export interface ChatCompletionInputMessageTool extends ChatCompletionInputMessa
2323

2424
export class McpClient {
2525
protected client: InferenceClient | InferenceClientEndpoint;
26-
protected provider: string | undefined;
26+
protected provider: InferenceProviderOrPolicy | undefined;
2727

2828
protected model: string;
2929
private clients: Map<ToolName, Client> = new Map();
@@ -36,7 +36,7 @@ export class McpClient {
3636
apiKey,
3737
}: (
3838
| {
39-
provider: InferenceProvider;
39+
provider: InferenceProviderOrPolicy;
4040
baseUrl?: undefined;
4141
}
4242
| {

0 commit comments

Comments
 (0)