Skip to content

Commit 949b28e

Browse files
Support for multi valued dense vector fields (through nested vectors and diversifying children query) (#4051)
1 parent 6cc80a0 commit 949b28e

File tree

11 files changed

+1491
-60
lines changed

11 files changed

+1491
-60
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# See https://github.com/apache/solr/blob/main/dev-docs/changelog.adoc
2+
title: Introducing support for multi valued dense vector representation in documents through nested vectors
3+
type: added # added, changed, fixed, deprecated, removed, dependency_update, security, other
4+
authors:
5+
- name: Alessandro Benedetti
6+
links:
7+
- name: SOLR-18074
8+
url: https://issues.apache.org/jira/browse/SOLR-18074

solr/core/src/java/org/apache/solr/response/transform/ChildDocTransformer.java

Lines changed: 117 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,13 @@
2424
import java.io.IOException;
2525
import java.lang.invoke.MethodHandles;
2626
import java.util.ArrayList;
27+
import java.util.Collection;
2728
import java.util.HashMap;
29+
import java.util.HashSet;
2830
import java.util.List;
2931
import java.util.Map;
32+
import java.util.Set;
33+
import org.apache.lucene.document.StoredField;
3034
import org.apache.lucene.index.DocValues;
3135
import org.apache.lucene.index.LeafReader;
3236
import org.apache.lucene.index.LeafReaderContext;
@@ -35,14 +39,17 @@
3539
import org.apache.lucene.index.SortedDocValues;
3640
import org.apache.lucene.index.Terms;
3741
import org.apache.lucene.index.TermsEnum;
42+
import org.apache.lucene.index.VectorEncoding;
3843
import org.apache.lucene.search.DocIdSetIterator;
3944
import org.apache.lucene.search.join.BitSetProducer;
4045
import org.apache.lucene.util.BitSet;
4146
import org.apache.lucene.util.Bits;
4247
import org.apache.lucene.util.BytesRef;
4348
import org.apache.solr.common.SolrDocument;
4449
import org.apache.solr.common.SolrException;
50+
import org.apache.solr.schema.DenseVectorField;
4551
import org.apache.solr.schema.IndexSchema;
52+
import org.apache.solr.schema.SchemaField;
4653
import org.apache.solr.search.BitsFilteredPostingsEnum;
4754
import org.apache.solr.search.DocIterationInfo;
4855
import org.apache.solr.search.DocSet;
@@ -138,6 +145,20 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
138145
final Bits liveDocs = leafReaderContext.reader().getLiveDocs();
139146
final int segBaseId = leafReaderContext.docBase;
140147
final int segRootId = rootDocId - segBaseId;
148+
Set<String> multiValuedFLoatVectorFields =
149+
this.getMultiValuedVectorFields(
150+
searcher.getSchema(), childReturnFields, VectorEncoding.FLOAT32);
151+
Set<String> multiValuedByteVectorFields =
152+
this.getMultiValuedVectorFields(
153+
searcher.getSchema(), childReturnFields, VectorEncoding.BYTE);
154+
if ((multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size()) > 0
155+
&& (multiValuedFLoatVectorFields.size() + multiValuedByteVectorFields.size())
156+
!= childReturnFields.getExplicitlyRequestedFieldNames().size()) {
157+
throw new SolrException(
158+
SolrException.ErrorCode.BAD_REQUEST,
159+
"When using the Child transformer to flatten nested vectors, all 'fl' must be "
160+
+ "multivalued vector fields");
161+
}
141162

142163
// can return be -1 and that's okay (happens for very first block)
143164
final int segPrevRootId;
@@ -219,8 +240,21 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
219240

220241
if (isAncestor) {
221242
// if this path has pending child docs, add them.
222-
addChildrenToParent(
223-
doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending
243+
if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) {
244+
addFlatMultiValuedVectorsToParent(
245+
rootDoc,
246+
pendingParentPathsToChildren.values().iterator().next(),
247+
multiValuedFLoatVectorFields,
248+
VectorEncoding.FLOAT32);
249+
addFlatMultiValuedVectorsToParent(
250+
rootDoc,
251+
pendingParentPathsToChildren.values().iterator().next(),
252+
multiValuedByteVectorFields,
253+
VectorEncoding.BYTE);
254+
} else {
255+
addChildrenToParent(
256+
doc, pendingParentPathsToChildren.remove(fullDocPath)); // no longer pending
257+
}
224258
}
225259

226260
// get parent path
@@ -248,7 +282,20 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
248282
assert pendingParentPathsToChildren.keySet().size() == 1;
249283

250284
// size == 1, so get the last remaining entry
251-
addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next());
285+
if (!multiValuedFLoatVectorFields.isEmpty() || !multiValuedByteVectorFields.isEmpty()) {
286+
addFlatMultiValuedVectorsToParent(
287+
rootDoc,
288+
pendingParentPathsToChildren.values().iterator().next(),
289+
multiValuedFLoatVectorFields,
290+
VectorEncoding.FLOAT32);
291+
addFlatMultiValuedVectorsToParent(
292+
rootDoc,
293+
pendingParentPathsToChildren.values().iterator().next(),
294+
multiValuedByteVectorFields,
295+
VectorEncoding.BYTE);
296+
} else {
297+
addChildrenToParent(rootDoc, pendingParentPathsToChildren.values().iterator().next());
298+
}
252299

253300
} catch (IOException e) {
254301
// TODO DWS: reconsider this unusual error handling approach; shouldn't we rethrow?
@@ -257,6 +304,25 @@ public void transform(SolrDocument rootDoc, int rootDocId, DocIterationInfo docI
257304
}
258305
}
259306

307+
private Set<String> getMultiValuedVectorFields(
308+
IndexSchema schema, SolrReturnFields childReturnFields, VectorEncoding encoding) {
309+
Set<String> multiValuedVectorsFields = new HashSet<>();
310+
Set<String> explicitlyRequestedFieldNames =
311+
childReturnFields.getExplicitlyRequestedFieldNames();
312+
if (explicitlyRequestedFieldNames != null) {
313+
for (String fieldName : explicitlyRequestedFieldNames) {
314+
SchemaField sfield = schema.getFieldOrNull(fieldName);
315+
if (sfield != null
316+
&& sfield.getType() instanceof DenseVectorField
317+
&& sfield.multiValued()
318+
&& ((DenseVectorField) sfield.getType()).getVectorEncoding() == encoding) {
319+
multiValuedVectorsFields.add(fieldName);
320+
}
321+
}
322+
}
323+
return multiValuedVectorsFields;
324+
}
325+
260326
private static void addChildrenToParent(
261327
SolrDocument parent, Map<String, List<SolrDocument>> children) {
262328
for (Map.Entry<String, List<SolrDocument>> entry : children.entrySet()) {
@@ -285,6 +351,54 @@ private static void addChildrenToParent(
285351
parent.setField(trimmedPath, children.get(0));
286352
}
287353

354+
private void addFlatMultiValuedVectorsToParent(
355+
SolrDocument parent,
356+
Map<String, List<SolrDocument>> children,
357+
Set<String> multiValuedVectorFields,
358+
VectorEncoding encoding) {
359+
for (String multiValuedVectorField : multiValuedVectorFields) {
360+
List<SolrDocument> solrDocuments = children.get(multiValuedVectorField);
361+
List<List<Number>> multiValuedVectors = new ArrayList<>(solrDocuments.size());
362+
for (SolrDocument singleVector : solrDocuments) {
363+
List<Number> extractedVectors;
364+
switch (encoding) {
365+
case FLOAT32:
366+
extractedVectors =
367+
this.extractFloatVector(singleVector.getFieldValues(multiValuedVectorField));
368+
break;
369+
case BYTE:
370+
extractedVectors =
371+
this.extractByteVector(singleVector.getFieldValues(multiValuedVectorField));
372+
break;
373+
default:
374+
throw new SolrException(
375+
SolrException.ErrorCode.BAD_REQUEST, "Unsupported vector encoding: " + encoding);
376+
}
377+
multiValuedVectors.add(extractedVectors);
378+
}
379+
parent.setField(multiValuedVectorField, multiValuedVectors);
380+
}
381+
}
382+
383+
private List<Number> extractFloatVector(Collection<Object> fieldValues) {
384+
List<Number> vector = new ArrayList<>(fieldValues.size());
385+
for (Object fieldValue : fieldValues) {
386+
StoredField storedVectorValue = (StoredField) fieldValue;
387+
vector.add(storedVectorValue.numericValue());
388+
}
389+
return vector;
390+
}
391+
392+
private List<Number> extractByteVector(Collection<Object> singleVector) {
393+
StoredField vector = (StoredField) singleVector.iterator().next();
394+
BytesRef byteVector = vector.binaryValue();
395+
List<Number> extractedVector = new ArrayList<>(byteVector.length);
396+
for (Byte element : byteVector.bytes) {
397+
extractedVector.add(element.byteValue());
398+
}
399+
return extractedVector;
400+
}
401+
288402
private static String getLastPath(String path) {
289403
int lastIndexOfPathSepChar = path.lastIndexOf(PATH_SEP_CHAR);
290404
if (lastIndexOfPathSepChar == -1) {

solr/core/src/java/org/apache/solr/schema/DenseVectorField.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,11 +316,6 @@ protected boolean enableDocValuesByDefault() {
316316
@Override
317317
public void checkSchemaField(final SchemaField field) throws SolrException {
318318
super.checkSchemaField(field);
319-
if (field.multiValued()) {
320-
throw new SolrException(
321-
SolrException.ErrorCode.SERVER_ERROR,
322-
getClass().getSimpleName() + " fields can not be multiValued: " + field.getName());
323-
}
324319

325320
if (field.hasDocValues()) {
326321
throw new SolrException(

solr/core/src/java/org/apache/solr/schema/IndexSchema.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ public class IndexSchema {
106106
public static final String NAME = "name";
107107
public static final String NEST_PARENT_FIELD_NAME = "_nest_parent_";
108108
public static final String NEST_PATH_FIELD_NAME = "_nest_path_";
109+
public static final String NESTED_VECTORS_PSEUDO_FIELD_NAME = "_nested_vectors_";
109110
public static final String REQUIRED = "required";
110111
public static final String SCHEMA = "schema";
111112
public static final String SIMILARITY = "similarity";

solr/core/src/java/org/apache/solr/update/processor/NestedUpdateProcessorFactory.java

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@
1717

1818
package org.apache.solr.update.processor;
1919

20+
import static org.apache.solr.schema.IndexSchema.NESTED_VECTORS_PSEUDO_FIELD_NAME;
21+
2022
import java.io.IOException;
23+
import java.util.ArrayList;
2124
import java.util.Collection;
25+
import java.util.List;
2226
import org.apache.solr.common.SolrException;
2327
import org.apache.solr.common.SolrInputDocument;
2428
import org.apache.solr.common.SolrInputField;
2529
import org.apache.solr.request.SolrQueryRequest;
2630
import org.apache.solr.response.SolrQueryResponse;
31+
import org.apache.solr.schema.DenseVectorField;
2732
import org.apache.solr.schema.IndexSchema;
33+
import org.apache.solr.schema.SchemaField;
2834
import org.apache.solr.update.AddUpdateCommand;
2935

3036
/**
@@ -63,13 +69,15 @@ private static class NestedUpdateProcessor extends UpdateRequestProcessor {
6369
private boolean storePath;
6470
private boolean storeParent;
6571
private String uniqueKeyFieldName;
72+
private IndexSchema schema;
6673

6774
NestedUpdateProcessor(
6875
SolrQueryRequest req, boolean storeParent, boolean storePath, UpdateRequestProcessor next) {
6976
super(next);
7077
this.storeParent = storeParent;
7178
this.storePath = storePath;
7279
this.uniqueKeyFieldName = req.getSchema().getUniqueKeyField().getName();
80+
this.schema = req.getSchema();
7381
}
7482

7583
@Override
@@ -81,66 +89,111 @@ public void processAdd(AddUpdateCommand cmd) throws IOException {
8189

8290
private boolean processDocChildren(SolrInputDocument doc, String fullPath) {
8391
boolean isNested = false;
92+
List<String> originalVectorFieldsToRemove = new ArrayList<>();
93+
ArrayList<SolrInputDocument> vectors = new ArrayList<>();
8494
for (SolrInputField field : doc.values()) {
95+
SchemaField sfield = schema.getFieldOrNull(field.getName());
8596
int childNum = 0;
8697
boolean isSingleVal = !(field.getValue() instanceof Collection);
87-
for (Object val : field) {
88-
if (!(val instanceof SolrInputDocument cDoc)) {
89-
// either all collection items are child docs or none are.
90-
break;
91-
}
92-
final String fieldName = field.getName();
93-
94-
if (fieldName.contains(PATH_SEP_CHAR)) {
95-
throw new SolrException(
96-
SolrException.ErrorCode.BAD_REQUEST,
97-
"Field name: '"
98-
+ fieldName
99-
+ "' contains: '"
100-
+ PATH_SEP_CHAR
101-
+ "' , which is reserved for the nested URP");
102-
}
103-
final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum);
104-
if (!cDoc.containsKey(uniqueKeyFieldName)) {
98+
boolean firstLevelChildren = fullPath == null;
99+
if (firstLevelChildren && sfield != null && isMultiValuedVectorField(sfield)) {
100+
for (Object vectorValue : field.getValues()) {
101+
SolrInputDocument singleVectorNestedDoc = new SolrInputDocument();
102+
singleVectorNestedDoc.setField(field.getName(), vectorValue);
103+
final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum);
105104
String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString();
106-
cDoc.setField(
107-
uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum));
105+
singleVectorNestedDoc.setField(
106+
uniqueKeyFieldName, generateChildUniqueId(parentDocId, field.getName(), sChildNum));
107+
108+
if (!isNested) {
109+
isNested = true;
110+
}
111+
final String lastKeyPath = PATH_SEP_CHAR + field.getName() + NUM_SEP_CHAR + sChildNum;
112+
final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath;
113+
if (storePath) {
114+
setPathField(singleVectorNestedDoc, childDocPath);
115+
}
116+
if (storeParent) {
117+
setParentKey(singleVectorNestedDoc, doc);
118+
}
119+
++childNum;
120+
vectors.add(singleVectorNestedDoc);
108121
}
109-
if (!isNested) {
110-
isNested = true;
122+
originalVectorFieldsToRemove.add(field.getName());
123+
} else {
124+
for (Object val : field) {
125+
if (!(val instanceof SolrInputDocument cDoc)) {
126+
// either all collection items are child docs or none are.
127+
break;
128+
}
129+
final String fieldName = field.getName();
130+
131+
if (fieldName.contains(PATH_SEP_CHAR)) {
132+
throw new SolrException(
133+
SolrException.ErrorCode.BAD_REQUEST,
134+
"Field name: '"
135+
+ fieldName
136+
+ "' contains: '"
137+
+ PATH_SEP_CHAR
138+
+ "' , which is reserved for the nested URP");
139+
}
140+
final String sChildNum = isSingleVal ? SINGULAR_VALUE_CHAR : String.valueOf(childNum);
141+
if (!cDoc.containsKey(uniqueKeyFieldName)) {
142+
String parentDocId = doc.getField(uniqueKeyFieldName).getFirstValue().toString();
143+
cDoc.setField(
144+
uniqueKeyFieldName, generateChildUniqueId(parentDocId, fieldName, sChildNum));
145+
}
146+
if (!isNested) {
147+
isNested = true;
148+
}
149+
final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum;
150+
// concat of all paths children.grandChild => /children#1/grandChild#
151+
final String childDocPath = firstLevelChildren ? lastKeyPath : fullPath + lastKeyPath;
152+
processChildDoc(cDoc, doc, childDocPath);
153+
++childNum;
111154
}
112-
final String lastKeyPath = PATH_SEP_CHAR + fieldName + NUM_SEP_CHAR + sChildNum;
113-
// concat of all paths children.grandChild => /children#1/grandChild#
114-
final String childDocPath = fullPath == null ? lastKeyPath : fullPath + lastKeyPath;
115-
processChildDoc(cDoc, doc, childDocPath);
116-
++childNum;
117155
}
118156
}
157+
this.cleanOriginalVectorFields(doc, originalVectorFieldsToRemove);
158+
if (vectors.size() > 0) {
159+
doc.setField(NESTED_VECTORS_PSEUDO_FIELD_NAME, vectors);
160+
}
119161
return isNested;
120162
}
121163

164+
private void cleanOriginalVectorFields(
165+
SolrInputDocument doc, List<String> originalVectorFieldsToRemove) {
166+
for (String fieldName : originalVectorFieldsToRemove) {
167+
doc.removeField(fieldName);
168+
}
169+
}
170+
171+
private static boolean isMultiValuedVectorField(SchemaField sfield) {
172+
return sfield.getType() instanceof DenseVectorField && sfield.multiValued();
173+
}
174+
122175
private void processChildDoc(
123-
SolrInputDocument sdoc, SolrInputDocument parent, String fullPath) {
176+
SolrInputDocument child, SolrInputDocument parent, String fullPath) {
124177
if (storePath) {
125-
setPathField(sdoc, fullPath);
178+
setPathField(child, fullPath);
126179
}
127180
if (storeParent) {
128-
setParentKey(sdoc, parent);
181+
setParentKey(child, parent);
129182
}
130-
processDocChildren(sdoc, fullPath);
183+
processDocChildren(child, fullPath);
131184
}
132185

133186
private String generateChildUniqueId(String parentId, String childKey, String childNum) {
134187
// combines parentId with the child's key and childNum. e.g. "10/footnote#1"
135188
return parentId + PATH_SEP_CHAR + childKey + NUM_SEP_CHAR + childNum;
136189
}
137190

138-
private void setParentKey(SolrInputDocument sdoc, SolrInputDocument parent) {
139-
sdoc.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName));
191+
private void setParentKey(SolrInputDocument child, SolrInputDocument parent) {
192+
child.setField(IndexSchema.NEST_PARENT_FIELD_NAME, parent.getFieldValue(uniqueKeyFieldName));
140193
}
141194

142-
private void setPathField(SolrInputDocument sdoc, String fullPath) {
143-
sdoc.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath);
195+
private void setPathField(SolrInputDocument child, String fullPath) {
196+
child.setField(IndexSchema.NEST_PATH_FIELD_NAME, fullPath);
144197
}
145198
}
146199
}

0 commit comments

Comments
 (0)