Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions truss/remote/baseten/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def _prepare_push(

config = truss_handle._spec._config

# Pin HuggingFace revisions in weights to commit SHAs
self._pin_hf_revisions_in_weights(truss_handle)

config.validate_forbid_extra()
encoded_config_str = base64_encoded_json_str(config.to_dict())
validate_truss_config_against_backend(self._api, encoded_config_str)
Expand All @@ -199,6 +202,96 @@ def _prepare_push(
team_id=team_id,
)

def _pin_hf_revisions_in_weights(self, truss_handle: TrussHandle) -> None:
"""Pin HuggingFace revisions in weights to commit SHAs (best effort).

This resolves unpinned revisions (e.g., "main", "dev", or missing revision)
to commit SHAs and writes the updated config back to config.yaml.

This is a best-effort operation - if resolution fails for any reason,
the push will still proceed with the unpinned revision. This ensures
the push never fails due to HF API issues or network problems.
"""
try:
from truss.util.hf_revision_resolver import (
build_hf_source,
is_commit_sha,
parse_hf_source,
resolve_hf_revision,
)

config = truss_handle._spec._config
Copy link
Contributor

@michaelfeil michaelfeil Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now you're running the try except for every push. Not good. Only try when weights are used and huggingface is used and revision is main or startswith(refs)

weights = config.weights

if not weights or not weights.sources:
return

modified = False
changes = []

for source in weights.sources:
if not source.is_huggingface:
continue

try:
repo_id, revision = parse_hf_source(source.source)
except ValueError:
# Not a valid HF source, skip
continue

# Skip if already a commit SHA
if is_commit_sha(revision):
continue

# If no revision specified, assume "main"
revision_to_resolve = revision if revision else "main"

# Resolve to commit SHA (best effort)
try:
resolved_sha = resolve_hf_revision(
repo_id, revision_to_resolve, token=None
)

if revision != resolved_sha:
new_source = build_hf_source(repo_id, resolved_sha)

source.source = new_source
modified = True
changes.append((repo_id, revision_to_resolve, resolved_sha))

logging.info(
f"Pinning HF revision: {repo_id}@{revision_to_resolve} "
f"-> {resolved_sha[:8]}..."
)
except Exception as e:
# Best effort - don't fail the push if we can't resolve
logging.info(
f"Could not resolve {repo_id}@{revision_to_resolve} to SHA: {e}. "
"Push will proceed with unpinned revision."
)
continue

if modified:
# Write updated config back to file
config_path = truss_handle._truss_dir / "config.yaml"
config.write_to_yaml_file(config_path, verbose=True)

logging.warning(
"\n⚠️ Your config.yaml has been modified to pin HuggingFace revisions to commit SHAs."
)
logging.warning(
" This ensures reproducible deployments. Modified sources:\n"
)
for repo_id, old_rev, new_sha in changes:
logging.warning(f" {repo_id}@{old_rev} -> {new_sha[:8]}...\n")

except Exception as e:
# Catch-all: never fail the push due to revision pinning
logging.info(
f"HF revision pinning encountered an error: {e}. "
"Push will proceed without pinning revisions."
)

def push( # type: ignore
self,
truss_handle: TrussHandle,
Expand Down
120 changes: 120 additions & 0 deletions truss/tests/util/test_hf_revision_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Tests for HuggingFace revision resolver."""

from unittest.mock import Mock, patch

import pytest

from truss.util.hf_revision_resolver import (
build_hf_source,
is_commit_sha,
parse_hf_source,
resolve_hf_revision,
)


def test_parse_hf_source_with_revision():
"""Test parsing HF source with explicit revision."""
repo_id, revision = parse_hf_source("hf://meta-llama/Llama-2-7b@main")
assert repo_id == "meta-llama/Llama-2-7b"
assert revision == "main"


def test_parse_hf_source_with_sha():
"""Test parsing HF source with commit SHA."""
sha = "a" * 40
repo_id, revision = parse_hf_source(f"hf://meta-llama/Llama-2-7b@{sha}")
assert repo_id == "meta-llama/Llama-2-7b"
assert revision == sha


def test_parse_hf_source_without_revision():
"""Test parsing HF source without revision (should return None)."""
repo_id, revision = parse_hf_source("hf://meta-llama/Llama-2-7b")
assert repo_id == "meta-llama/Llama-2-7b"
assert revision is None


def test_parse_hf_source_invalid():
"""Test parsing non-HF source raises ValueError."""
with pytest.raises(ValueError, match="Not a HuggingFace source"):
parse_hf_source("s3://bucket/path")

with pytest.raises(ValueError, match="Not a HuggingFace source"):
parse_hf_source("gs://bucket/path")


def test_is_commit_sha():
"""Test commit SHA detection."""
# Valid SHAs
assert is_commit_sha("a" * 40)
assert is_commit_sha("0123456789abcdef" * 2 + "01234567")
assert is_commit_sha("f" * 40)

# Invalid SHAs
assert not is_commit_sha("main")
assert not is_commit_sha("v1.0.0")
assert not is_commit_sha("a" * 39) # Too short
assert not is_commit_sha("a" * 41) # Too long
assert not is_commit_sha(None)
assert not is_commit_sha("")
assert not is_commit_sha("gggggggggggggggggggggggggggggggggggggggg") # Invalid hex


def test_build_hf_source():
"""Test building HF source URI."""
sha = "abc1230000000000000000000000000000000000"
uri = build_hf_source("meta-llama/Llama-2-7b", sha)
assert uri == f"hf://meta-llama/Llama-2-7b@{sha}"


def test_build_hf_source_with_branch():
"""Test building HF source URI with branch name."""
uri = build_hf_source("meta-llama/Llama-2-7b", "main")
assert uri == "hf://meta-llama/Llama-2-7b@main"


@patch("truss.util.hf_revision_resolver.HfApi")
def test_resolve_hf_revision(mock_hf_api):
"""Test resolving HF revision to SHA."""
mock_repo_info = Mock()
mock_repo_info.sha = "a" * 40
mock_hf_api.return_value.repo_info.return_value = mock_repo_info

sha = resolve_hf_revision("meta-llama/Llama-2-7b", "main")
assert sha == "a" * 40

# Verify API was called correctly
mock_hf_api.assert_called_once_with(token=None)
mock_hf_api.return_value.repo_info.assert_called_once_with(
repo_id="meta-llama/Llama-2-7b", revision="main", repo_type="model"
)


@patch("truss.util.hf_revision_resolver.HfApi")
def test_resolve_hf_revision_without_revision(mock_hf_api):
"""Test resolving HF revision when no revision specified (uses default)."""
mock_repo_info = Mock()
mock_repo_info.sha = "b" * 40
mock_hf_api.return_value.repo_info.return_value = mock_repo_info

sha = resolve_hf_revision("meta-llama/Llama-2-7b", None)
assert sha == "b" * 40

# Verify revision=None was passed
mock_hf_api.return_value.repo_info.assert_called_once_with(
repo_id="meta-llama/Llama-2-7b", revision=None, repo_type="model"
)


@patch("truss.util.hf_revision_resolver.HfApi")
def test_resolve_hf_revision_with_token(mock_hf_api):
"""Test resolving HF revision with auth token."""
mock_repo_info = Mock()
mock_repo_info.sha = "c" * 40
mock_hf_api.return_value.repo_info.return_value = mock_repo_info

sha = resolve_hf_revision("meta-llama/Llama-2-7b", "main", token="hf_test_token")
assert sha == "c" * 40

# Verify token was passed to HfApi
mock_hf_api.assert_called_once_with(token="hf_test_token")
91 changes: 91 additions & 0 deletions truss/util/hf_revision_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Resolve HuggingFace revisions to commit SHAs."""

import logging
import re
from typing import Optional

from huggingface_hub import HfApi
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad, import of optional deps at global level. Move to inside resolve_hf_revision(

from huggingface_hub.utils import HfHubHTTPError

logger = logging.getLogger(__name__)


def resolve_hf_revision(
repo_id: str, revision: Optional[str] = None, token: Optional[str] = None
) -> str:
"""Resolve HF revision to commit SHA (mirrors Rust truss-transfer logic).

This does the same thing as the Rust code in truss-transfer/src/create/hf_metadata.rs:
- Calls api_repo.info().await
- Extracts repo_info.sha

This is used for best-effort revision pinning during push. Callers should handle
exceptions gracefully and allow the push to proceed even if resolution fails.

Args:
repo_id: HuggingFace repo ID (e.g., "meta-llama/Llama-2-7b")
revision: Branch, tag, or SHA (None = default branch)
token: Optional HF token for private repos

Returns:
Resolved commit SHA (40-char hex string)

Raises:
HfHubHTTPError: If repo doesn't exist, revision is invalid, or network issues
"""
try:
api = HfApi(token=token)
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type="model")
return repo_info.sha
except HfHubHTTPError as e:
logger.debug(f"Failed to resolve HF revision for {repo_id}@{revision}: {e}")
raise


def is_commit_sha(revision: Optional[str]) -> bool:
"""Check if revision is already a 40-character commit SHA."""
if not revision:
return False
return bool(re.match(r"^[0-9a-f]{40}$", revision))


def parse_hf_source(source: str) -> tuple[str, Optional[str]]:
"""Parse HuggingFace source URI into repo_id and revision.

Args:
source: URI like "hf://owner/repo@revision" or "hf://owner/repo"

Returns:
Tuple of (repo_id, revision or None)

Examples:
>>> parse_hf_source("hf://meta-llama/Llama-2-7b@main")
("meta-llama/Llama-2-7b", "main")
>>> parse_hf_source("hf://meta-llama/Llama-2-7b")
("meta-llama/Llama-2-7b", None)
"""
if not source.startswith("hf://"):
raise ValueError(f"Not a HuggingFace source: {source}")

# Remove "hf://" prefix
path = source[5:]

# Split on @ to get repo_id and revision
if "@" in path:
repo_id, revision = path.rsplit("@", 1)
return repo_id, revision

return path, None


def build_hf_source(repo_id: str, revision: str) -> str:
"""Build HuggingFace source URI from repo_id and revision.

Args:
repo_id: HuggingFace repo ID
revision: Commit SHA or branch/tag name

Returns:
URI like "hf://owner/repo@revision"
"""
return f"hf://{repo_id}@{revision}"
Loading