Skip to content

Commit a4a20b9

Browse files
authored
Add model card loading (#45)
* Add model card loading * Add tests
1 parent 0ca9d00 commit a4a20b9

File tree

3 files changed

+55
-5
lines changed

3 files changed

+55
-5
lines changed

model2vec/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,9 +154,11 @@ def from_pretrained(
154154
:param token: The huggingface token to use.
155155
:return: A StaticEmbedder
156156
"""
157-
embeddings, tokenizer, config = load_pretrained(path, token=token)
157+
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token)
158158

159-
return cls(embeddings, tokenizer, config)
159+
return cls(
160+
embeddings, tokenizer, config, base_model_name=metadata.get("base_model"), language=metadata.get("language")
161+
)
160162

161163
def encode(
162164
self,

model2vec/utils.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def _create_model_card(
102102

103103
def load_pretrained(
104104
folder_or_repo_path: str | Path, token: str | None = None
105-
) -> tuple[np.ndarray, Tokenizer, dict[str, Any]]:
105+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
106106
"""
107107
Loads a pretrained model from a folder.
108108
@@ -111,7 +111,7 @@ def load_pretrained(
111111
- If the local path is not found, we will attempt to load from the huggingface hub.
112112
:param token: The huggingface token to use.
113113
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
114-
:return: The embeddings, tokenizer, and config.
114+
:return: The embeddings, tokenizer, config, and metadata.
115115
116116
"""
117117
folder_or_repo_path = Path(folder_or_repo_path)
@@ -133,6 +133,10 @@ def load_pretrained(
133133
if not tokenizer_path.exists():
134134
raise FileNotFoundError(f"Tokenizer file does not exist in {folder_or_repo_path}")
135135

136+
# README is optional, so this is a bit finicky.
137+
readme_path = folder_or_repo_path / "README.md"
138+
metadata = _get_metadata_from_readme(readme_path)
139+
136140
else:
137141
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
138142
try:
@@ -148,6 +152,13 @@ def load_pretrained(
148152
# Raise original exception.
149153
raise e
150154

155+
try:
156+
readme_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "README.md", token=token)
157+
metadata = _get_metadata_from_readme(Path(readme_path))
158+
except huggingface_hub.utils.EntryNotFoundError:
159+
logger.info("No README found in the model folder. No model card loaded.")
160+
metadata = {}
161+
151162
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
152163
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)
153164

@@ -162,7 +173,19 @@ def load_pretrained(
162173
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
163174
)
164175

165-
return embeddings, tokenizer, config
176+
return embeddings, tokenizer, config, metadata
177+
178+
179+
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:
180+
"""Get metadata from a README file."""
181+
if not readme_path.exists():
182+
logger.info(f"README file not found in {readme_path}. No model card loaded.")
183+
return {}
184+
model_card = ModelCard.load(readme_path)
185+
data: dict[str, Any] = model_card.data.to_dict()
186+
if not data:
187+
logger.info("File README.md exists, but was empty. No model card loaded.")
188+
return data
166189

167190

168191
def push_folder_to_hub(folder_path: Path, repo_id: str, private: bool, token: str | None) -> None:

tests/test_utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from pathlib import Path
2+
from tempfile import NamedTemporaryFile
3+
4+
from model2vec.utils import _get_metadata_from_readme
5+
6+
7+
def test__get_metadata_from_readme_not_exists() -> None:
8+
"""Test getting metadata from a README."""
9+
assert _get_metadata_from_readme(Path("zzz")) == {}
10+
11+
12+
def test__get_metadata_from_readme_mocked_file() -> None:
13+
"""Test getting metadata from a README."""
14+
with NamedTemporaryFile() as f:
15+
f.write(b"---\nkey: value\n---\n")
16+
f.flush()
17+
assert _get_metadata_from_readme(Path(f.name))["key"] == "value"
18+
19+
20+
def test__get_metadata_from_readme_mocked_file_keys() -> None:
21+
"""Test getting metadata from a README."""
22+
with NamedTemporaryFile() as f:
23+
f.write(b"")
24+
f.flush()
25+
assert set(_get_metadata_from_readme(Path(f.name))) == set()

0 commit comments

Comments
 (0)