@@ -519,21 +519,52 @@ export class HfInference {
519
519
* Tries to fill in a hole with a missing word (token to be precise). That’s the base task for BERT models.
520
520
*/
521
521
public async fillMask ( args : FillMaskArgs , options ?: Options ) : Promise < FillMaskReturn > {
522
- return await this . request ( args , options ) ;
522
+ const res = await this . request < FillMaskReturn > ( args , options ) ;
523
+ const isValidOutput =
524
+ Array . isArray ( res ) &&
525
+ res . every (
526
+ ( x ) =>
527
+ typeof x . score === "number" &&
528
+ typeof x . sequence === "string" &&
529
+ typeof x . token === "number" &&
530
+ typeof x . token_str === "string"
531
+ ) ;
532
+ if ( ! isValidOutput ) {
533
+ throw new TypeError (
534
+ "Invalid inference output: output must be of type Array<score: number, sequence:string, token:number, token_str:string>"
535
+ ) ;
536
+ }
537
+ return res ;
523
538
}
524
539
525
540
/**
526
541
* This task is well known to summarize longer text into shorter text. Be careful, some models have a maximum length of input. That means that the summary cannot handle full books for instance. Be careful when choosing your model.
527
542
*/
528
543
public async summarization ( args : SummarizationArgs , options ?: Options ) : Promise < SummarizationReturn > {
529
- return ( await this . request < SummarizationReturn [ ] > ( args , options ) ) ?. [ 0 ] ;
544
+ const res = await this . request < SummarizationReturn [ ] > ( args , options ) ;
545
+ const isValidOutput = Array . isArray ( res ) && res . every ( ( x ) => typeof x . summary_text === "string" ) ;
546
+ if ( ! isValidOutput ) {
547
+ throw new TypeError ( "Invalid inference output: output must be of type Array<summary_text: string>" ) ;
548
+ }
549
+ return res ?. [ 0 ] ;
530
550
}
531
551
532
552
/**
533
553
* Want to have a nice know-it-all bot that can answer any question?. Recommended model: deepset/roberta-base-squad2
534
554
*/
535
555
public async questionAnswer ( args : QuestionAnswerArgs , options ?: Options ) : Promise < QuestionAnswerReturn > {
536
- return await this . request ( args , options ) ;
556
+ const res = await this . request < QuestionAnswerReturn > ( args , options ) ;
557
+ const isValidOutput =
558
+ typeof res . answer === "string" &&
559
+ typeof res . end === "number" &&
560
+ typeof res . score === "number" &&
561
+ typeof res . start === "number" ;
562
+ if ( ! isValidOutput ) {
563
+ throw new TypeError (
564
+ "Invalid inference output: output must be of type <answer: string, end: number, score: number, start: number>"
565
+ ) ;
566
+ }
567
+ return res ;
537
568
}
538
569
539
570
/**
@@ -543,21 +574,45 @@ export class HfInference {
543
574
args : TableQuestionAnswerArgs ,
544
575
options ?: Options
545
576
) : Promise < TableQuestionAnswerReturn > {
546
- return await this . request ( args , options ) ;
577
+ const res = await this . request < TableQuestionAnswerReturn > ( args , options ) ;
578
+ const isValidOutput =
579
+ typeof res . aggregator === "string" &&
580
+ typeof res . answer === "string" &&
581
+ Array . isArray ( res . cells ) &&
582
+ res . cells . every ( ( x ) => typeof x === "string" ) &&
583
+ Array . isArray ( res . coordinates ) &&
584
+ res . coordinates . every ( ( coord ) => Array . isArray ( coord ) && coord . every ( ( x ) => typeof x === "number" ) ) ;
585
+ if ( ! isValidOutput ) {
586
+ throw new TypeError (
587
+ "Invalid inference output: output must be of type <aggregator: string, answer: string, cells: string[], coordinates: number[][]>"
588
+ ) ;
589
+ }
590
+ return res ;
547
591
}
548
592
549
593
/**
550
594
* Usually used for sentiment-analysis this will output the likelihood of classes of an input. Recommended model: distilbert-base-uncased-finetuned-sst-2-english
551
595
*/
552
596
public async textClassification ( args : TextClassificationArgs , options ?: Options ) : Promise < TextClassificationReturn > {
553
- return ( await this . request < TextClassificationReturn [ ] > ( args , options ) ) ?. [ 0 ] ;
597
+ const res = ( await this . request < TextClassificationReturn [ ] > ( args , options ) ) ?. [ 0 ] ;
598
+ const isValidOutput =
599
+ Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
600
+ if ( ! isValidOutput ) {
601
+ throw new TypeError ( "Invalid inference output: output must be of type Array<label: string, score: number>" ) ;
602
+ }
603
+ return res ;
554
604
}
555
605
556
606
/**
557
607
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
558
608
*/
559
609
public async textGeneration ( args : TextGenerationArgs , options ?: Options ) : Promise < TextGenerationReturn > {
560
- return ( await this . request < TextGenerationReturn [ ] > ( args , options ) ) ?. [ 0 ] ;
610
+ const res = await this . request < TextGenerationReturn [ ] > ( args , options ) ;
611
+ const isValidOutput = Array . isArray ( res ) && res . every ( ( x ) => typeof x . generated_text === "string" ) ;
612
+ if ( ! isValidOutput ) {
613
+ throw new TypeError ( "Invalid inference output: output must be of type Array<generated_text: string>" ) ;
614
+ }
615
+ return res ?. [ 0 ] ;
561
616
}
562
617
563
618
/**
@@ -567,14 +622,35 @@ export class HfInference {
567
622
args : TokenClassificationArgs ,
568
623
options ?: Options
569
624
) : Promise < TokenClassificationReturn > {
570
- return toArray ( await this . request ( args , options ) ) ;
625
+ const res = toArray ( await this . request < TokenClassificationReturnValue | TokenClassificationReturn > ( args , options ) ) ;
626
+ const isValidOutput =
627
+ Array . isArray ( res ) &&
628
+ res . every (
629
+ ( x ) =>
630
+ typeof x . end === "number" &&
631
+ typeof x . entity_group === "string" &&
632
+ typeof x . score === "number" &&
633
+ typeof x . start === "number" &&
634
+ typeof x . word === "string"
635
+ ) ;
636
+ if ( ! isValidOutput ) {
637
+ throw new TypeError (
638
+ "Invalid inference output: output must be of type Array<end: number, entity_group: string, score: number, start: number, word: string>"
639
+ ) ;
640
+ }
641
+ return res ;
571
642
}
572
643
573
644
/**
574
645
* This task is well known to translate text from one language to another. Recommended model: Helsinki-NLP/opus-mt-ru-en.
575
646
*/
576
647
public async translation ( args : TranslationArgs , options ?: Options ) : Promise < TranslationReturn > {
577
- return ( await this . request < TranslationReturn [ ] > ( args , options ) ) ?. [ 0 ] ;
648
+ const res = await this . request < TranslationReturn [ ] > ( args , options ) ;
649
+ const isValidOutput = Array . isArray ( res ) && res . every ( ( x ) => typeof x . translation_text === "string" ) ;
650
+ if ( ! isValidOutput ) {
651
+ throw new TypeError ( "Invalid inference output: output must be of type Array<translation_text: string>" ) ;
652
+ }
653
+ return res ?. [ 0 ] ;
578
654
}
579
655
580
656
/**
@@ -584,24 +660,55 @@ export class HfInference {
584
660
args : ZeroShotClassificationArgs ,
585
661
options ?: Options
586
662
) : Promise < ZeroShotClassificationReturn > {
587
- return toArray (
588
- await this . request < ZeroShotClassificationReturnValue | ZeroShotClassificationReturnValue [ ] > ( args , options )
663
+ const res = toArray (
664
+ await this . request < ZeroShotClassificationReturnValue | ZeroShotClassificationReturn > ( args , options )
589
665
) ;
666
+ const isValidOutput =
667
+ Array . isArray ( res ) &&
668
+ res . every (
669
+ ( x ) =>
670
+ Array . isArray ( x . labels ) &&
671
+ x . labels . every ( ( _label ) => typeof _label === "string" ) &&
672
+ Array . isArray ( x . scores ) &&
673
+ x . scores . every ( ( _score ) => typeof _score === "number" ) &&
674
+ typeof x . sequence === "string"
675
+ ) ;
676
+ if ( ! isValidOutput ) {
677
+ throw new TypeError (
678
+ "Invalid inference output: output must be of type Array<labels: string[], scores: number[], sequence: string>"
679
+ ) ;
680
+ }
681
+ return res ;
590
682
}
591
683
592
684
/**
593
685
* This task corresponds to any chatbot like structure. Models tend to have shorter max_length, so please check with caution when using a given model if you need long range dependency or not. Recommended model: microsoft/DialoGPT-large.
594
686
*
595
687
*/
596
688
public async conversational ( args : ConversationalArgs , options ?: Options ) : Promise < ConversationalReturn > {
597
- return await this . request ( args , options ) ;
689
+ const res = await this . request < ConversationalReturn > ( args , options ) ;
690
+ const isValidOutput =
691
+ Array . isArray ( res . conversation . generated_responses ) &&
692
+ res . conversation . generated_responses . every ( ( x ) => typeof x === "string" ) &&
693
+ Array . isArray ( res . conversation . past_user_inputs ) &&
694
+ res . conversation . past_user_inputs . every ( ( x ) => typeof x === "string" ) &&
695
+ typeof res . generated_text === "string" &&
696
+ Array . isArray ( res . warnings ) &&
697
+ res . warnings . every ( ( x ) => typeof x === "string" ) ;
698
+ if ( ! isValidOutput ) {
699
+ throw new TypeError (
700
+ "Invalid inference output: output must be of type <conversation: {generated_responses: string[], past_user_inputs: string[]}, generated_text: string, warnings: string[]>"
701
+ ) ;
702
+ }
703
+ return res ;
598
704
}
599
705
600
706
/**
601
707
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
602
708
*/
603
709
public async featureExtraction ( args : FeatureExtractionArgs , options ?: Options ) : Promise < FeatureExtractionReturn > {
604
- return await this . request ( args , options ) ;
710
+ const res = await this . request < FeatureExtractionReturn > ( args , options ) ;
711
+ return res ;
605
712
}
606
713
607
714
/**
@@ -612,10 +719,15 @@ export class HfInference {
612
719
args : AutomaticSpeechRecognitionArgs ,
613
720
options ?: Options
614
721
) : Promise < AutomaticSpeechRecognitionReturn > {
615
- return await this . request ( args , {
722
+ const res = await this . request < AutomaticSpeechRecognitionReturn > ( args , {
616
723
...options ,
617
724
binary : true ,
618
725
} ) ;
726
+ const isValidOutput = typeof res . text === "string" ;
727
+ if ( ! isValidOutput ) {
728
+ throw new TypeError ( "Invalid inference output: output must be of type <text: string>" ) ;
729
+ }
730
+ return res ;
619
731
}
620
732
621
733
/**
@@ -626,10 +738,16 @@ export class HfInference {
626
738
args : AudioClassificationArgs ,
627
739
options ?: Options
628
740
) : Promise < AudioClassificationReturn > {
629
- return await this . request ( args , {
741
+ const res = await this . request < AudioClassificationReturn > ( args , {
630
742
...options ,
631
743
binary : true ,
632
744
} ) ;
745
+ const isValidOutput =
746
+ Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
747
+ if ( ! isValidOutput ) {
748
+ throw new TypeError ( "Invalid inference output: output must be of type Array<label: string, score: number>" ) ;
749
+ }
750
+ return res ;
633
751
}
634
752
635
753
/**
@@ -640,43 +758,80 @@ export class HfInference {
640
758
args : ImageClassificationArgs ,
641
759
options ?: Options
642
760
) : Promise < ImageClassificationReturn > {
643
- return await this . request ( args , {
761
+ const res = await this . request < ImageClassificationReturn > ( args , {
644
762
...options ,
645
763
binary : true ,
646
764
} ) ;
765
+ const isValidOutput =
766
+ Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
767
+ if ( ! isValidOutput ) {
768
+ throw new TypeError ( "Invalid inference output: output must be of type Array<label: string, score: number>" ) ;
769
+ }
770
+ return res ;
647
771
}
648
772
649
773
/**
650
774
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
651
775
* Recommended model: facebook/detr-resnet-50
652
776
*/
653
777
public async objectDetection ( args : ObjectDetectionArgs , options ?: Options ) : Promise < ObjectDetectionReturn > {
654
- return await this . request ( args , {
778
+ const res = await this . request < ObjectDetectionReturn > ( args , {
655
779
...options ,
656
780
binary : true ,
657
781
} ) ;
782
+ const isValidOutput =
783
+ Array . isArray ( res ) &&
784
+ res . every (
785
+ ( x ) =>
786
+ typeof x . label === "string" &&
787
+ typeof x . score === "number" &&
788
+ typeof x . box . xmin === "number" &&
789
+ typeof x . box . ymin === "number" &&
790
+ typeof x . box . xmax === "number" &&
791
+ typeof x . box . ymax === "number"
792
+ ) ;
793
+ if ( ! isValidOutput ) {
794
+ throw new TypeError (
795
+ "Invalid inference output: output must be of type Array<{label:string; score:number; box:{xmin:number; ymin:number; xmax:number; ymax:number}}>"
796
+ ) ;
797
+ }
798
+ return res ;
658
799
}
659
800
660
801
/**
661
802
* This task reads some image input and outputs the likelihood of classes & bounding boxes of detected objects.
662
803
* Recommended model: facebook/detr-resnet-50-panoptic
663
804
*/
664
805
public async imageSegmentation ( args : ImageSegmentationArgs , options ?: Options ) : Promise < ImageSegmentationReturn > {
665
- return await this . request ( args , {
806
+ const res = await this . request < ImageSegmentationReturn > ( args , {
666
807
...options ,
667
808
binary : true ,
668
809
} ) ;
810
+ const isValidOutput =
811
+ Array . isArray ( res ) &&
812
+ res . every ( ( x ) => typeof x . label === "string" && typeof x . mask === "string" && typeof x . score === "number" ) ;
813
+ if ( ! isValidOutput ) {
814
+ throw new TypeError (
815
+ "Invalid inference output: output must be of type Array<label: string, mask: string, score: number>"
816
+ ) ;
817
+ }
818
+ return res ;
669
819
}
670
820
671
821
/**
672
822
* This task reads some text input and outputs an image.
673
823
* Recommended model: stabilityai/stable-diffusion-2
674
824
*/
675
825
public async textToImage ( args : TextToImageArgs , options ?: Options ) : Promise < TextToImageReturn > {
676
- return await this . request ( args , {
826
+ const res = await this . request < TextToImageReturn > ( args , {
677
827
...options ,
678
828
blob : true ,
679
829
} ) ;
830
+ const isValidOutput = res && res instanceof Blob ;
831
+ if ( ! isValidOutput ) {
832
+ throw new TypeError ( "Invalid inference output: output must be of type object & of instance Blob" ) ;
833
+ }
834
+ return res ;
680
835
}
681
836
682
837
public async request < T > (
0 commit comments