Skip to content

Commit 457119e

Browse files
[Inference] fix feature extraction (embeddings) for sambanova (#1364)
This PR fixes the widget of https://huggingface.co/intfloat/e5-mistral-7b-instruct i didn’t want to introduce a breaking change in this PR, but we should consider adding support for the OpenAI [Embeddings API](https://platform.openai.com/docs/api-reference/embeddings). Other providers like [Fireworks AI](https://docs.fireworks.ai/guides/querying-embeddings-models) and [Together](https://docs.together.ai/docs/embeddings-overview) also host embedding models with OpenAI-compatible endpoints. AFAIK, TEI supports the OAI Embeddings API as well.
1 parent 4f4e176 commit 457119e

File tree

5 files changed

+7541
-7464
lines changed

5 files changed

+7541
-7464
lines changed

packages/inference/src/lib/getProviderHelper.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
119119
},
120120
sambanova: {
121121
conversational: new Sambanova.SambanovaConversationalTask(),
122+
"feature-extraction": new Sambanova.SambanovaFeatureExtractionTask(),
122123
},
123124
together: {
124125
"text-to-image": new Together.TogetherTextToImageTask(),

packages/inference/src/providers/sambanova.ts

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,42 @@
1414
*
1515
* Thanks!
1616
*/
17-
import { BaseConversationalTask } from "./providerHelper";
17+
import { InferenceOutputError } from "../lib/InferenceOutputError";
18+
19+
import type { FeatureExtractionOutput } from "@huggingface/tasks";
20+
import type { BodyParams } from "../types";
21+
import type { FeatureExtractionTaskHelper } from "./providerHelper";
22+
import { BaseConversationalTask, TaskProviderHelper } from "./providerHelper";
1823

1924
export class SambanovaConversationalTask extends BaseConversationalTask {
2025
constructor() {
2126
super("sambanova", "https://api.sambanova.ai");
2227
}
2328
}
29+
30+
export class SambanovaFeatureExtractionTask extends TaskProviderHelper implements FeatureExtractionTaskHelper {
31+
constructor() {
32+
super("sambanova", "https://api.sambanova.ai");
33+
}
34+
35+
override makeRoute(): string {
36+
return `/v1/embeddings`;
37+
}
38+
39+
override async getResponse(response: FeatureExtractionOutput): Promise<FeatureExtractionOutput> {
40+
if (typeof response === "object" && "data" in response && Array.isArray(response.data)) {
41+
return response.data.map((item) => item.embedding);
42+
}
43+
throw new InferenceOutputError(
44+
"Expected Sambanova feature-extraction (embeddings) response format to be {'data' : list of {'embedding' : number[]}}"
45+
);
46+
}
47+
48+
override preparePayload(params: BodyParams): Record<string, unknown> {
49+
return {
50+
model: params.model,
51+
input: params.args.inputs,
52+
...params.args,
53+
};
54+
}
55+
}

packages/inference/src/tasks/nlp/featureExtraction.ts

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@ import { getProviderHelper } from "../../lib/getProviderHelper";
33
import type { BaseArgs, Options } from "../../types";
44
import { innerRequest } from "../../utils/request";
55

6-
export type FeatureExtractionArgs = BaseArgs & FeatureExtractionInput;
6+
interface FeatureExtractionOAICompatInput {
7+
encoding_format?: "float" | "base64";
8+
dimensions?: number | null;
9+
}
10+
11+
export type FeatureExtractionArgs = BaseArgs & FeatureExtractionInput & FeatureExtractionOAICompatInput;
712

813
/**
914
* Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README).

packages/inference/test/InferenceClient.spec.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@ import type { TextToImageArgs } from "../src";
66
import {
77
chatCompletion,
88
chatCompletionStream,
9+
HfInference,
910
InferenceClient,
1011
textGeneration,
1112
textToImage,
12-
HfInference,
1313
} from "../src";
14+
import { isUrl } from "../src/lib/isUrl";
15+
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../src/providers/consts";
1416
import { readTestFile } from "./test-files";
1517
import "./vcr";
16-
import { HARDCODED_MODEL_INFERENCE_MAPPING } from "../src/providers/consts";
17-
import { isUrl } from "../src/lib/isUrl";
1818

1919
const TIMEOUT = 60000 * 3;
2020
const env = import.meta.env;
@@ -1176,6 +1176,15 @@ describe.concurrent("InferenceClient", () => {
11761176
}
11771177
expect(out).toContain("2");
11781178
});
1179+
it("featureExtraction", async () => {
1180+
const res = await client.featureExtraction({
1181+
model: "intfloat/e5-mistral-7b-instruct",
1182+
provider: "sambanova",
1183+
inputs: "Today is a sunny day and I will get some ice cream.",
1184+
});
1185+
expect(res).toBeInstanceOf(Array);
1186+
expect(res[0]).toBeInstanceOf(Array);
1187+
});
11791188
},
11801189
TIMEOUT
11811190
);

0 commit comments

Comments
 (0)