Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/modelgauge/suts/huggingface_chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def translate_response(
class HuggingFaceChatCompletionDedicatedSUT(BaseHuggingFaceChatCompletionSUT):
"""A Hugging Face SUT that is hosted on a dedicated inference endpoint and uses the chat_completion API."""

def __init__(self, uid: str, inference_endpoint: str, token: HuggingFaceInferenceToken):
def __init__(self, uid: str, inference_endpoint: str, model: str, token: HuggingFaceInferenceToken):
super().__init__(uid, token)
self.inference_endpoint = inference_endpoint
self.model = model

def _create_client(self):
endpoint = get_inference_endpoint(self.inference_endpoint, token=self.token.value)
Expand Down Expand Up @@ -158,6 +159,7 @@ def translate_text_prompt(self, prompt: TextPrompt, options: SUTOptions) -> Hugg
if options.top_logprobs is not None:
logprobs = True
return HuggingFaceChatCompletionRequest(
model=self.model,
messages=[ChatMessage(role="user", content=prompt.text)],
logprobs=logprobs,
**options.model_dump(),
Expand All @@ -168,6 +170,7 @@ def translate_chat_prompt(self, prompt: ChatPrompt, options: SUTOptions) -> Hugg
if options.top_logprobs is not None:
logprobs = True
return HuggingFaceChatCompletionRequest(
model=self.model,
messages=[ChatMessage(role=p.role.lower(), content=p.text) for p in prompt.messages],
logprobs=logprobs,
**options.model_dump(),
Expand Down
21 changes: 11 additions & 10 deletions src/modelgauge/suts/huggingface_sut_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,23 +104,24 @@ def _find(sut_definition: SUTDefinition) -> str | None:
try:
endpoints = hfh.list_inference_endpoints()
for e in endpoints:
if e.repository == model_name and e.status != "running":
try:
e.resume()
except Exception as ie:
logger.error(
f"Found endpoint for {model_name} but unable to start it. Check your token's permissions. {ie}"
)
return e.name
if e.repository.lower() == model_name:
if e.status != "running":
try:
e.resume()
except Exception as ie:
logger.error(
f"Found endpoint for {model_name} but unable to start it. Check your token's permissions. {ie}"
)
return e.name, e.repository
except Exception as oe:
logger.error(f"Error looking up dedicated endpoints for {model_name}: {oe}")
return None

def make_sut(self, sut_definition: SUTDefinition) -> HuggingFaceChatCompletionDedicatedSUT:
endpoint_name = HuggingFaceChatCompletionDedicatedSUTFactory._find(sut_definition)
endpoint_name, model_name = HuggingFaceChatCompletionDedicatedSUTFactory._find(sut_definition)
if not endpoint_name:
raise ProviderNotFoundError(
f"No dedicated inference endpoint found for {sut_definition.external_model_name()}."
)
sut_uid = sut_definition.dynamic_uid
return HuggingFaceChatCompletionDedicatedSUT(sut_uid, endpoint_name, self.injected_secrets())
return HuggingFaceChatCompletionDedicatedSUT(sut_uid, endpoint_name, model_name, self.injected_secrets()[0])
Loading