Skip to content

Commit b5d8fb0

Browse files
author
Mishig
authored
Check if api inference output is correct (#125)
This PR adds checks to api inference calls: 1. Checks if api inference outputs are as expected (for example, TextClaissification api output should be `array<{score: number, label: string}>`) 2. If not, throw `new TypeError`
1 parent 99f9eed commit b5d8fb0

File tree

1 file changed

+173
-18
lines changed

1 file changed

+173
-18
lines changed

packages/inference/src/HfInference.ts

Lines changed: 173 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -519,21 +519,52 @@ export class HfInference {
519519
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
520520
*/
521521
public async fillMask(args: FillMaskArgs, options?: Options): Promise<FillMaskReturn> {
522-
return await this.request(args, options);
522+
const res = await this.request<FillMaskReturn>(args, options);
523+
const isValidOutput =
524+
Array.isArray(res) &&
525+
res.every(
526+
(x) =>
527+
typeof x.score === "number" &&
528+
typeof x.sequence === "string" &&
529+
typeof x.token === "number" &&
530+
typeof x.token_str === "string"
531+
);
532+
if (!isValidOutput) {
533+
throw new TypeError(
534+
"Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
535+
);
536+
}
537+
return res;
523538
}
524539

525540
/**
526541
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
527542
*/
528543
public async summarization(args: SummarizationArgs, options?: Options): Promise<SummarizationReturn> {
529-
return (await this.request<SummarizationReturn[]>(args, options))?.[0];
544+
const res = await this.request<SummarizationReturn[]>(args, options);
545+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.summary_text === "string");
546+
if (!isValidOutput) {
547+
throw new TypeError("Invalid inference output: output must be of type Array<summary_text: string>");
548+
}
549+
return res?.[0];
530550
}
531551

532552
/**
533553
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
534554
*/
535555
public async questionAnswer(args: QuestionAnswerArgs, options?: Options): Promise<QuestionAnswerReturn> {
536-
return await this.request(args, options);
556+
const res = await this.request<QuestionAnswerReturn>(args, options);
557+
const isValidOutput =
558+
typeof res.answer === "string" &&
559+
typeof res.end === "number" &&
560+
typeof res.score === "number" &&
561+
typeof res.start === "number";
562+
if (!isValidOutput) {
563+
throw new TypeError(
564+
"Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
565+
);
566+
}
567+
return res;
537568
}
538569

539570
/**
@@ -543,21 +574,45 @@ export class HfInference {
543574
args: TableQuestionAnswerArgs,
544575
options?: Options
545576
): Promise<TableQuestionAnswerReturn> {
546-
return await this.request(args, options);
577+
const res = await this.request<TableQuestionAnswerReturn>(args, options);
578+
const isValidOutput =
579+
typeof res.aggregator === "string" &&
580+
typeof res.answer === "string" &&
581+
Array.isArray(res.cells) &&
582+
res.cells.every((x) => typeof x === "string") &&
583+
Array.isArray(res.coordinates) &&
584+
res.coordinates.every((coord) => Array.isArray(coord) && coord.every((x) => typeof x === "number"));
585+
if (!isValidOutput) {
586+
throw new TypeError(
587+
"Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
588+
);
589+
}
590+
return res;
547591
}
548592

549593
/**
550594
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
551595
*/
552596
public async textClassification(args: TextClassificationArgs, options?: Options): Promise<TextClassificationReturn> {
553-
return (await this.request<TextClassificationReturn[]>(args, options))?.[0];
597+
const res = (await this.request<TextClassificationReturn[]>(args, options))?.[0];
598+
const isValidOutput =
599+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
600+
if (!isValidOutput) {
601+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
602+
}
603+
return res;
554604
}
555605

556606
/**
557607
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
558608
*/
559609
public async textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationReturn> {
560-
return (await this.request<TextGenerationReturn[]>(args, options))?.[0];
610+
const res = await this.request<TextGenerationReturn[]>(args, options);
611+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.generated_text === "string");
612+
if (!isValidOutput) {
613+
throw new TypeError("Invalid inference output: output must be of type Array<generated_text: string>");
614+
}
615+
return res?.[0];
561616
}
562617

563618
/**
@@ -567,14 +622,35 @@ export class HfInference {
567622
args: TokenClassificationArgs,
568623
options?: Options
569624
): Promise<TokenClassificationReturn> {
570-
return toArray(await this.request(args, options));
625+
const res = toArray(await this.request<TokenClassificationReturnValue | TokenClassificationReturn>(args, options));
626+
const isValidOutput =
627+
Array.isArray(res) &&
628+
res.every(
629+
(x) =>
630+
typeof x.end === "number" &&
631+
typeof x.entity_group === "string" &&
632+
typeof x.score === "number" &&
633+
typeof x.start === "number" &&
634+
typeof x.word === "string"
635+
);
636+
if (!isValidOutput) {
637+
throw new TypeError(
638+
"Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
639+
);
640+
}
641+
return res;
571642
}
572643

573644
/**
574645
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
575646
*/
576647
public async translation(args: TranslationArgs, options?: Options): Promise<TranslationReturn> {
577-
return (await this.request<TranslationReturn[]>(args, options))?.[0];
648+
const res = await this.request<TranslationReturn[]>(args, options);
649+
const isValidOutput = Array.isArray(res) && res.every((x) => typeof x.translation_text === "string");
650+
if (!isValidOutput) {
651+
throw new TypeError("Invalid inference output: output must be of type Array<translation_text: string>");
652+
}
653+
return res?.[0];
578654
}
579655

580656
/**
@@ -584,24 +660,55 @@ export class HfInference {
584660
args: ZeroShotClassificationArgs,
585661
options?: Options
586662
): Promise<ZeroShotClassificationReturn> {
587-
return toArray(
588-
await this.request<ZeroShotClassificationReturnValue | ZeroShotClassificationReturnValue[]>(args, options)
663+
const res = toArray(
664+
await this.request<ZeroShotClassificationReturnValue | ZeroShotClassificationReturn>(args, options)
589665
);
666+
const isValidOutput =
667+
Array.isArray(res) &&
668+
res.every(
669+
(x) =>
670+
Array.isArray(x.labels) &&
671+
x.labels.every((_label) => typeof _label === "string") &&
672+
Array.isArray(x.scores) &&
673+
x.scores.every((_score) => typeof _score === "number") &&
674+
typeof x.sequence === "string"
675+
);
676+
if (!isValidOutput) {
677+
throw new TypeError(
678+
"Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
679+
);
680+
}
681+
return res;
590682
}
591683

592684
/**
593685
* This task corresponds to any chatbot like structure. Models tend to have shorter max_length, so please check with caution when using a given model if you need long range dependency or not. Recommended model: microsoft/DialoGPT-large.
594686
*
595687
*/
596688
public async conversational(args: ConversationalArgs, options?: Options): Promise<ConversationalReturn> {
597-
return await this.request(args, options);
689+
const res = await this.request<ConversationalReturn>(args, options);
690+
const isValidOutput =
691+
Array.isArray(res.conversation.generated_responses) &&
692+
res.conversation.generated_responses.every((x) => typeof x === "string") &&
693+
Array.isArray(res.conversation.past_user_inputs) &&
694+
res.conversation.past_user_inputs.every((x) => typeof x === "string") &&
695+
typeof res.generated_text === "string" &&
696+
Array.isArray(res.warnings) &&
697+
res.warnings.every((x) => typeof x === "string");
698+
if (!isValidOutput) {
699+
throw new TypeError(
700+
"Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
701+
);
702+
}
703+
return res;
598704
}
599705

600706
/**
601707
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
602708
*/
603709
public async featureExtraction(args: FeatureExtractionArgs, options?: Options): Promise<FeatureExtractionReturn> {
604-
return await this.request(args, options);
710+
const res = await this.request<FeatureExtractionReturn>(args, options);
711+
return res;
605712
}
606713

607714
/**
@@ -612,10 +719,15 @@ export class HfInference {
612719
args: AutomaticSpeechRecognitionArgs,
613720
options?: Options
614721
): Promise<AutomaticSpeechRecognitionReturn> {
615-
return await this.request(args, {
722+
const res = await this.request<AutomaticSpeechRecognitionReturn>(args, {
616723
...options,
617724
binary: true,
618725
});
726+
const isValidOutput = typeof res.text === "string";
727+
if (!isValidOutput) {
728+
throw new TypeError("Invalid inference output: output must be of type <text: string>");
729+
}
730+
return res;
619731
}
620732

621733
/**
@@ -626,10 +738,16 @@ export class HfInference {
626738
args: AudioClassificationArgs,
627739
options?: Options
628740
): Promise<AudioClassificationReturn> {
629-
return await this.request(args, {
741+
const res = await this.request<AudioClassificationReturn>(args, {
630742
...options,
631743
binary: true,
632744
});
745+
const isValidOutput =
746+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
747+
if (!isValidOutput) {
748+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
749+
}
750+
return res;
633751
}
634752

635753
/**
@@ -640,43 +758,80 @@ export class HfInference {
640758
args: ImageClassificationArgs,
641759
options?: Options
642760
): Promise<ImageClassificationReturn> {
643-
return await this.request(args, {
761+
const res = await this.request<ImageClassificationReturn>(args, {
644762
...options,
645763
binary: true,
646764
});
765+
const isValidOutput =
766+
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
767+
if (!isValidOutput) {
768+
throw new TypeError("Invalid inference output: output must be of type Array<label: string, score: number>");
769+
}
770+
return res;
647771
}
648772

649773
/**
650774
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
651775
* Recommended model: facebook/detr-resnet-50
652776
*/
653777
public async objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionReturn> {
654-
return await this.request(args, {
778+
const res = await this.request<ObjectDetectionReturn>(args, {
655779
...options,
656780
binary: true,
657781
});
782+
const isValidOutput =
783+
Array.isArray(res) &&
784+
res.every(
785+
(x) =>
786+
typeof x.label === "string" &&
787+
typeof x.score === "number" &&
788+
typeof x.box.xmin === "number" &&
789+
typeof x.box.ymin === "number" &&
790+
typeof x.box.xmax === "number" &&
791+
typeof x.box.ymax === "number"
792+
);
793+
if (!isValidOutput) {
794+
throw new TypeError(
795+
"Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
796+
);
797+
}
798+
return res;
658799
}
659800

660801
/**
661802
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
662803
* Recommended model: facebook/detr-resnet-50-panoptic
663804
*/
664805
public async imageSegmentation(args: ImageSegmentationArgs, options?: Options): Promise<ImageSegmentationReturn> {
665-
return await this.request(args, {
806+
const res = await this.request<ImageSegmentationReturn>(args, {
666807
...options,
667808
binary: true,
668809
});
810+
const isValidOutput =
811+
Array.isArray(res) &&
812+
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
813+
if (!isValidOutput) {
814+
throw new TypeError(
815+
"Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
816+
);
817+
}
818+
return res;
669819
}
670820

671821
/**
672822
* This task reads some text input and outputs an image.
673823
* Recommended model: stabilityai/stable-diffusion-2
674824
*/
675825
public async textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageReturn> {
676-
return await this.request(args, {
826+
const res = await this.request<TextToImageReturn>(args, {
677827
...options,
678828
blob: true,
679829
});
830+
const isValidOutput = res && res instanceof Blob;
831+
if (!isValidOutput) {
832+
throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
833+
}
834+
return res;
680835
}
681836

682837
public async request<T>(

0 commit comments

Comments
 (0)