Skip to content

Commit 06545dd

Browse files
Remove test files for utils and vicinity modules
- Deleted `tests/test_utils.py` containing tests for normalization utility functions - Removed `tests/test_vicinity.py` with comprehensive test cases for the Vicinity class - These test files are no longer needed, likely due to refactoring or migration of tests
1 parent 65465f3 commit 06545dd

File tree

3 files changed

+33
-25
lines changed

3 files changed

+33
-25
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
import io
4+
import sys
5+
6+
from vicinity import Vicinity
7+
from vicinity.datatypes import Backend
8+
from vicinity.integrations.huggingface import _MODEL_NAME_OR_PATH_PRINT_STATEMENT
9+
10+
BackendType = tuple[Backend, str]
11+
12+
13+
def test_load_from_hub(vicinity_instance: Vicinity) -> None:
14+
"""
15+
Test Vicinity.load_from_hub.
16+
17+
:param vicinity_instance: A Vicinity instance.
18+
"""
19+
repo_id = "davidberenstein1957/my-vicinity-repo"
20+
# get the first part of the print statement to test if model name or path is printed
21+
expected_print_statement = _MODEL_NAME_OR_PATH_PRINT_STATEMENT.split(":")[0]
22+
23+
# Capture the output
24+
captured_output = io.StringIO()
25+
sys.stdout = captured_output
26+
27+
Vicinity.load_from_hub(repo_id=repo_id)
28+
29+
# Reset redirect.
30+
sys.stdout = sys.__stdout__
31+
32+
# Check if the expected message is in the output
33+
assert expected_print_statement in captured_output.getvalue()
Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
import io
4-
import sys
53
from pathlib import Path
64

75
import numpy as np
@@ -10,7 +8,6 @@
108

119
from vicinity import Vicinity
1210
from vicinity.datatypes import Backend
13-
from vicinity.integrations.huggingface import _MODEL_NAME_OR_PATH_PRINT_STATEMENT
1411

1512
BackendType = tuple[Backend, str]
1613

@@ -336,25 +333,3 @@ def test_vicinity_evaluate(vicinity_instance: Vicinity, vectors: np.ndarray) ->
336333
vicinity_instance.backend.arguments.metric = "manhattan"
337334
with pytest.raises(ValueError):
338335
vicinity_instance.evaluate(vectors, query_vectors)
339-
340-
341-
def test_load_from_hub(vicinity_instance: Vicinity) -> None:
342-
"""
343-
Test Vicinity.load_from_hub.
344-
345-
:param vicinity_instance: A Vicinity instance.
346-
"""
347-
repo_id = "davidberenstein1957/my-vicinity-repo"
348-
expected_print_statement = _MODEL_NAME_OR_PATH_PRINT_STATEMENT.split(":")[0]
349-
350-
# Capture the output
351-
captured_output = io.StringIO()
352-
sys.stdout = captured_output
353-
354-
Vicinity.load_from_hub(repo_id=repo_id)
355-
356-
# Reset redirect.
357-
sys.stdout = sys.__stdout__
358-
359-
# Check if the expected message is in the output
360-
assert expected_print_statement in captured_output.getvalue()

0 commit comments

Comments
 (0)