Skip to content

Commit 1221acb

Browse files
radamescoyotte508
andauthored
Feature sentence fix (#141)
Co-authored-by: Eliott C <[email protected]>
1 parent 02b4bb3 commit 1221acb

File tree

4 files changed

+91
-7
lines changed

4 files changed

+91
-7
lines changed

packages/inference/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ await hf.conversational({
109109
}
110110
})
111111

112-
await hf.featureExtraction({
112+
await hf.sentenceSimilarity({
113113
model: 'sentence-transformers/paraphrase-xlm-r-multilingual-v1',
114114
inputs: {
115115
source_sentence: 'That is a happy person',
@@ -121,6 +121,11 @@ await hf.featureExtraction({
121121
}
122122
})
123123

124+
await hf.featureExtraction({
125+
model: "sentence-transformers/distilbert-base-nli-mean-tokens",
126+
inputs: "That is a happy person",
127+
});
128+
124129
// Audio
125130

126131
await hf.automaticSpeechRecognition({

packages/inference/src/HfInference.ts

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,8 +449,22 @@ export interface ConversationalReturn {
449449
generated_text: string;
450450
warnings: string[];
451451
}
452-
453452
export type FeatureExtractionArgs = Args & {
453+
/**
454+
* The inputs is a string or a list of strings to get the features from.
455+
*
456+
* inputs: "That is a happy person",
457+
*
458+
*/
459+
inputs: string | string[];
460+
};
461+
462+
/**
463+
* Returned values are a list of floats, or a list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
464+
*/
465+
export type FeatureExtractionReturn = (number | number[])[];
466+
467+
export type SentenceSimiliarityArgs = Args & {
454468
/**
455469
* The inputs vary based on the model. For example when using sentence-transformers/paraphrase-xlm-r-multilingual-v1 the inputs will look like this:
456470
*
@@ -463,9 +477,9 @@ export type FeatureExtractionArgs = Args & {
463477
};
464478

465479
/**
466-
* Returned values are a list of floats, or a list of list of floats (depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README.
480+
* Returned values are a list of floats
467481
*/
468-
export type FeatureExtractionReturn = (number | number[])[];
482+
export type SentenceSimiliarityReturn = number[];
469483

470484
export type ImageClassificationArgs = Args & {
471485
/**
@@ -834,6 +848,44 @@ export class HfInference {
834848
*/
835849
public async featureExtraction(args: FeatureExtractionArgs, options?: Options): Promise<FeatureExtractionReturn> {
836850
const res = await this.request<FeatureExtractionReturn>(args, options);
851+
let isValidOutput = true;
852+
// Check if output is an array
853+
if (Array.isArray(res)) {
854+
for (const e of res) {
855+
// Check if output is an array of arrays or numbers
856+
if (Array.isArray(e)) {
857+
// if all elements are numbers, continue
858+
isValidOutput = e.every((x) => typeof x === "number");
859+
if (!isValidOutput) {
860+
break;
861+
}
862+
} else if (typeof e !== "number") {
863+
isValidOutput = false;
864+
break;
865+
}
866+
}
867+
} else {
868+
isValidOutput = false;
869+
}
870+
if (!isValidOutput) {
871+
throw new TypeError("Invalid inference output: output must be of type Array<Array<number> | number>");
872+
}
873+
return res;
874+
}
875+
876+
/**
877+
* Calculate the semantic similarity between one text and a list of other sentences by comparing their embeddings.
878+
*/
879+
public async sentenceSimiliarity(
880+
args: SentenceSimiliarityArgs,
881+
options?: Options
882+
): Promise<SentenceSimiliarityReturn> {
883+
const res = await this.request<SentenceSimiliarityReturn>(args, options);
884+
885+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x === "number");
886+
if (!isValidOutput) {
887+
throw new TypeError("Invalid inference output: output must be of type Array<number>");
888+
}
837889
return res;
838890
}
839891

packages/inference/test/HfInference.spec.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,9 @@ describe.concurrent(
267267
warnings: ["Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation."],
268268
});
269269
});
270-
it("featureExtraction", async () => {
270+
it("SentenceSimiliarity", async () => {
271271
expect(
272-
await hf.featureExtraction({
272+
await hf.sentenceSimiliarity({
273273
model: "sentence-transformers/paraphrase-xlm-r-multilingual-v1",
274274
inputs: {
275275
source_sentence: "That is a happy person",
@@ -278,6 +278,13 @@ describe.concurrent(
278278
})
279279
).toEqual([expect.any(Number), expect.any(Number), expect.any(Number)]);
280280
});
281+
it("FeatureExtraction", async () => {
282+
const response = await hf.featureExtraction({
283+
model: "sentence-transformers/distilbert-base-nli-mean-tokens",
284+
inputs: "That is a happy person",
285+
});
286+
expect(response).toEqual(expect.arrayContaining([expect.any(Number)]));
287+
});
281288
it("automaticSpeechRecognition", async () => {
282289
expect(
283290
await hf.automaticSpeechRecognition({

0 commit comments

Comments
 (0)