2424import java .io .IOException ;
2525import java .lang .invoke .MethodHandles ;
2626import java .util .ArrayList ;
27+ import java .util .Collection ;
2728import java .util .HashMap ;
29+ import java .util .HashSet ;
2830import java .util .List ;
2931import java .util .Map ;
32+ import java .util .Set ;
33+ import org .apache .lucene .document .StoredField ;
3034import org .apache .lucene .index .DocValues ;
3135import org .apache .lucene .index .LeafReader ;
3236import org .apache .lucene .index .LeafReaderContext ;
3539import org .apache .lucene .index .SortedDocValues ;
3640import org .apache .lucene .index .Terms ;
3741import org .apache .lucene .index .TermsEnum ;
42+ import org .apache .lucene .index .VectorEncoding ;
3843import org .apache .lucene .search .DocIdSetIterator ;
3944import org .apache .lucene .search .join .BitSetProducer ;
4045import org .apache .lucene .util .BitSet ;
4146import org .apache .lucene .util .Bits ;
4247import org .apache .lucene .util .BytesRef ;
4348import org .apache .solr .common .SolrDocument ;
4449import org .apache .solr .common .SolrException ;
50+ import org .apache .solr .schema .DenseVectorField ;
4551import org .apache .solr .schema .IndexSchema ;
52+ import org .apache .solr .schema .SchemaField ;
4653import org .apache .solr .search .BitsFilteredPostingsEnum ;
4754import org .apache .solr .search .DocIterationInfo ;
4855import org .apache .solr .search .DocSet ;
@@ -138,6 +145,20 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
138145 final Bits liveDocs = leafReaderContext .reader ().getLiveDocs ();
139146 final int segBaseId = leafReaderContext .docBase ;
140147 final int segRootId = rootDocId - segBaseId ;
148+ Set <String > multiValuedFLoatVectorFields =
149+ this .getMultiValuedVectorFields (
150+ searcher .getSchema (), childReturnFields , VectorEncoding .FLOAT32 );
151+ Set <String > multiValuedByteVectorFields =
152+ this .getMultiValuedVectorFields (
153+ searcher .getSchema (), childReturnFields , VectorEncoding .BYTE );
154+ if ((multiValuedFLoatVectorFields .size () + multiValuedByteVectorFields .size ()) > 0
155+ && (multiValuedFLoatVectorFields .size () + multiValuedByteVectorFields .size ())
156+ != childReturnFields .getExplicitlyRequestedFieldNames ().size ()) {
157+ throw new SolrException (
158+ SolrException .ErrorCode .BAD_REQUEST ,
159+ "When using the Child transformer to flatten nested vectors, all 'fl' must be "
160+ + "multivalued vector fields" );
161+ }
141162
142163 // can return be -1 and that's okay (happens for very first block)
143164 final int segPrevRootId ;
@@ -219,8 +240,21 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
219240
220241 if (isAncestor ) {
221242 // if this path has pending child docs, add them.
222- addChildrenToParent (
223- doc , pendingParentPathsToChildren .remove (fullDocPath )); // no longer pending
243+ if (!multiValuedFLoatVectorFields .isEmpty () || !multiValuedByteVectorFields .isEmpty ()) {
244+ addFlatMultiValuedVectorsToParent (
245+ rootDoc ,
246+ pendingParentPathsToChildren .values ().iterator ().next (),
247+ multiValuedFLoatVectorFields ,
248+ VectorEncoding .FLOAT32 );
249+ addFlatMultiValuedVectorsToParent (
250+ rootDoc ,
251+ pendingParentPathsToChildren .values ().iterator ().next (),
252+ multiValuedByteVectorFields ,
253+ VectorEncoding .BYTE );
254+ } else {
255+ addChildrenToParent (
256+ doc , pendingParentPathsToChildren .remove (fullDocPath )); // no longer pending
257+ }
224258 }
225259
226260 // get parent path
@@ -248,7 +282,20 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
248282 assert pendingParentPathsToChildren .keySet ().size () == 1 ;
249283
250284 // size == 1, so get the last remaining entry
251- addChildrenToParent (rootDoc , pendingParentPathsToChildren .values ().iterator ().next ());
285+ if (!multiValuedFLoatVectorFields .isEmpty () || !multiValuedByteVectorFields .isEmpty ()) {
286+ addFlatMultiValuedVectorsToParent (
287+ rootDoc ,
288+ pendingParentPathsToChildren .values ().iterator ().next (),
289+ multiValuedFLoatVectorFields ,
290+ VectorEncoding .FLOAT32 );
291+ addFlatMultiValuedVectorsToParent (
292+ rootDoc ,
293+ pendingParentPathsToChildren .values ().iterator ().next (),
294+ multiValuedByteVectorFields ,
295+ VectorEncoding .BYTE );
296+ } else {
297+ addChildrenToParent (rootDoc , pendingParentPathsToChildren .values ().iterator ().next ());
298+ }
252299
253300 } catch (IOException e ) {
254301 // TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow?
@@ -257,6 +304,25 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
257304 }
258305 }
259306
307+ private Set <String > getMultiValuedVectorFields (
308+ IndexSchema schema , SolrReturnFields childReturnFields , VectorEncoding encoding ) {
309+ Set <String > multiValuedVectorsFields = new HashSet <>();
310+ Set <String > explicitlyRequestedFieldNames =
311+ childReturnFields .getExplicitlyRequestedFieldNames ();
312+ if (explicitlyRequestedFieldNames != null ) {
313+ for (String fieldName : explicitlyRequestedFieldNames ) {
314+ SchemaField sfield = schema .getFieldOrNull (fieldName );
315+ if (sfield != null
316+ && sfield .getType () instanceof DenseVectorField
317+ && sfield .multiValued ()
318+ && ((DenseVectorField ) sfield .getType ()).getVectorEncoding () == encoding ) {
319+ multiValuedVectorsFields .add (fieldName );
320+ }
321+ }
322+ }
323+ return multiValuedVectorsFields ;
324+ }
325+
260326 private static void addChildrenToParent (
261327 SolrDocument parent , Map <String , List <SolrDocument >> children ) {
262328 for (Map .Entry <String , List <SolrDocument >> entry : children .entrySet ()) {
@@ -285,6 +351,54 @@ private static void addChildrenToParent(
285351 parent .setField (trimmedPath , children .get (0 ));
286352 }
287353
354+ private void addFlatMultiValuedVectorsToParent (
355+ SolrDocument parent ,
356+ Map <String , List <SolrDocument >> children ,
357+ Set <String > multiValuedVectorFields ,
358+ VectorEncoding encoding ) {
359+ for (String multiValuedVectorField : multiValuedVectorFields ) {
360+ List <SolrDocument > solrDocuments = children .get (multiValuedVectorField );
361+ List <List <Number >> multiValuedVectors = new ArrayList <>(solrDocuments .size ());
362+ for (SolrDocument singleVector : solrDocuments ) {
363+ List <Number > extractedVectors ;
364+ switch (encoding ) {
365+ case FLOAT32 :
366+ extractedVectors =
367+ this .extractFloatVector (singleVector .getFieldValues (multiValuedVectorField ));
368+ break ;
369+ case BYTE :
370+ extractedVectors =
371+ this .extractByteVector (singleVector .getFieldValues (multiValuedVectorField ));
372+ break ;
373+ default :
374+ throw new SolrException (
375+ SolrException .ErrorCode .BAD_REQUEST , "Unsupported vector encoding: " + encoding );
376+ }
377+ multiValuedVectors .add (extractedVectors );
378+ }
379+ parent .setField (multiValuedVectorField , multiValuedVectors );
380+ }
381+ }
382+
383+ private List <Number > extractFloatVector (Collection <Object > fieldValues ) {
384+ List <Number > vector = new ArrayList <>(fieldValues .size ());
385+ for (Object fieldValue : fieldValues ) {
386+ StoredField storedVectorValue = (StoredField ) fieldValue ;
387+ vector .add (storedVectorValue .numericValue ());
388+ }
389+ return vector ;
390+ }
391+
392+ private List <Number > extractByteVector (Collection <Object > singleVector ) {
393+ StoredField vector = (StoredField ) singleVector .iterator ().next ();
394+ BytesRef byteVector = vector .binaryValue ();
395+ List <Number > extractedVector = new ArrayList <>(byteVector .length );
396+ for (Byte element : byteVector .bytes ) {
397+ extractedVector .add (element .byteValue ());
398+ }
399+ return extractedVector ;
400+ }
401+
288402 private static String getLastPath (String path ) {
289403 int lastIndexOfPathSepChar = path .lastIndexOf (PATH_SEP_CHAR );
290404 if (lastIndexOfPathSepChar == -1 ) {
0 commit comments