Skip to content

Commit 61c7d1b

Browse files
committed
Bug fixes for HF models
1 parent 462bed0 commit 61c7d1b

File tree

4 files changed

+25
-1
lines changed

4 files changed

+25
-1
lines changed

sagemaker-serve/src/sagemaker/serve/builder/schema_builder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,11 @@ def _get_deserializer(self, obj):
196196
return StringDeserializer()
197197
if _is_jsonable(obj):
198198
return JSONDeserializer()
199+
if isinstance(obj, dict) and "content_type" in obj:
200+
try:
201+
return BytesDeserializer()
202+
except ValueError as e:
203+
logger.error(e)
199204

200205
raise ValueError(
201206
(

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,20 @@ def _build_for_transformers(self) -> Model:
687687
hf_model_id, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
688688
)
689689
elif isinstance(self.model, str): # Only set HF_MODEL_ID if model is a string
690+
# Get model metadata for task detection (same pattern as _build_for_triton)
691+
hf_model_md = self.get_huggingface_model_metadata(
692+
self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
693+
)
694+
model_task = hf_model_md.get("pipeline_tag")
695+
if model_task:
696+
self.env_vars.update({"HF_TASK": model_task})
697+
690698
self.env_vars.update({"HF_MODEL_ID": self.model})
699+
700+
# Add HuggingFace token if available (same as other methods)
701+
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
702+
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
703+
691704
# Get HF config for string model IDs
692705
if hasattr(self.env_vars, "HF_API_TOKEN"):
693706
self.hf_model_config = _get_model_config_properties_from_hf(

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -992,6 +992,12 @@ def _hf_schema_builder_init(self, model_task: str) -> None:
992992
sample_outputs,
993993
) = remote_hf_schema_helper.get_resolved_hf_schema_for_task(model_task)
994994

995+
# Unwrap list outputs for binary tasks (text-to-image, audio, etc.)
996+
# Remote schema retriever returns [{'data': b'...', 'content_type': '...'}]
997+
# but SchemaBuilder expects {'data': b'...', 'content_type': '...'}
998+
if isinstance(sample_outputs, list) and len(sample_outputs) > 0:
999+
sample_outputs = sample_outputs[0]
1000+
9951001
self.schema_builder = SchemaBuilder(sample_inputs, sample_outputs)
9961002

9971003
except ValueError as e:

sagemaker-serve/tests/integ/test_tei_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def build_and_deploy():
104104

105105
core_endpoint = model_builder.deploy(
106106
endpoint_name=f"{ENDPOINT_NAME_PREFIX}-{unique_id}",
107-
initial_instance_count=1
107+
initial_instance_count=1,
108108
)
109109
logger.info(f"Endpoint Successfully Created: {core_endpoint.endpoint_name}")
110110

0 commit comments

Comments
 (0)