Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,11 @@ def build_vector_search_proto_request(params):
source_config = common_pb2.SourceConfigParam(bool=fetch_source)
timeout = params.get("request-timeout")

if body.get("stored_fields") is None or body.get("stored_fields") is STORED_FIELDS_NONE:
stored_fields = body.get("stored_fields")
if stored_fields is None or stored_fields == STORED_FIELDS_NONE:
stored_fields = [STORED_FIELDS_NONE]
else:
stored_fields = body.get("stored_fields")
elif not isinstance(stored_fields, list):
raise Exception("Error parsing query params - Stored fields must be a list")

if isinstance(params.get("cache"), bool):
cache = params.get("cache")
Expand Down
35 changes: 35 additions & 0 deletions tests/worker_coordinator/proto_query_helper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,41 @@ def test_build_vector_search_proto_defaults_with_optional_fields_empty(self):
self.assertEqual(request.timeout, '')
self.assertEqual(request.request_body.profile, False)

def test_build_vector_search_proto_stored_fields_non_list_parsing(self):
params = {
'body': {
'query': {
'knn': {
'knn_field': {
'vector': np.array([1.0], dtype=np.float32),
'k': 100
}
}
},
'docvalue_fields': ['_id'],
'stored_fields': "_none_",
},
'request-params': {},
'index': 'index_required',
'k': 100,
}

request = ProtoQueryHelper.build_vector_search_proto_request(params)

self.assertEqual(request.stored_fields, ["_none_"])

params['body']['stored_fields'] = ['field1', 'field2']
request = ProtoQueryHelper.build_vector_search_proto_request(params)

self.assertEqual(request.stored_fields, ['field1', 'field2'])

params['body']['stored_fields'] = 'field1'

with self.assertRaises(Exception) as context:
ProtoQueryHelper.build_vector_search_proto_request(params)

self.assertIn('Stored fields must be a list', str(context.exception))

def test_build_vector_search_proto_string_or_bool(self):
params = {
'body': {
Expand Down
Loading