diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 5b91accaedc..13c22d7898e 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -29,6 +29,8 @@ New Features * SOLR-17814: Add support for PatienceKnnVectorQuery. (Ilaria Petreti via Alessandro Benedetti) +* SOLR-17948: Support indexing primitive float[] values for DenseVectorField via JavaBin (Puneet Ahuja, Noble Paul) + Improvements --------------------- diff --git a/solr/core/src/java/org/apache/solr/util/vector/DenseVectorParser.java b/solr/core/src/java/org/apache/solr/util/vector/DenseVectorParser.java index 14c4b4ff653..82d57bc29c6 100644 --- a/solr/core/src/java/org/apache/solr/util/vector/DenseVectorParser.java +++ b/solr/core/src/java/org/apache/solr/util/vector/DenseVectorParser.java @@ -51,6 +51,20 @@ protected void parseVector() { } protected void parseIndexVector() { + if (inputValue instanceof float[] fa) { + checkVectorDimension(fa.length); + for (int i = 0; i < dimension; i++) { + addNumberElement(fa[i]); + } + return; + } + if (inputValue instanceof double[] da) { + checkVectorDimension(da.length); + for (int i = 0; i < dimension; i++) { + addNumberElement((float) da[i]); + } + return; + } if (!(inputValue instanceof List inputVector)) { throw new SolrException( SolrException.ErrorCode.BAD_REQUEST, "incorrect vector format. " + errorMessage()); diff --git a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java index 7173bffbb9b..6b3c63ca331 100644 --- a/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java +++ b/solr/core/src/test/org/apache/solr/schema/DenseVectorFieldTest.java @@ -18,15 +18,26 @@ import static org.hamcrest.core.Is.is; +import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Map; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.solr.client.solrj.request.JavaBinUpdateRequestCodec; +import org.apache.solr.client.solrj.request.UpdateRequest; import org.apache.solr.common.SolrException; import org.apache.solr.common.SolrInputDocument; +import org.apache.solr.common.params.ModifiableSolrParams; +import org.apache.solr.common.util.ContentStreamBase; import org.apache.solr.core.AbstractBadConfigTestBase; +import org.apache.solr.handler.loader.JavabinLoader; +import org.apache.solr.request.SolrQueryRequest; +import org.apache.solr.response.SolrQueryResponse; +import org.apache.solr.update.CommitUpdateCommand; +import org.apache.solr.update.processor.UpdateRequestProcessor; +import org.apache.solr.update.processor.UpdateRequestProcessorChain; import org.apache.solr.util.vector.DenseVectorParser; import org.junit.Before; import org.junit.Test; @@ -760,4 +771,71 @@ public void denseVectorByteEncoding_shouldRaiseExceptionWithFloatValues() throws deleteCore(); } } + + private void addDocWithJavaBin(SolrInputDocument doc) throws Exception { + UpdateRequest ur = new UpdateRequest(); + ur.add(doc); + + ByteArrayOutputStream bos = new ByteArrayOutputStream(); + JavaBinUpdateRequestCodec codec = new JavaBinUpdateRequestCodec(); + codec.marshal(ur, bos); + + byte[] payload = bos.toByteArray(); + + ContentStreamBase.ByteArrayStream cs = + new ContentStreamBase.ByteArrayStream(payload, "application/javabin"); + + try (SolrQueryRequest sreq = req()) { + SolrQueryResponse srsp = new SolrQueryResponse(); + UpdateRequestProcessorChain chain = + h.getCore().getUpdateProcessorChain(new ModifiableSolrParams()); + try (UpdateRequestProcessor proc = chain.createProcessor(sreq, srsp)) { + new JavabinLoader().load(sreq, srsp, cs, proc); + proc.finish(); + } + h.getCore().getUpdateHandler().commit(new CommitUpdateCommand(sreq, false)); + } + } + + @Test + public void testIndexingViaJavaBin() throws Exception { + try { + initCore("solrconfig-basic.xml", "schema-densevector.xml"); + + int variant = random().nextInt(3); + Object vector; + String id; + + switch (variant) { + case 0: + vector = new float[] {1.1f, 2.2f, 3.3f, 4.4f}; + id = "pf_jb"; + break; + case 1: + vector = new double[] {1.1d, 2.2d, 3.3d, 4.4d}; + id = "pd_jb"; + break; + default: + vector = Arrays.asList(1.1f, 2.2f, 3.3f, 4.4f); + id = "lf_jb"; + break; + } + + SolrInputDocument doc = new SolrInputDocument(); + doc.addField("id", id); + doc.addField("vector", vector); + + addDocWithJavaBin(doc); + + assertJQ( + req("q", "id:" + id, "fl", "vector"), "/response/docs/[0]/vector==[1.1,2.2,3.3,4.4]"); + + assertJQ( + req("q", "{!knn f=vector topK=1}[1.1,2.2,3.3,4.4]", "fl", "id"), + "/response/numFound==1", + "/response/docs/[0]/id==\"" + id + "\""); + } finally { + deleteCore(); + } + } }