Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@


name: Py-marqo Unit Tests
run-name: Py-marqo unit tests

on:
workflow_dispatch:
pull_request:
branches:
- mainline
- releases/*
push:
branches:
- mainline
- releases/*

permissions:
contents: read

jobs:
test:
name: Run Pytest
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9

- name: Install dependencies
run: pip install -r requirements-dev.txt

- name: Run UnitTests
run: |
export PYTHONPATH="${PYTHONPATH}:$(pwd)/src"
pytest unit_tests
8 changes: 7 additions & 1 deletion src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
track_total_hits: Optional[bool] = None,
approximate_threshold: Optional[float] = None,
language: Optional[str] = None,
sort_by: Optional[dict] = None,
relevance_cutoff: Optional[dict] = None
) -> Dict[str, Any]:
"""Search the index.

Expand Down Expand Up @@ -269,6 +271,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
track_total_hits: return total number of lexical matches
approximate_threshold: hit ratio threshold for deciding if a nearest neighbor search should be performed as
an exact search, rather than an approximate search
sort_by: a dictionary of the sort_by parameters to be used for sorting the results
relevance_cutoff: a dictionary of the relevance cutoff parameters

Returns:
Dictionary with hits and other metadata
Expand Down Expand Up @@ -308,7 +312,9 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
"hybridParameters": hybrid_parameters,
"facets": facets,
"trackTotalHits": track_total_hits,
"language": language
"language": language,
"sortBy": sort_by,
"relevanceCutoff": relevance_cutoff
}

body = {k: v for k, v in body.items() if v is not None}
Expand Down
Empty file added unit_tests/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions unit_tests/marqo_unit_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from unittest import TestCase


class MarqoUnitTests(TestCase):
pass
132 changes: 132 additions & 0 deletions unit_tests/test_index_search_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from unittest.mock import patch, MagicMock

from marqo.config import Config
from marqo.default_instance_mappings import DefaultInstanceMappings
from marqo.index import Index
from unit_tests.marqo_unit_tests import MarqoUnitTests


class TestIndexSearchEndpointSortBy(MarqoUnitTests):
"""
Test class for the Index search endpoint sort_by parameter.
"""

@classmethod
def setUpClass(cls):
"""
Set up the class by creating a test index.
"""
config = Config(
instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"),
)
cls.index = Index(config=config, index_name="test_index")

@patch('marqo._httprequests.HttpRequests.send_request')
def test_search_by_parameters_one_fields(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response
sort_by = {"fields":[{"fieldName": "price", "order": "asc"}]}
self.index.search(q= "test", sort_by= sort_by)

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertEqual({"fields": [{"fieldName": "price", "order": "asc"}]}, body["sortBy"])

@patch('marqo._httprequests.HttpRequests.send_request')
def test_search_by_parameters_three_fields(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

sort_by = {
"fields": [
{"fieldName": "price", "order": "asc"},
{"fieldName": "rating", "order": "desc"},
{"fieldName": "date", "order": "asc"}
]
}

self.index.search(q= "test", sort_by=sort_by)

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertEqual(sort_by, body["sortBy"])

@patch('marqo._httprequests.HttpRequests.send_request')
def test_search_by_parameters_none(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

sort_by = None

self.index.search(q= "test", sort_by=sort_by)

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertNotIn("sortBy", body)

@patch('marqo._httprequests.HttpRequests.send_request')
def test_search_by_parameters_not_exists(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

self.index.search(q= "test")

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertNotIn("sortBy", body)


class TestIndexSearchEndpointRelevanceCutoff(MarqoUnitTests):
"""
Test class for the Index search endpoint relevance_cutoff parameter.
"""
@classmethod
def setUpClass(cls):
"""
Set up the class by creating a test index.
"""
config = Config(
instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"),
)
cls.index = Index(config=config, index_name="test_index")

@patch('marqo._httprequests.HttpRequests.send_request')
def test_relevance_cutoff_parameter(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

relevance_cutoff = {
"method": "mean_std_dev",
"parameters": {
"stdDevFactor": 1.0,
},
"probeDepth": 1000
}

self.index.search(q= "test", relevance_cutoff=relevance_cutoff)

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertEqual(relevance_cutoff, body["relevanceCutoff"])

@patch('marqo._httprequests.HttpRequests.send_request')
def test_relevance_cutoff_parameter_none(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

relevance_cutoff = None

self.index.search(q= "test", relevance_cutoff=relevance_cutoff)
body = mock_send_request.call_args[0][2]
self.assertNotIn("relevanceCutoff", body)

@patch('marqo._httprequests.HttpRequests.send_request')
def test_relevance_cutoff_parameter_not_exists(self, mock_send_request):
mock_response = MagicMock()
mock_send_request.return_value = mock_response

self.index.search(q= "test")

self.assertEqual(1, mock_send_request.call_count)
body = mock_send_request.call_args[0][2]
self.assertNotIn("relevanceCutoff", body)
Loading