Skip to content

Commit 65465f3

Browse files
Add test for Vicinity.load_from_hub method
- Implemented a new test case for loading a Vicinity instance from Hugging Face Hub - Added test to verify the print statement when loading from a repository - Introduced a constant for the print statement in the Hugging Face integration module - Updated the print statement to use string formatting for better flexibility
1 parent cab15e5 commit 65465f3

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

tests/test_vicinity.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import io
4+
import sys
35
from pathlib import Path
46

57
import numpy as np
@@ -8,6 +10,7 @@
810

911
from vicinity import Vicinity
1012
from vicinity.datatypes import Backend
13+
from vicinity.integrations.huggingface import _MODEL_NAME_OR_PATH_PRINT_STATEMENT
1114

1215
BackendType = tuple[Backend, str]
1316

@@ -333,3 +336,25 @@ def test_vicinity_evaluate(vicinity_instance: Vicinity, vectors: np.ndarray) ->
333336
vicinity_instance.backend.arguments.metric = "manhattan"
334337
with pytest.raises(ValueError):
335338
vicinity_instance.evaluate(vectors, query_vectors)
339+
340+
341+
def test_load_from_hub(vicinity_instance: Vicinity) -> None:
342+
"""
343+
Test Vicinity.load_from_hub.
344+
345+
:param vicinity_instance: A Vicinity instance.
346+
"""
347+
repo_id = "davidberenstein1957/my-vicinity-repo"
348+
expected_print_statement = _MODEL_NAME_OR_PATH_PRINT_STATEMENT.split(":")[0]
349+
350+
# Capture the output
351+
captured_output = io.StringIO()
352+
sys.stdout = captured_output
353+
354+
Vicinity.load_from_hub(repo_id=repo_id)
355+
356+
# Reset redirect.
357+
sys.stdout = sys.__stdout__
358+
359+
# Check if the expected message is in the output
360+
assert expected_print_statement in captured_output.getvalue()

vicinity/integrations/huggingface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
_HUB_IMPORT_ERROR = ImportError(
1818
"`datasets` and `huggingface_hub` are required to push to the Hugging Face Hub. Please install them with `pip install 'vicinity[huggingface]'`"
1919
)
20+
_MODEL_NAME_OR_PATH_PRINT_STATEMENT = (
21+
"Embeddings in Vicinity instance were created from model name or path: {model_name_or_path}"
22+
)
2023

2124
logger = logging.getLogger(__name__)
2225

@@ -128,7 +131,7 @@ def load_from_hub(cls, repo_id: str, token: str | None = None, **kwargs: Any) ->
128131
config = json.load(f)
129132
model_name_or_path = config.pop("model_name_or_path")
130133

131-
print(f"Embeddings in Vicinity instance were created from model name or path: {model_name_or_path}")
134+
print(_MODEL_NAME_OR_PATH_PRINT_STATEMENT.format(model_name_or_path=model_name_or_path))
132135
backend_type = Backend(config["backend_type"])
133136
backend = get_backend_class(backend_type).load(repo_path / "backend")
134137

0 commit comments

Comments
 (0)