Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/marqo/api/models/recommend_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from marqo.core.models.interpolation_method import InterpolationMethod
from marqo.tensor_search.models.api_models import BaseMarqoModel
from pydantic.v1 import root_validator
from pydantic.v1 import root_validator, Field
from marqo.tensor_search.models.score_modifiers_object import ScoreModifierLists


Expand All @@ -22,6 +22,8 @@ class RecommendQuery(BaseMarqoModel):
attributesToRetrieve: Union[None, List[str]] = None
scoreModifiers: Optional[ScoreModifierLists] = None
rerankDepth: Optional[int] = None
allow_missing_documents: bool = Field(default=False, alias="allowMissingDocuments")
allow_missing_embeddings: bool = Field(default=False, alias="allowMissingEmbeddings")

@root_validator(pre=False)
def validate_rerank_depth(cls, values):
Expand Down
8 changes: 0 additions & 8 deletions src/marqo/core/models/hybrid_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,6 @@ def validate_properties(cls, values):
if values.get('rankingMethod') not in [RankingMethod.Lexical, RankingMethod.Tensor]:
raise ValueError("For retrievalMethod: tensor or lexical, rankingMethod must be: tensor or lexical")

# if tensor query is an empty dict
if isinstance(values.get('queryTensor'), dict):
if not len(values.get('queryTensor')):
raise ValueError(
"Multi-term query for queryTensor requires at least one query. Received empty dictionary"
)


return values

@validator('alpha')
Expand Down
5 changes: 3 additions & 2 deletions src/marqo/core/search/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,11 @@ def search(
tensor_query = query
lexical_query = query

if (tensor_query is None) != (lexical_query is None):
if lexical_query is None:
# We could allow queryTensor to be None as tensors might be provided with context
if hybrid_parameters.retrievalMethod == RetrievalMethod.Disjunction:
raise core_exceptions.InvalidArgumentError(
"Either both of 'hybridParameters.queryLexical' and 'hybridParameters.queryTensor' or just 'q'"
"Either 'hybridParameters.queryLexical' or just 'q'"
"must be present when 'disjunction' retrieval method is used."
)

Expand Down
116 changes: 95 additions & 21 deletions src/marqo/core/search/recommender.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,14 @@ def __init__(self, vespa_client: VespaClient, index_management: IndexManagement,
self.index_management = index_management
self.inference = inference

def get_doc_vectors_from_ids(self,
index_name: str,
documents: Union[List[str], Dict[str, float]],
tensor_fields: Optional[List[str]] = None) -> Dict[str, List[List[float]]]:
def get_doc_vectors_from_ids(
self,
index_name: str,
documents: Union[List[str], Dict[str, float]],
tensor_fields: Optional[List[str]] = None,
allow_missing_documents: bool = False,
allow_missing_embeddings: bool = False
) -> Dict[str, List[List[float]]]:
"""
This method gets documents from Vespa using their IDs, removes any unnecessary data, checks for
lack of vectors, then returns a list of document vectors. Can be used internally (in recommend)
Expand All @@ -36,10 +40,17 @@ def get_doc_vectors_from_ids(self,
index_name: Name of the index to search
documents: A list of document IDs or a dictionary where the keys are document IDs and the values are weights
tensor_fields: List of tensor fields to use for recommendation (can include text, image, audio, and video fields)
allow_missing_documents: If True, will not raise an error if some document IDs are not found
allow_missing_embeddings: If True, will not raise an error if some documents do not have embeddings

Returns:
A dictionary mapping document IDs to lists of vector embeddings. This is flattened to 1 list per document
ID (not separated by tensor field). Order of embeddings is not guaranteed.

Raises:
InvalidArgumentError:
- If any document IDs are not found and allow_missing_documents is False
- If any document IDs does not have embeddings and allow_missing_embeddings is False
"""

# TODO - Extract search and get_docs from tensor_search and refactor this
Expand Down Expand Up @@ -84,55 +95,107 @@ def get_doc_vectors_from_ids(self,
f'Available tensor fields: {", ".join(valid_tensor_fields)}')

# Use the new optimized method to get only embeddings
# TODO - Consolidate these two method into one place
doc_embeddings_by_field = tensor_search.get_doc_vectors_per_tensor_field_by_ids(
config.Config(self.vespa_client, inference=self.inference),
index_name,
document_ids,
tensor_fields=tensor_fields
tensor_fields=tensor_fields,
allow_missing_documents=allow_missing_documents,
)

return self._sanitize_doc_embeddins_by_field(
all_documents_ids = document_ids,
marqo_index=marqo_index,
doc_embeddings_by_field=doc_embeddings_by_field,
tensor_fields=tensor_fields,
allow_missing_documents=allow_missing_documents,
allow_missing_embeddings=allow_missing_embeddings,
)

def _sanitize_doc_embeddins_by_field(
self,
all_documents_ids: List[str],
marqo_index: MarqoIndex,
doc_embeddings_by_field: Dict[str, Dict[str, List[List[float]]]],
tensor_fields: Optional[List[str]],
allow_missing_documents: bool,
allow_missing_embeddings: bool
) -> Dict[str, List[List[float]]]:
"""
Sanitize the document embeddings by checking for missing documents and embeddings,
and flattening the structure to a simple mapping of document ID to list of embeddings.

If allow_missing_documents is False, raises an error if any document IDs are not found.
If allow_missing_embeddings is False, raises an error if any documents do not have embeddings.

Documents with no embeddings are removed from the result.
Args:
all_documents_ids: The list of all document IDs that were requested
marqo_index: The marqo index object containing metadata about the index
doc_embeddings_by_field: The document embeddings by field returned from
tensor_search.get_doc_vectors_per_tensor_field_by_ids
tensor_fields: tensor fields to include in the result. If None, all fields are included.
allow_missing_documents: If True, will not raise an error if some document IDs are not found.
allow_missing_embeddings: If True, will not raise an error if some documents do not have embeddings.

Returns:
A dictionary mapping document IDs to lists of vector embeddings.
E.g.,
{
"doc1": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
"doc2": [[0.7, 0.8, 0.9]]
}
where each list contains embeddings from all tensor fields where order of embeddings is not preserved.

Raises:
InvalidArgumentError: If any document IDs are not found and allow_missing_documents is False,
or if any documents do not have embeddings and allow_missing_embeddings is False.
"""

# Check that all documents were found
not_found = []
for doc_id in document_ids:
not_found_docs = []
for doc_id in all_documents_ids:
if doc_id not in doc_embeddings_by_field:
not_found.append(doc_id)
not_found_docs.append(doc_id)

if len(not_found) > 0:
raise InvalidArgumentError(f'The following document IDs were not found: {", ".join(not_found)}')
if len(not_found_docs) > 0 and not allow_missing_documents:
raise InvalidArgumentError(f'The following document IDs were not found: {", ".join(not_found_docs)}')

# Flatten the embeddings structure to match the expected return format
# Convert from Dict[doc_id, Dict[field_name, List[List[float]]]]
# Convert from Dict[doc_id, Dict[field_name, List[List[float]]]]
# to Dict[doc_id, List[List[float]]]
doc_vectors: Dict[str, List[List[float]]] = {}
docs_without_vectors = []

for doc_id, field_embeddings in doc_embeddings_by_field.items():
vectors: List[List[float]] = []

# Flatten all embeddings from all fields for this document
for field_name, embedding_list in field_embeddings.items():
# For legacy unstructured indices, field_name will be "marqo__embeddings"
# and we should include all embeddings regardless of tensor_fields filter
# since all embeddings are stored together in marqo__embeddings
if (tensor_fields is None or
if (tensor_fields is None or
field_name in tensor_fields or
(marqo_index.type == IndexType.Unstructured and
field_name == unstructured_common.VESPA_DOC_EMBEDDINGS)):
vectors.extend(embedding_list)

doc_vectors[doc_id] = vectors

if len(vectors) == 0:
docs_without_vectors.append(doc_id)

if len(docs_without_vectors) > 0:

if len(docs_without_vectors) > 0 and not allow_missing_embeddings:
raise InvalidArgumentError(
f'The following documents do not have embeddings: {", ".join(docs_without_vectors)}'
)

for doc_id in docs_without_vectors:
del doc_vectors[doc_id]
return doc_vectors


def recommend(self,
index_name: str,
documents: Union[List[str], Dict[str, float]],
Expand All @@ -150,7 +213,9 @@ def recommend(self,
filter: str = None,
attributes_to_retrieve: Optional[List[str]] = None,
score_modifiers: Optional[ScoreModifierLists] = None,
rerank_depth: Optional[int] = None
rerank_depth: Optional[int] = None,
allow_missing_documents: bool = False,
allow_missing_embeddings: bool = False,
):
"""
Recommend documents similar to the provided documents.
Expand Down Expand Up @@ -191,7 +256,9 @@ def recommend(self,
doc_vectors = self.get_doc_vectors_from_ids(
index_name=index_name,
documents=documents,
tensor_fields=tensor_fields
tensor_fields=tensor_fields,
allow_missing_documents=allow_missing_documents,
allow_missing_embeddings=allow_missing_embeddings,
)

# Save original document IDs for filtering
Expand All @@ -208,9 +275,16 @@ def recommend(self,
weight = documents[document_id]
else:
weight = 1

vectors.extend(vector_list)
weights.extend([weight] * len(vector_list))

if len(vectors) == 0:
raise InvalidArgumentError(
"Marqo could not collect any valid vector from the documents. "
"Please check if the provided documents exist or if the documents have valid embeddings. "
)

try:
interpolated_vector = vector_interpolation.interpolate(
vectors, weights
Expand Down Expand Up @@ -253,7 +327,7 @@ def recommend(self,
attributes_to_retrieve=attributes_to_retrieve,
score_modifiers=score_modifiers,
processing_start=t0,
rerank_depth=rerank_depth
rerank_depth=rerank_depth,
)

return results
Expand Down
4 changes: 3 additions & 1 deletion src/marqo/tensor_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,9 @@ def recommend(query_dict: dict, index_name: str,
filter=query.filter,
attributes_to_retrieve=query.attributesToRetrieve,
score_modifiers=query.scoreModifiers,
rerank_depth=query.rerankDepth
rerank_depth=query.rerankDepth,
allow_missing_documents=query.allow_missing_documents,
allow_missing_embeddings=query.allow_missing_embeddings,
)


Expand Down
2 changes: 2 additions & 0 deletions src/marqo/tensor_search/models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class SearchContextTensor(BaseModel):
class SearchContextDocumentsParameters(BaseModel):
tensor_fields: Optional[List[str]] = Field(None, alias='tensorFields')
exclude_input_documents: bool = Field(True, alias='excludeInputDocuments')
allow_missing_documents: bool = Field(False, alias='allowMissingDocuments')
allow_missing_embeddings: bool = Field(False, alias='allowMissingEmbeddings')

@validator('tensor_fields', pre=True, always=True)
def check_tensor_fields_not_empty(cls, v):
Expand Down
49 changes: 39 additions & 10 deletions src/marqo/tensor_search/tensor_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,14 +887,16 @@ def get_query_vectors_from_jobs(
) -> Dict[Qidx, List[float]]:
"""
Retrieve the vectorised content associated to each query from the set of batch vectorise jobs.
Handles multi-modal queries, by weighting and combining queries into a single vector
Handles multi-modal queries, by weighting and combining queries into a single vector.

Args:
- queries: Original search queries.
- qidx_to_job: VectorisedJobPointer for each query
- job_to_vectors: inference output from each VectorisedJob
- config: standard Marqo config.

Raises:
api_exceptions.InvalidArgError: If this method can not collect a valid vector from the query
"""
result: Dict[Qidx, List[float]] = defaultdict(list)
for qidx, ptrs in qidx_to_job.items():
Expand Down Expand Up @@ -943,7 +945,9 @@ def get_query_vectors_from_jobs(
context_doc_vectors = config.recommender.get_doc_vectors_from_ids(
index_name=q.index.name,
documents=context_documents.ids,
tensor_fields=context_documents.parameters.tensor_fields
tensor_fields=context_documents.parameters.tensor_fields,
allow_missing_documents=context_documents.parameters.allow_missing_documents,
allow_missing_embeddings= context_documents.parameters.allow_missing_embeddings
)

# Update weights and vectors list
Expand All @@ -968,13 +972,14 @@ def get_query_vectors_from_jobs(
# Use interpolation to combine all vectors
vector_interpolation = from_interpolation_method(interpolation_method)
with RequestMetricsStore.for_request().time(f"search.vectorise.interpolate_vectors"):
merged_vector = vector_interpolation.interpolate(
vectors=collected_vectors,
weights=collected_weights
)

result[qidx] = list(merged_vector)

if collected_vectors:
merged_vector = vector_interpolation.interpolate(
vectors=collected_vectors,
weights=collected_weights
)
result[qidx] = list(merged_vector)
else:
result[qidx] = []
elif isinstance(q.q, str):
if q.context:
raise core_exceptions.InvalidArgumentError(
Expand All @@ -988,6 +993,14 @@ def get_query_vectors_from_jobs(
)
else:
raise ValueError(f"Unexpected query type: {type(q.q).__name__}")

if not result[qidx]:
raise api_exceptions.InvalidArgError(
f"Marqo could not collect any vectors from the search query but the retrieval or ranking method requires "
f"at least one valid vector. "
f"Please check the provided query, context (if any), or queryTensor(for Hybrid search) "
)

return result


Expand Down Expand Up @@ -1069,7 +1082,7 @@ def add_prefix_to_queries(queries: List[BulkSearchQueryEntity]) -> List[BulkSear
def run_vectorise_pipeline(config: Config, queries: List[BulkSearchQueryEntity], device: Union[Device, str],
interpolation_method: InterpolationMethod = None) -> Dict[
Qidx, List[float]]:
"""Run the query vectorisation process
"""Run the query vectorisation process. This is a pipeline used for both Tensor search and Hybrid search.

Raise:
api_exceptions.InvalidArgError: If the vectorisation process fails or if the media cannot be downloaded.
Expand Down Expand Up @@ -1185,6 +1198,9 @@ def _vector_text_search(
qidx_to_vectors: Dict[Qidx, List[float]] = run_vectorise_pipeline(config, queries, device, interpolation_method)
vectorised_text = list(qidx_to_vectors.values())[0]

if not vectorised_text: # pragma: no cover
raise InternalError(f"No vector is generated for the tensor query: {query}. ")

marqo_query = MarqoTensorQuery(
index_name=index_name,
vector_query=vectorised_text,
Expand Down Expand Up @@ -1353,6 +1369,7 @@ def get_doc_vectors_per_tensor_field_by_ids(
index_name: str,
document_ids: List[str],
tensor_fields: Optional[List[str]] = None,
allow_missing_documents: bool = False,
) -> Dict[str, Dict[str, List[List[float]]]]:
"""
Get only the embeddings for documents by their IDs.
Expand All @@ -1362,9 +1379,18 @@ def get_doc_vectors_per_tensor_field_by_ids(
index_name: Name of the index
document_ids: List of document IDs to fetch
tensor_fields: Specific tensor fields to get. If None, get all tensor fields.
allow_missing_documents: If True, will not raise an error if a document is not found

Returns:
Dict mapping document_id to field_name to list of embedding vectors
E.g.,
{
"doc_id_1": {
"field_name_1": [[0.1, 0.2, ...], ...],
"field_name_2": [[0.3, 0.4, ...], ...],
},
"doc_id_2": {"field_name_1": [[0.5, 0.6, ...], ...]}
}
"""

# We can just use the cache here since we refresh every 1s.
Expand Down Expand Up @@ -1417,6 +1443,9 @@ def get_doc_vectors_per_tensor_field_by_ids(
else:
# Otherwise, field is empty list
result[doc_id][marqo_tensor_field_name] = []
elif response.status == 404 and allow_missing_documents:
# If the document is not found and we are allowing missing documents, continue to next response
continue
else:
# If the response is not successful, error out
raise core_exceptions.InvalidArgumentError(
Expand Down
Loading
Loading