Skip to content

Commit 3457f16

Browse files
committed
Pin HuggingFace revisions in weights to commit SHAs on push
- Add hf_revision_resolver utility to resolve HF revisions to SHAs - Modify BasetenRemote._prepare_push to pin unpinned HF revisions - Resolves 'main' and missing revisions to commit SHAs - Writes updated config.yaml with pinned revisions - Ensures reproducible deployments (same config = same weights) - Mirrors existing model_cache behavior but at push time - Adds comprehensive tests for revision resolution Fixes INF-2314
1 parent 26ba4e6 commit 3457f16

File tree

3 files changed

+290
-0
lines changed

3 files changed

+290
-0
lines changed

truss/remote/baseten/remote.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,9 @@ def _prepare_push(
179179

180180
config = truss_handle._spec._config
181181

182+
# Pin HuggingFace revisions in weights to commit SHAs
183+
self._pin_hf_revisions_in_weights(truss_handle)
184+
182185
config.validate_forbid_extra()
183186
encoded_config_str = base64_encoded_json_str(config.to_dict())
184187
validate_truss_config_against_backend(self._api, encoded_config_str)
@@ -199,6 +202,85 @@ def _prepare_push(
199202
team_id=team_id,
200203
)
201204

205+
def _pin_hf_revisions_in_weights(self, truss_handle: TrussHandle) -> None:
206+
"""Pin HuggingFace revisions in weights to commit SHAs.
207+
208+
This resolves unpinned revisions (e.g., "main", "dev", or missing revision)
209+
to commit SHAs and writes the updated config back to config.yaml.
210+
211+
This ensures reproducible deployments - the same config always pulls the same weights.
212+
"""
213+
from truss.util.hf_revision_resolver import (
214+
build_hf_source,
215+
is_commit_sha,
216+
parse_hf_source,
217+
resolve_hf_revision,
218+
)
219+
220+
config = truss_handle._spec._config
221+
weights = config.weights
222+
223+
if not weights or not weights.sources:
224+
return
225+
226+
modified = False
227+
changes = []
228+
229+
for source in weights.sources:
230+
if not source.is_huggingface:
231+
continue
232+
233+
try:
234+
repo_id, revision = parse_hf_source(source.source)
235+
except ValueError:
236+
continue
237+
238+
# Skip if already a commit SHA
239+
if is_commit_sha(revision):
240+
continue
241+
242+
# If no revision specified, assume "main"
243+
revision_to_resolve = revision if revision else "main"
244+
245+
# Resolve to commit SHA
246+
# TODO: Extract token from source.auth_secret_name if available
247+
try:
248+
resolved_sha = resolve_hf_revision(
249+
repo_id, revision_to_resolve, token=None
250+
)
251+
252+
if revision != resolved_sha:
253+
new_source = build_hf_source(repo_id, resolved_sha)
254+
255+
source.source = new_source
256+
modified = True
257+
changes.append((repo_id, revision_to_resolve, resolved_sha))
258+
259+
logging.info(
260+
f"Pinning HF revision: {repo_id}@{revision_to_resolve} "
261+
f"-> {resolved_sha[:8]}..."
262+
)
263+
except Exception as e:
264+
logging.warning(
265+
f"Failed to resolve {repo_id}@{revision_to_resolve}: {e}. "
266+
"Proceeding with unpinned revision."
267+
)
268+
continue
269+
270+
if modified:
271+
# Write updated config back to file
272+
config_path = truss_handle._truss_dir / "config.yaml"
273+
config.write_to_yaml_file(config_path, verbose=True)
274+
275+
logging.warning(
276+
"\n⚠️ Your config.yaml has been modified to pin HuggingFace revisions to commit SHAs."
277+
)
278+
logging.warning(
279+
" This ensures reproducible deployments. Modified sources:\n"
280+
)
281+
for repo_id, old_rev, new_sha in changes:
282+
logging.warning(f" {repo_id}@{old_rev} -> {new_sha[:8]}...\n")
283+
202284
def push( # type: ignore
203285
self,
204286
truss_handle: TrussHandle,
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
"""Tests for HuggingFace revision resolver."""
2+
3+
from unittest.mock import Mock, patch
4+
5+
import pytest
6+
7+
from truss.util.hf_revision_resolver import (
8+
build_hf_source,
9+
is_commit_sha,
10+
parse_hf_source,
11+
resolve_hf_revision,
12+
)
13+
14+
15+
def test_parse_hf_source_with_revision():
16+
"""Test parsing HF source with explicit revision."""
17+
repo_id, revision = parse_hf_source("hf://meta-llama/Llama-2-7b@main")
18+
assert repo_id == "meta-llama/Llama-2-7b"
19+
assert revision == "main"
20+
21+
22+
def test_parse_hf_source_with_sha():
23+
"""Test parsing HF source with commit SHA."""
24+
sha = "a" * 40
25+
repo_id, revision = parse_hf_source(f"hf://meta-llama/Llama-2-7b@{sha}")
26+
assert repo_id == "meta-llama/Llama-2-7b"
27+
assert revision == sha
28+
29+
30+
def test_parse_hf_source_without_revision():
31+
"""Test parsing HF source without revision (should return None)."""
32+
repo_id, revision = parse_hf_source("hf://meta-llama/Llama-2-7b")
33+
assert repo_id == "meta-llama/Llama-2-7b"
34+
assert revision is None
35+
36+
37+
def test_parse_hf_source_invalid():
38+
"""Test parsing non-HF source raises ValueError."""
39+
with pytest.raises(ValueError, match="Not a HuggingFace source"):
40+
parse_hf_source("s3://bucket/path")
41+
42+
with pytest.raises(ValueError, match="Not a HuggingFace source"):
43+
parse_hf_source("gs://bucket/path")
44+
45+
46+
def test_is_commit_sha():
47+
"""Test commit SHA detection."""
48+
# Valid SHAs
49+
assert is_commit_sha("a" * 40)
50+
assert is_commit_sha("0123456789abcdef" * 2 + "01234567")
51+
assert is_commit_sha("f" * 40)
52+
53+
# Invalid SHAs
54+
assert not is_commit_sha("main")
55+
assert not is_commit_sha("v1.0.0")
56+
assert not is_commit_sha("a" * 39) # Too short
57+
assert not is_commit_sha("a" * 41) # Too long
58+
assert not is_commit_sha(None)
59+
assert not is_commit_sha("")
60+
assert not is_commit_sha("gggggggggggggggggggggggggggggggggggggggg") # Invalid hex
61+
62+
63+
def test_build_hf_source():
64+
"""Test building HF source URI."""
65+
sha = "abc1230000000000000000000000000000000000"
66+
uri = build_hf_source("meta-llama/Llama-2-7b", sha)
67+
assert uri == f"hf://meta-llama/Llama-2-7b@{sha}"
68+
69+
70+
def test_build_hf_source_with_branch():
71+
"""Test building HF source URI with branch name."""
72+
uri = build_hf_source("meta-llama/Llama-2-7b", "main")
73+
assert uri == "hf://meta-llama/Llama-2-7b@main"
74+
75+
76+
@patch("truss.util.hf_revision_resolver.HfApi")
77+
def test_resolve_hf_revision(mock_hf_api):
78+
"""Test resolving HF revision to SHA."""
79+
mock_repo_info = Mock()
80+
mock_repo_info.sha = "a" * 40
81+
mock_hf_api.return_value.repo_info.return_value = mock_repo_info
82+
83+
sha = resolve_hf_revision("meta-llama/Llama-2-7b", "main")
84+
assert sha == "a" * 40
85+
86+
# Verify API was called correctly
87+
mock_hf_api.assert_called_once_with(token=None)
88+
mock_hf_api.return_value.repo_info.assert_called_once_with(
89+
repo_id="meta-llama/Llama-2-7b", revision="main", repo_type="model"
90+
)
91+
92+
93+
@patch("truss.util.hf_revision_resolver.HfApi")
94+
def test_resolve_hf_revision_without_revision(mock_hf_api):
95+
"""Test resolving HF revision when no revision specified (uses default)."""
96+
mock_repo_info = Mock()
97+
mock_repo_info.sha = "b" * 40
98+
mock_hf_api.return_value.repo_info.return_value = mock_repo_info
99+
100+
sha = resolve_hf_revision("meta-llama/Llama-2-7b", None)
101+
assert sha == "b" * 40
102+
103+
# Verify revision=None was passed
104+
mock_hf_api.return_value.repo_info.assert_called_once_with(
105+
repo_id="meta-llama/Llama-2-7b", revision=None, repo_type="model"
106+
)
107+
108+
109+
@patch("truss.util.hf_revision_resolver.HfApi")
110+
def test_resolve_hf_revision_with_token(mock_hf_api):
111+
"""Test resolving HF revision with auth token."""
112+
mock_repo_info = Mock()
113+
mock_repo_info.sha = "c" * 40
114+
mock_hf_api.return_value.repo_info.return_value = mock_repo_info
115+
116+
sha = resolve_hf_revision("meta-llama/Llama-2-7b", "main", token="hf_test_token")
117+
assert sha == "c" * 40
118+
119+
# Verify token was passed to HfApi
120+
mock_hf_api.assert_called_once_with(token="hf_test_token")

truss/util/hf_revision_resolver.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""Resolve HuggingFace revisions to commit SHAs."""
2+
3+
import logging
4+
import re
5+
from typing import Optional
6+
7+
from huggingface_hub import HfApi
8+
from huggingface_hub.utils import HfHubHTTPError
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def resolve_hf_revision(
14+
repo_id: str, revision: Optional[str] = None, token: Optional[str] = None
15+
) -> str:
16+
"""Resolve HF revision to commit SHA (mirrors Rust truss-transfer logic).
17+
18+
This does the same thing as the Rust code in truss-transfer/src/create/hf_metadata.rs:
19+
- Calls api_repo.info().await
20+
- Extracts repo_info.sha
21+
22+
Args:
23+
repo_id: HuggingFace repo ID (e.g., "meta-llama/Llama-2-7b")
24+
revision: Branch, tag, or SHA (None = default branch "main")
25+
token: Optional HF token for private repos
26+
27+
Returns:
28+
Resolved commit SHA (40-char hex string)
29+
30+
Raises:
31+
HfHubHTTPError: If repo doesn't exist or revision is invalid
32+
"""
33+
try:
34+
api = HfApi(token=token)
35+
repo_info = api.repo_info(repo_id=repo_id, revision=revision, repo_type="model")
36+
return repo_info.sha
37+
except HfHubHTTPError as e:
38+
logger.error(f"Failed to resolve HF revision for {repo_id}@{revision}: {e}")
39+
raise
40+
41+
42+
def is_commit_sha(revision: Optional[str]) -> bool:
43+
"""Check if revision is already a 40-character commit SHA."""
44+
if not revision:
45+
return False
46+
return bool(re.match(r"^[0-9a-f]{40}$", revision))
47+
48+
49+
def parse_hf_source(source: str) -> tuple[str, Optional[str]]:
50+
"""Parse HuggingFace source URI into repo_id and revision.
51+
52+
Args:
53+
source: URI like "hf://owner/repo@revision" or "hf://owner/repo"
54+
55+
Returns:
56+
Tuple of (repo_id, revision or None)
57+
58+
Examples:
59+
>>> parse_hf_source("hf://meta-llama/Llama-2-7b@main")
60+
("meta-llama/Llama-2-7b", "main")
61+
>>> parse_hf_source("hf://meta-llama/Llama-2-7b")
62+
("meta-llama/Llama-2-7b", None)
63+
"""
64+
if not source.startswith("hf://"):
65+
raise ValueError(f"Not a HuggingFace source: {source}")
66+
67+
# Remove "hf://" prefix
68+
path = source[5:]
69+
70+
# Split on @ to get repo_id and revision
71+
if "@" in path:
72+
repo_id, revision = path.rsplit("@", 1)
73+
return repo_id, revision
74+
75+
return path, None
76+
77+
78+
def build_hf_source(repo_id: str, revision: str) -> str:
79+
"""Build HuggingFace source URI from repo_id and revision.
80+
81+
Args:
82+
repo_id: HuggingFace repo ID
83+
revision: Commit SHA or branch/tag name
84+
85+
Returns:
86+
URI like "hf://owner/repo@revision"
87+
"""
88+
return f"hf://{repo_id}@{revision}"

0 commit comments

Comments
 (0)