diff --git a/src/huggingface_hub/inference/_providers/__init__.py b/src/huggingface_hub/inference/_providers/__init__.py index a40d87efda..73c3009904 100644 --- a/src/huggingface_hub/inference/_providers/__init__.py +++ b/src/huggingface_hub/inference/_providers/__init__.py @@ -23,7 +23,7 @@ from .nebius import NebiusConversationalTask, NebiusTextGenerationTask, NebiusTextToImageTask from .novita import NovitaConversationalTask, NovitaTextGenerationTask, NovitaTextToVideoTask from .openai import OpenAIConversationalTask -from .replicate import ReplicateTask, ReplicateTextToSpeechTask +from .replicate import ReplicateTask, ReplicateTextToImageTask, ReplicateTextToSpeechTask from .sambanova import SambanovaConversationalTask, SambanovaFeatureExtractionTask from .together import TogetherConversationalTask, TogetherTextGenerationTask, TogetherTextToImageTask @@ -115,7 +115,7 @@ "conversational": OpenAIConversationalTask(), }, "replicate": { - "text-to-image": ReplicateTask("text-to-image"), + "text-to-image": ReplicateTextToImageTask(), "text-to-speech": ReplicateTextToSpeechTask(), "text-to-video": ReplicateTask("text-to-video"), }, diff --git a/src/huggingface_hub/inference/_providers/replicate.py b/src/huggingface_hub/inference/_providers/replicate.py index d76eaa2b5a..2ba3127647 100644 --- a/src/huggingface_hub/inference/_providers/replicate.py +++ b/src/huggingface_hub/inference/_providers/replicate.py @@ -47,6 +47,19 @@ def get_response(self, response: Union[bytes, Dict], request_params: Optional[Re return get_session().get(output_url).content +class ReplicateTextToImageTask(ReplicateTask): + def __init__(self): + super().__init__("text-to-image") + + def _prepare_payload_as_dict( + self, inputs: Any, parameters: Dict, provider_mapping_info: InferenceProviderMapping + ) -> Optional[Dict]: + payload: Dict = super()._prepare_payload_as_dict(inputs, parameters, provider_mapping_info) # type: ignore[assignment] + if provider_mapping_info.adapter_weights_path is not None: + payload["input"]["lora_weights"] = f"https://huggingface.co/{provider_mapping_info.hf_model_id}" + return payload + + class ReplicateTextToSpeechTask(ReplicateTask): def __init__(self): super().__init__("text-to-speech")