Skip to content

Commit 2b77772

Browse files
authored
Support collapse_fields param in both create_index and search endpoints (#296)
* Support collapse fields params in both create_index and search endpoints. * downgrade minimum version
1 parent ad89a7d commit 2b77772

File tree

5 files changed

+77
-2
lines changed

5 files changed

+77
-2
lines changed

src/marqo/index.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def create(config: Config,
101101
wait_for_readiness: bool = True,
102102
text_chunk_prefix: Optional[str] = None,
103103
text_query_prefix: Optional[str] = None,
104+
collapse_fields: Optional[List[marqo_index.CollapseField]] = None,
104105
) -> Dict[str, Any]:
105106
"""Create the index. Please refer to the marqo cloud to see options for inference and storage node types.
106107
Creates CreateIndexSettings object and then uses it to create the index.
@@ -137,6 +138,7 @@ def create(config: Config,
137138
number_of_inferences: number of inferences for the index
138139
number_of_shards: number of shards for the index
139140
number_of_replicas: number of replicas for the index
141+
collapse_fields: list of fields that can be collapsed on at query time
140142
Note:
141143
wait_for_readiness, inference_type, storage_class, number_of_inferences,
142144
number_of_shards, number_of_replicas are Marqo Cloud specific parameters,
@@ -166,6 +168,7 @@ def create(config: Config,
166168
annParameters=ann_parameters,
167169
textChunkPrefix=text_chunk_prefix,
168170
textQueryPrefix=text_query_prefix,
171+
collapseFields=collapse_fields,
169172
)
170173
return req.post(f"indexes/{index_name}", body=local_create_index_settings.generate_request_body())
171174

@@ -195,6 +198,7 @@ def create(config: Config,
195198
storageClass=storage_class,
196199
textChunkPrefix=text_chunk_prefix,
197200
textQueryPrefix=text_query_prefix,
201+
collapseFields=collapse_fields,
198202
)
199203

200204
response = req.post(f"indexes/{index_name}", body=cloud_index_settings.generate_request_body())
@@ -228,7 +232,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
228232
language: Optional[str] = None,
229233
sort_by: Optional[dict] = None,
230234
relevance_cutoff: Optional[dict] = None,
231-
interpolation_method: Optional[str] = None
235+
interpolation_method: Optional[str] = None,
236+
collapse_fields: Optional[List[dict]] = None,
232237
) -> Dict[str, Any]:
233238
"""Search the index.
234239
@@ -275,6 +280,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
275280
sort_by: a dictionary of the sort_by parameters to be used for sorting the results
276281
relevance_cutoff: a dictionary of the relevance cutoff parameters
277282
interpolation_method: the interpolation method to use for combining query & context embeddings.
283+
collapse_fields: a list of fields to collapse/group on
278284
279285
Returns:
280286
Dictionary with hits and other metadata
@@ -317,7 +323,8 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
317323
"language": language,
318324
"sortBy": sort_by,
319325
"relevanceCutoff": relevance_cutoff,
320-
"interpolationMethod": interpolation_method
326+
"interpolationMethod": interpolation_method,
327+
"collapseFields": collapse_fields,
321328
}
322329

323330
body = {k: v for k, v in body.items() if v is not None}

src/marqo/models/create_index_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class IndexSettings(MarqoBaseModel):
2626
imagePreprocessing: The image preprocessing method to use.
2727
vectorNumericType: The numeric type of the vector.
2828
annParameters: The ANN parameters to use.
29+
collapseFields: list of fields that can be collapsed on at query time
2930
3031
Please note, we don't note set default values in the py-marqo side. All the
3132
values are set to be None and will not be sent to Marqo in the HttpRequest.
@@ -50,6 +51,7 @@ class IndexSettings(MarqoBaseModel):
5051
annParameters: Optional[marqo_index.AnnParameters] = None
5152
textQueryPrefix: Optional[str] = None
5253
textChunkPrefix: Optional[str] = None
54+
collapseFields: Optional[List[marqo_index.CollapseField]] = None
5355

5456
def generate_request_body(self) -> dict:
5557
"""A json encoded string of the request body"""

src/marqo/models/marqo_index.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,8 @@ class FieldRequest(StrictBaseModel):
104104
class AnnParameters(StrictBaseModel):
105105
spaceType: Optional[DistanceMetric] = Field(None, alias="space_type")
106106
parameters: Optional[HnswConfig] = None
107+
108+
109+
class CollapseField(StrictBaseModel):
110+
name: str
111+
minGroups: int = 500
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from unittest.mock import patch, MagicMock
2+
3+
from marqo.config import Config
4+
from marqo.default_instance_mappings import DefaultInstanceMappings
5+
from marqo.index import Index
6+
from marqo.models.marqo_index import CollapseField
7+
from unit_tests.marqo_unit_tests import MarqoUnitTests
8+
9+
10+
class TestIndexCreationEndpoint(MarqoUnitTests):
11+
@classmethod
12+
def setUpClass(cls):
13+
"""
14+
Set up the class by creating a test index.
15+
"""
16+
cls.config = Config(
17+
instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"),
18+
)
19+
20+
@patch('marqo._httprequests.HttpRequests.send_request')
21+
def test_collapse_field_parameter(self, mock_send_request):
22+
mock_response = MagicMock()
23+
mock_send_request.return_value = mock_response
24+
25+
Index.create(self.config, index_name='test_index',
26+
collapse_fields=[
27+
CollapseField(name='parent_id'), # test default minGroups value is 500
28+
CollapseField(name='color', minGroups=10),
29+
])
30+
31+
self.assertEqual(1, mock_send_request.call_count)
32+
body = mock_send_request.call_args[0][2]
33+
self.assertListEqual([{'name': 'parent_id', 'minGroups': 500}, {'name': 'color', 'minGroups': 10}],
34+
body["collapseFields"])

unit_tests/test_index_search_endpoint.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,30 @@ def test_interpolation_method_parameter_not_exists(self, mock_send_request):
179179
self.assertEqual(1, mock_send_request.call_count)
180180
body = mock_send_request.call_args[0][2]
181181
self.assertNotIn("interpolationMethod", body)
182+
183+
184+
class TestIndexSearchEndpointCollapseFields(MarqoUnitTests):
185+
"""
186+
Test class for the Index search endpoint collapse_fields parameter.
187+
"""
188+
@classmethod
189+
def setUpClass(cls):
190+
"""
191+
Set up the class by creating a test index.
192+
"""
193+
config = Config(
194+
instance_mappings=DefaultInstanceMappings(url="http://unit-tests-url:8882"),
195+
)
196+
cls.index = Index(config=config, index_name="test_index")
197+
198+
@patch('marqo._httprequests.HttpRequests.send_request')
199+
def test_collapse_field_parameter(self, mock_send_request):
200+
mock_response = MagicMock()
201+
mock_send_request.return_value = mock_response
202+
203+
collapse_fields = [{"name": "parent_id"}]
204+
self.index.search(q="test", collapse_fields=collapse_fields)
205+
206+
self.assertEqual(1, mock_send_request.call_count)
207+
body = mock_send_request.call_args[0][2]
208+
self.assertEqual(collapse_fields, body["collapseFields"])

0 commit comments

Comments
 (0)