Skip to content

Commit 20875ad

Browse files
Updated Hugging Face chat and magics processing with new APIs, clients (#784)
* Updated HF chat processing (1) The API has changed and uses the HuggingFaceClient class instead of HuggingFaceHub, which is deprecated. (2) InferenceClient replaces InferenceAPI (3) Removed legacy code that does not work with the new APIs/ * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle text gen and text_to_image tasks Added logic to branch to one of text-gen or text-to-image tasks based on the type of response received. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reworking conditional branching for text vs image Used a different approach to check for task type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e0eaeaa commit 20875ad

File tree

3 files changed

+67
-64
lines changed

3 files changed

+67
-64
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
Bedrock,
4545
Cohere,
4646
GPT4All,
47-
HuggingFaceHub,
47+
HuggingFaceEndpoint,
4848
OpenAI,
4949
SagemakerEndpoint,
5050
Together,
@@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs):
318318
),
319319
"text": PromptTemplate.from_template("{prompt}"), # No customization
320320
}
321-
322321
super().__init__(*args, **kwargs, **model_kwargs)
323322

324323
async def _call_in_executor(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
@@ -582,14 +581,10 @@ def allows_concurrency(self):
582581
return False
583582

584583

585-
HUGGINGFACE_HUB_VALID_TASKS = (
586-
"text2text-generation",
587-
"text-generation",
588-
"text-to-image",
589-
)
590-
591-
592-
class HfHubProvider(BaseProvider, HuggingFaceHub):
584+
# References for using HuggingFaceEndpoint and InferenceClient:
585+
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
586+
# https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
587+
class HfHubProvider(BaseProvider, HuggingFaceEndpoint):
593588
id = "huggingface_hub"
594589
name = "Hugging Face Hub"
595590
models = ["*"]
@@ -609,33 +604,35 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
609604
@root_validator()
610605
def validate_environment(cls, values: Dict) -> Dict:
611606
"""Validate that api key and python package exists in environment."""
612-
huggingfacehub_api_token = get_from_dict_or_env(
613-
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
614-
)
615607
try:
616-
from huggingface_hub.inference_api import InferenceApi
608+
huggingfacehub_api_token = get_from_dict_or_env(
609+
values, "huggingfacehub_api_token", "HUGGINGFACEHUB_API_TOKEN"
610+
)
611+
except Exception as e:
612+
raise ValueError(
613+
"Could not authenticate with huggingface_hub. "
614+
"Please check your API token."
615+
) from e
616+
try:
617+
from huggingface_hub import InferenceClient
617618

618-
repo_id = values["repo_id"]
619-
client = InferenceApi(
620-
repo_id=repo_id,
619+
values["client"] = InferenceClient(
620+
model=values["model"],
621+
timeout=values["timeout"],
621622
token=huggingfacehub_api_token,
622-
task=values.get("task"),
623+
**values["server_kwargs"],
623624
)
624-
if client.task not in HUGGINGFACE_HUB_VALID_TASKS:
625-
raise ValueError(
626-
f"Got invalid task {client.task}, "
627-
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
628-
)
629-
values["client"] = client
630625
except ImportError:
631626
raise ValueError(
632627
"Could not import huggingface_hub python package. "
633628
"Please install it with `pip install huggingface_hub`."
634629
)
635630
return values
636631

637-
# Handle image outputs
638-
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
632+
# Handle text and image outputs
633+
def _call(
634+
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
635+
) -> str:
639636
"""Call out to Hugging Face Hub's inference endpoint.
640637
641638
Args:
@@ -650,45 +647,51 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
650647
651648
response = hf("Tell me a joke.")
652649
"""
653-
_model_kwargs = self.model_kwargs or {}
654-
response = self.client(inputs=prompt, params=_model_kwargs)
655-
656-
if type(response) is dict and "error" in response:
657-
raise ValueError(f"Error raised by inference API: {response['error']}")
658-
659-
# Custom code for responding to image generation responses
660-
if self.client.task == "text-to-image":
661-
imageFormat = response.format # Presume it's a PIL ImageFile
662-
mimeType = ""
663-
if imageFormat == "JPEG":
664-
mimeType = "image/jpeg"
665-
elif imageFormat == "PNG":
666-
mimeType = "image/png"
667-
elif imageFormat == "GIF":
668-
mimeType = "image/gif"
650+
invocation_params = self._invocation_params(stop, **kwargs)
651+
invocation_params["stop"] = invocation_params[
652+
"stop_sequences"
653+
] # porting 'stop_sequences' into the 'stop' argument
654+
response = self.client.post(
655+
json={"inputs": prompt, "parameters": invocation_params},
656+
stream=False,
657+
task=self.task,
658+
)
659+
660+
try:
661+
if "generated_text" in str(response):
662+
# text2 text or text-generation task
663+
response_text = json.loads(response.decode())[0]["generated_text"]
664+
# Maybe the generation has stopped at one of the stop sequences:
665+
# then we remove this stop sequence from the end of the generated text
666+
for stop_seq in invocation_params["stop_sequences"]:
667+
if response_text[-len(stop_seq) :] == stop_seq:
668+
response_text = response_text[: -len(stop_seq)]
669+
return response_text
669670
else:
670-
raise ValueError(f"Unrecognized image format {imageFormat}")
671-
672-
buffer = io.BytesIO()
673-
response.save(buffer, format=imageFormat)
674-
# Encode image data to Base64 bytes, then decode bytes to str
675-
return mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
676-
677-
if self.client.task == "text-generation":
678-
# Text generation return includes the starter text.
679-
text = response[0]["generated_text"][len(prompt) :]
680-
elif self.client.task == "text2text-generation":
681-
text = response[0]["generated_text"]
682-
else:
671+
# text-to-image task
672+
# https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
673+
# Custom code for responding to image generation responses
674+
image = self.client.text_to_image(prompt)
675+
imageFormat = image.format # Presume it's a PIL ImageFile
676+
mimeType = ""
677+
if imageFormat == "JPEG":
678+
mimeType = "image/jpeg"
679+
elif imageFormat == "PNG":
680+
mimeType = "image/png"
681+
elif imageFormat == "GIF":
682+
mimeType = "image/gif"
683+
else:
684+
raise ValueError(f"Unrecognized image format {imageFormat}")
685+
buffer = io.BytesIO()
686+
image.save(buffer, format=imageFormat)
687+
# # Encode image data to Base64 bytes, then decode bytes to str
688+
return (
689+
mimeType + ";base64," + base64.b64encode(buffer.getvalue()).decode()
690+
)
691+
except:
683692
raise ValueError(
684-
f"Got invalid task {self.client.task}, "
685-
f"currently only {HUGGINGFACE_HUB_VALID_TASKS} are supported"
693+
"Task not supported, only text-generation and text-to-image tasks are valid."
686694
)
687-
if stop is not None:
688-
# This is a bit hacky, but I can't figure out a better way to enforce
689-
# stop tokens when making calls to huggingface_hub.
690-
text = enforce_stop_tokens(text, stop)
691-
return text
692695

693696
async def _acall(self, *args, **kwargs) -> Coroutine[Any, Any, str]:
694697
return await self._call_in_executor(*args, **kwargs)

packages/jupyter-ai/jupyter_ai/chat_handlers/ask.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from jupyter_ai_magics.providers import BaseProvider
66
from langchain.chains import ConversationalRetrievalChain
77
from langchain.memory import ConversationBufferWindowMemory
8-
from langchain.prompts import PromptTemplate
8+
from langchain_core.prompts import PromptTemplate
99

1010
from .base import BaseChatHandler, SlashCommandRoutingType
1111

packages/jupyter-ai/jupyter_ai/chat_handlers/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from langchain.chains import LLMChain
1313
from langchain.llms import BaseLLM
1414
from langchain.output_parsers import PydanticOutputParser
15-
from langchain.prompts import PromptTemplate
1615
from langchain.pydantic_v1 import BaseModel
1716
from langchain.schema.output_parser import BaseOutputParser
17+
from langchain_core.prompts import PromptTemplate
1818

1919

2020
class OutlineSection(BaseModel):

0 commit comments

Comments
 (0)