diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml new file mode 100644 index 00000000..ea582edd --- /dev/null +++ b/.github/workflows/unit-tests.yml @@ -0,0 +1,40 @@ + + +name: Py-marqo Unit Tests +run-name: Py-marqo unit tests + +on: + workflow_dispatch: + pull_request: + branches: + - mainline + - releases/* + push: + branches: + - mainline + - releases/* + +permissions: + contents: read + +jobs: + test: + name: Run Pytest + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + + - name: Install dependencies + run: pip install -r requirements-dev.txt + + - name: Run UnitTests + run: | + export PYTHONPATH="${PYTHONPATH}:$(pwd)/src" + pytest unit_tests \ No newline at end of file diff --git a/src/marqo/index.py b/src/marqo/index.py index 7d07dd13..30ea28ee 100644 --- a/src/marqo/index.py +++ b/src/marqo/index.py @@ -226,6 +226,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op track_total_hits: Optional[bool] = None, approximate_threshold: Optional[float] = None, language: Optional[str] = None, + sort_by: Optional[dict] = None, + relevance_cutoff: Optional[dict] = None ) -> Dict[str, Any]: """Search the index. @@ -269,6 +271,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op track_total_hits: return total number of lexical matches approximate_threshold: hit ratio threshold for deciding if a nearest neighbor search should be performed as an exact search, rather than an approximate search + sort_by: a dictionary of the sort_by parameters to be used for sorting the results + relevance_cutoff: a dictionary of the relevance cutoff parameters Returns: Dictionary with hits and other metadata @@ -308,7 +312,9 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op "hybridParameters": hybrid_parameters, "facets": facets, "trackTotalHits": track_total_hits, - "language": language + "language": language, + "sortBy": sort_by, + "relevanceCutoff": relevance_cutoff } body = {k: v for k, v in body.items() if v is not None} diff --git a/unit_tests/__init__.py b/unit_tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/unit_tests/marqo_unit_tests.py b/unit_tests/marqo_unit_tests.py new file mode 100644 index 00000000..03a6741c --- /dev/null +++ b/unit_tests/marqo_unit_tests.py @@ -0,0 +1,5 @@ +from unittest import TestCase + + +class MarqoUnitTests(TestCase): + pass \ No newline at end of file diff --git a/unit_tests/test_index_search_endpoint.py b/unit_tests/test_index_search_endpoint.py new file mode 100644 index 00000000..ae188aad --- /dev/null +++ b/unit_tests/test_index_search_endpoint.py @@ -0,0 +1,132 @@ +from unittest.mock import patch, MagicMock + +from marqo.config import Config +from marqo.default_instance_mappings import DefaultInstanceMappings +from marqo.index import Index +from unit_tests.marqo_unit_tests import MarqoUnitTests + + +class TestIndexSearchEndpointSortBy(MarqoUnitTests): + """ + Test class for the Index search endpoint sort_by parameter. + """ + + @classmethod + def setUpClass(cls): + """ + Set up the class by creating a test index. + """ + config = Config( + instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"), + ) + cls.index = Index(config=config, index_name="test_index") + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_search_by_parameters_one_fields(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + sort_by = {"fields":[{"fieldName": "price", "order": "asc"}]} + self.index.search(q= "test", sort_by= sort_by) + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertEqual({"fields": [{"fieldName": "price", "order": "asc"}]}, body["sortBy"]) + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_search_by_parameters_three_fields(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + sort_by = { + "fields": [ + {"fieldName": "price", "order": "asc"}, + {"fieldName": "rating", "order": "desc"}, + {"fieldName": "date", "order": "asc"} + ] + } + + self.index.search(q= "test", sort_by=sort_by) + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertEqual(sort_by, body["sortBy"]) + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_search_by_parameters_none(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + sort_by = None + + self.index.search(q= "test", sort_by=sort_by) + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertNotIn("sortBy", body) + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_search_by_parameters_not_exists(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + self.index.search(q= "test") + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertNotIn("sortBy", body) + + +class TestIndexSearchEndpointRelevanceCutoff(MarqoUnitTests): + """ + Test class for the Index search endpoint relevance_cutoff parameter. + """ + @classmethod + def setUpClass(cls): + """ + Set up the class by creating a test index. + """ + config = Config( + instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"), + ) + cls.index = Index(config=config, index_name="test_index") + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_relevance_cutoff_parameter(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + relevance_cutoff = { + "method": "mean_std_dev", + "parameters": { + "stdDevFactor": 1.0, + }, + "probeDepth": 1000 + } + + self.index.search(q= "test", relevance_cutoff=relevance_cutoff) + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertEqual(relevance_cutoff, body["relevanceCutoff"]) + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_relevance_cutoff_parameter_none(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + relevance_cutoff = None + + self.index.search(q= "test", relevance_cutoff=relevance_cutoff) + body = mock_send_request.call_args[0][2] + self.assertNotIn("relevanceCutoff", body) + + @patch('marqo._httprequests.HttpRequests.send_request') + def test_relevance_cutoff_parameter_not_exists(self, mock_send_request): + mock_response = MagicMock() + mock_send_request.return_value = mock_response + + self.index.search(q= "test") + + self.assertEqual(1, mock_send_request.call_count) + body = mock_send_request.call_args[0][2] + self.assertNotIn("relevanceCutoff", body)