44
44
Bedrock ,
45
45
Cohere ,
46
46
GPT4All ,
47
- HuggingFaceHub ,
47
+ HuggingFaceEndpoint ,
48
48
OpenAI ,
49
49
SagemakerEndpoint ,
50
50
Together ,
@@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs):
318
318
),
319
319
"text" : PromptTemplate .from_template ("{prompt}" ), # No customization
320
320
}
321
-
322
321
super ().__init__ (* args , ** kwargs , ** model_kwargs )
323
322
324
323
async def _call_in_executor (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
@@ -582,14 +581,10 @@ def allows_concurrency(self):
582
581
return False
583
582
584
583
585
- HUGGINGFACE_HUB_VALID_TASKS = (
586
- "text2text-generation" ,
587
- "text-generation" ,
588
- "text-to-image" ,
589
- )
590
-
591
-
592
- class HfHubProvider (BaseProvider , HuggingFaceHub ):
584
+ # References for using HuggingFaceEndpoint and InferenceClient:
585
+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
586
+ # https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
587
+ class HfHubProvider (BaseProvider , HuggingFaceEndpoint ):
593
588
id = "huggingface_hub"
594
589
name = "Hugging Face Hub"
595
590
models = ["*" ]
@@ -609,33 +604,35 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
609
604
@root_validator ()
610
605
def validate_environment (cls , values : Dict ) -> Dict :
611
606
"""Validate that api key and python package exists in environment."""
612
- huggingfacehub_api_token = get_from_dict_or_env (
613
- values , "huggingfacehub_api_token" , "HUGGINGFACEHUB_API_TOKEN"
614
- )
615
607
try :
616
- from huggingface_hub .inference_api import InferenceApi
608
+ huggingfacehub_api_token = get_from_dict_or_env (
609
+ values , "huggingfacehub_api_token" , "HUGGINGFACEHUB_API_TOKEN"
610
+ )
611
+ except Exception as e :
612
+ raise ValueError (
613
+ "Could not authenticate with huggingface_hub. "
614
+ "Please check your API token."
615
+ ) from e
616
+ try :
617
+ from huggingface_hub import InferenceClient
617
618
618
- repo_id = values ["repo_id" ]
619
- client = InferenceApi (
620
- repo_id = repo_id ,
619
+ values ["client" ] = InferenceClient (
620
+ model = values [ "model" ],
621
+ timeout = values [ "timeout" ] ,
621
622
token = huggingfacehub_api_token ,
622
- task = values . get ( "task" ) ,
623
+ ** values [ "server_kwargs" ] ,
623
624
)
624
- if client .task not in HUGGINGFACE_HUB_VALID_TASKS :
625
- raise ValueError (
626
- f"Got invalid task { client .task } , "
627
- f"currently only { HUGGINGFACE_HUB_VALID_TASKS } are supported"
628
- )
629
- values ["client" ] = client
630
625
except ImportError :
631
626
raise ValueError (
632
627
"Could not import huggingface_hub python package. "
633
628
"Please install it with `pip install huggingface_hub`."
634
629
)
635
630
return values
636
631
637
- # Handle image outputs
638
- def _call (self , prompt : str , stop : Optional [List [str ]] = None ) -> str :
632
+ # Handle text and image outputs
633
+ def _call (
634
+ self , prompt : str , stop : Optional [List [str ]] = None , ** kwargs : Any
635
+ ) -> str :
639
636
"""Call out to Hugging Face Hub's inference endpoint.
640
637
641
638
Args:
@@ -650,45 +647,51 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
650
647
651
648
response = hf("Tell me a joke.")
652
649
"""
653
- _model_kwargs = self .model_kwargs or {}
654
- response = self .client (inputs = prompt , params = _model_kwargs )
655
-
656
- if type (response ) is dict and "error" in response :
657
- raise ValueError (f"Error raised by inference API: { response ['error' ]} " )
658
-
659
- # Custom code for responding to image generation responses
660
- if self .client .task == "text-to-image" :
661
- imageFormat = response .format # Presume it's a PIL ImageFile
662
- mimeType = ""
663
- if imageFormat == "JPEG" :
664
- mimeType = "image/jpeg"
665
- elif imageFormat == "PNG" :
666
- mimeType = "image/png"
667
- elif imageFormat == "GIF" :
668
- mimeType = "image/gif"
650
+ invocation_params = self ._invocation_params (stop , ** kwargs )
651
+ invocation_params ["stop" ] = invocation_params [
652
+ "stop_sequences"
653
+ ] # porting 'stop_sequences' into the 'stop' argument
654
+ response = self .client .post (
655
+ json = {"inputs" : prompt , "parameters" : invocation_params },
656
+ stream = False ,
657
+ task = self .task ,
658
+ )
659
+
660
+ try :
661
+ if "generated_text" in str (response ):
662
+ # text2 text or text-generation task
663
+ response_text = json .loads (response .decode ())[0 ]["generated_text" ]
664
+ # Maybe the generation has stopped at one of the stop sequences:
665
+ # then we remove this stop sequence from the end of the generated text
666
+ for stop_seq in invocation_params ["stop_sequences" ]:
667
+ if response_text [- len (stop_seq ) :] == stop_seq :
668
+ response_text = response_text [: - len (stop_seq )]
669
+ return response_text
669
670
else :
670
- raise ValueError (f"Unrecognized image format { imageFormat } " )
671
-
672
- buffer = io .BytesIO ()
673
- response .save (buffer , format = imageFormat )
674
- # Encode image data to Base64 bytes, then decode bytes to str
675
- return mimeType + ";base64," + base64 .b64encode (buffer .getvalue ()).decode ()
676
-
677
- if self .client .task == "text-generation" :
678
- # Text generation return includes the starter text.
679
- text = response [0 ]["generated_text" ][len (prompt ) :]
680
- elif self .client .task == "text2text-generation" :
681
- text = response [0 ]["generated_text" ]
682
- else :
671
+ # text-to-image task
672
+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
673
+ # Custom code for responding to image generation responses
674
+ image = self .client .text_to_image (prompt )
675
+ imageFormat = image .format # Presume it's a PIL ImageFile
676
+ mimeType = ""
677
+ if imageFormat == "JPEG" :
678
+ mimeType = "image/jpeg"
679
+ elif imageFormat == "PNG" :
680
+ mimeType = "image/png"
681
+ elif imageFormat == "GIF" :
682
+ mimeType = "image/gif"
683
+ else :
684
+ raise ValueError (f"Unrecognized image format { imageFormat } " )
685
+ buffer = io .BytesIO ()
686
+ image .save (buffer , format = imageFormat )
687
+ # # Encode image data to Base64 bytes, then decode bytes to str
688
+ return (
689
+ mimeType + ";base64," + base64 .b64encode (buffer .getvalue ()).decode ()
690
+ )
691
+ except :
683
692
raise ValueError (
684
- f"Got invalid task { self .client .task } , "
685
- f"currently only { HUGGINGFACE_HUB_VALID_TASKS } are supported"
693
+ "Task not supported, only text-generation and text-to-image tasks are valid."
686
694
)
687
- if stop is not None :
688
- # This is a bit hacky, but I can't figure out a better way to enforce
689
- # stop tokens when making calls to huggingface_hub.
690
- text = enforce_stop_tokens (text , stop )
691
- return text
692
695
693
696
async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
694
697
return await self ._call_in_executor (* args , ** kwargs )
0 commit comments