@@ -23,10 +23,14 @@ def __init__(self, vespa_client: VespaClient, index_management: IndexManagement,
2323 self .index_management = index_management
2424 self .inference = inference
2525
26- def get_doc_vectors_from_ids (self ,
27- index_name : str ,
28- documents : Union [List [str ], Dict [str , float ]],
29- tensor_fields : Optional [List [str ]] = None ) -> Dict [str , List [List [float ]]]:
26+ def get_doc_vectors_from_ids (
27+ self ,
28+ index_name : str ,
29+ documents : Union [List [str ], Dict [str , float ]],
30+ tensor_fields : Optional [List [str ]] = None ,
31+ allow_missing_documents : bool = False ,
32+ allow_missing_embeddings : bool = False
33+ ) -> Dict [str , List [List [float ]]]:
3034 """
3135 This method gets documents from Vespa using their IDs, removes any unnecessary data, checks for
3236 lack of vectors, then returns a list of document vectors. Can be used internally (in recommend)
@@ -36,10 +40,17 @@ def get_doc_vectors_from_ids(self,
3640 index_name: Name of the index to search
3741 documents: A list of document IDs or a dictionary where the keys are document IDs and the values are weights
3842 tensor_fields: List of tensor fields to use for recommendation (can include text, image, audio, and video fields)
43+ allow_missing_documents: If True, will not raise an error if some document IDs are not found
44+ allow_missing_embeddings: If True, will not raise an error if some documents do not have embeddings
3945
4046 Returns:
4147 A dictionary mapping document IDs to lists of vector embeddings. This is flattened to 1 list per document
4248 ID (not separated by tensor field). Order of embeddings is not guaranteed.
49+
50+ Raises:
51+ InvalidArgumentError:
52+ - If any document IDs are not found and allow_missing_documents is False
53+ - If any document IDs does not have embeddings and allow_missing_embeddings is False
4354 """
4455
4556 # TODO - Extract search and get_docs from tensor_search and refactor this
@@ -84,55 +95,107 @@ def get_doc_vectors_from_ids(self,
8495 f'Available tensor fields: { ", " .join (valid_tensor_fields )} ' )
8596
8697 # Use the new optimized method to get only embeddings
98+ # TODO - Consolidate these two method into one place
8799 doc_embeddings_by_field = tensor_search .get_doc_vectors_per_tensor_field_by_ids (
88100 config .Config (self .vespa_client , inference = self .inference ),
89101 index_name ,
90102 document_ids ,
91- tensor_fields = tensor_fields
103+ tensor_fields = tensor_fields ,
104+ allow_missing_documents = allow_missing_documents ,
92105 )
93106
107+ return self ._sanitize_doc_embeddins_by_field (
108+ all_documents_ids = document_ids ,
109+ marqo_index = marqo_index ,
110+ doc_embeddings_by_field = doc_embeddings_by_field ,
111+ tensor_fields = tensor_fields ,
112+ allow_missing_documents = allow_missing_documents ,
113+ allow_missing_embeddings = allow_missing_embeddings ,
114+ )
115+
116+ def _sanitize_doc_embeddins_by_field (
117+ self ,
118+ all_documents_ids : List [str ],
119+ marqo_index : MarqoIndex ,
120+ doc_embeddings_by_field : Dict [str , Dict [str , List [List [float ]]]],
121+ tensor_fields : Optional [List [str ]],
122+ allow_missing_documents : bool ,
123+ allow_missing_embeddings : bool
124+ ) -> Dict [str , List [List [float ]]]:
125+ """
126+ Sanitize the document embeddings by checking for missing documents and embeddings,
127+ and flattening the structure to a simple mapping of document ID to list of embeddings.
128+
129+ If allow_missing_documents is False, raises an error if any document IDs are not found.
130+ If allow_missing_embeddings is False, raises an error if any documents do not have embeddings.
131+
132+ Documents with no embeddings are removed from the result.
133+ Args:
134+ all_documents_ids: The list of all document IDs that were requested
135+ marqo_index: The marqo index object containing metadata about the index
136+ doc_embeddings_by_field: The document embeddings by field returned from
137+ tensor_search.get_doc_vectors_per_tensor_field_by_ids
138+ tensor_fields: tensor fields to include in the result. If None, all fields are included.
139+ allow_missing_documents: If True, will not raise an error if some document IDs are not found.
140+ allow_missing_embeddings: If True, will not raise an error if some documents do not have embeddings.
141+
142+ Returns:
143+ A dictionary mapping document IDs to lists of vector embeddings.
144+ E.g.,
145+ {
146+ "doc1": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],
147+ "doc2": [[0.7, 0.8, 0.9]]
148+ }
149+ where each list contains embeddings from all tensor fields where order of embeddings is not preserved.
150+
151+ Raises:
152+ InvalidArgumentError: If any document IDs are not found and allow_missing_documents is False,
153+ or if any documents do not have embeddings and allow_missing_embeddings is False.
154+ """
155+
94156 # Check that all documents were found
95- not_found = []
96- for doc_id in document_ids :
157+ not_found_docs = []
158+ for doc_id in all_documents_ids :
97159 if doc_id not in doc_embeddings_by_field :
98- not_found .append (doc_id )
160+ not_found_docs .append (doc_id )
99161
100- if len (not_found ) > 0 :
101- raise InvalidArgumentError (f'The following document IDs were not found: { ", " .join (not_found )} ' )
162+ if len (not_found_docs ) > 0 and not allow_missing_documents :
163+ raise InvalidArgumentError (f'The following document IDs were not found: { ", " .join (not_found_docs )} ' )
102164
103165 # Flatten the embeddings structure to match the expected return format
104- # Convert from Dict[doc_id, Dict[field_name, List[List[float]]]]
166+ # Convert from Dict[doc_id, Dict[field_name, List[List[float]]]]
105167 # to Dict[doc_id, List[List[float]]]
106168 doc_vectors : Dict [str , List [List [float ]]] = {}
107169 docs_without_vectors = []
108-
170+
109171 for doc_id , field_embeddings in doc_embeddings_by_field .items ():
110172 vectors : List [List [float ]] = []
111-
173+
112174 # Flatten all embeddings from all fields for this document
113175 for field_name , embedding_list in field_embeddings .items ():
114176 # For legacy unstructured indices, field_name will be "marqo__embeddings"
115177 # and we should include all embeddings regardless of tensor_fields filter
116178 # since all embeddings are stored together in marqo__embeddings
117- if (tensor_fields is None or
179+ if (tensor_fields is None or
118180 field_name in tensor_fields or
119181 (marqo_index .type == IndexType .Unstructured and
120182 field_name == unstructured_common .VESPA_DOC_EMBEDDINGS )):
121183 vectors .extend (embedding_list )
122-
184+
123185 doc_vectors [doc_id ] = vectors
124186
125187 if len (vectors ) == 0 :
126188 docs_without_vectors .append (doc_id )
127189
128- if len (docs_without_vectors ) > 0 :
190+
191+ if len (docs_without_vectors ) > 0 and not allow_missing_embeddings :
129192 raise InvalidArgumentError (
130193 f'The following documents do not have embeddings: { ", " .join (docs_without_vectors )} '
131194 )
132-
195+ for doc_id in docs_without_vectors :
196+ del doc_vectors [doc_id ]
133197 return doc_vectors
134198
135-
136199 def recommend (self ,
137200 index_name : str ,
138201 documents : Union [List [str ], Dict [str , float ]],
@@ -150,7 +213,9 @@ def recommend(self,
150213 filter : str = None ,
151214 attributes_to_retrieve : Optional [List [str ]] = None ,
152215 score_modifiers : Optional [ScoreModifierLists ] = None ,
153- rerank_depth : Optional [int ] = None
216+ rerank_depth : Optional [int ] = None ,
217+ allow_missing_documents : bool = False ,
218+ allow_missing_embeddings : bool = False ,
154219 ):
155220 """
156221 Recommend documents similar to the provided documents.
@@ -191,7 +256,9 @@ def recommend(self,
191256 doc_vectors = self .get_doc_vectors_from_ids (
192257 index_name = index_name ,
193258 documents = documents ,
194- tensor_fields = tensor_fields
259+ tensor_fields = tensor_fields ,
260+ allow_missing_documents = allow_missing_documents ,
261+ allow_missing_embeddings = allow_missing_embeddings ,
195262 )
196263
197264 # Save original document IDs for filtering
@@ -208,9 +275,16 @@ def recommend(self,
208275 weight = documents [document_id ]
209276 else :
210277 weight = 1
278+
211279 vectors .extend (vector_list )
212280 weights .extend ([weight ] * len (vector_list ))
213281
282+ if len (vectors ) == 0 :
283+ raise InvalidArgumentError (
284+ "Marqo could not collect any valid vector from the documents. "
285+ "Please check if the provided documents exist or if the documents have valid embeddings. "
286+ )
287+
214288 try :
215289 interpolated_vector = vector_interpolation .interpolate (
216290 vectors , weights
@@ -253,7 +327,7 @@ def recommend(self,
253327 attributes_to_retrieve = attributes_to_retrieve ,
254328 score_modifiers = score_modifiers ,
255329 processing_start = t0 ,
256- rerank_depth = rerank_depth
330+ rerank_depth = rerank_depth ,
257331 )
258332
259333 return results
0 commit comments