diff --git a/examples/src/main/java/io/milvus/v2/NullableVectorExample.java b/examples/src/main/java/io/milvus/v2/NullableVectorExample.java new file mode 100644 index 000000000..84d2bee8a --- /dev/null +++ b/examples/src/main/java/io/milvus/v2/NullableVectorExample.java @@ -0,0 +1,299 @@ +package io.milvus.v2; + +import com.google.gson.JsonNull; +import com.google.gson.JsonObject; +import io.milvus.common.utils.JsonUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.AddCollectionFieldReq; +import io.milvus.v2.service.collection.request.AddFieldReq; +import io.milvus.v2.service.collection.request.CreateCollectionReq; +import io.milvus.v2.service.collection.request.DropCollectionReq; +import io.milvus.v2.service.collection.request.LoadCollectionReq; +import io.milvus.v2.service.collection.request.ReleaseCollectionReq; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.vector.request.InsertReq; +import io.milvus.v2.service.vector.request.QueryReq; +import io.milvus.v2.service.vector.request.SearchReq; +import io.milvus.v2.service.vector.request.data.FloatVec; +import io.milvus.v2.service.vector.response.InsertResp; +import io.milvus.v2.service.vector.response.QueryResp; +import io.milvus.v2.service.vector.response.SearchResp; + +import java.util.*; + +public class NullableVectorExample { + private static final int DIMENSION = 8; + private static final Random RANDOM = new Random(); + + private static List generateFloatVector() { + List vector = new ArrayList<>(); + for (int i = 0; i < DIMENSION; i++) { + vector.add(RANDOM.nextFloat()); + } + return vector; + } + + public static void main(String[] args) throws InterruptedException { + ConnectConfig config = ConnectConfig.builder() + .uri("http://localhost:19530") + .build(); + MilvusClientV2 client = new MilvusClientV2(config); + System.out.println("Connected to Milvus\n"); + + insertNullVectors(client); + addNullableVectorField(client); + + client.close(5L); + System.out.println("Done!"); + } + + private static void insertNullVectors(MilvusClientV2 client) throws InterruptedException { + String collectionName = "java_sdk_example_insert_null_vectors"; + System.out.println("=== Demo 1: Insert null vectors ==="); + + // Drop collection if exists + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Create collection with nullable vector field + CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder() + .build(); + schema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(true) + .autoID(false) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("name") + .dataType(DataType.VarChar) + .maxLength(100) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("embedding") + .dataType(DataType.FloatVector) + .dimension(DIMENSION) + .isNullable(true) // Enable nullable for vector field + .build()); + + client.createCollection(CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(schema) + .build()); + System.out.println("Created collection with nullable vector field"); + + // Create index + IndexParam indexParam = IndexParam.builder() + .fieldName("embedding") + .metricType(IndexParam.MetricType.L2) + .indexType(IndexParam.IndexType.FLAT) + .build(); + client.createIndex(CreateIndexReq.builder() + .collectionName(collectionName) + .indexParams(Collections.singletonList(indexParam)) + .build()); + + // Load collection + client.loadCollection(LoadCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Prepare test data: 100 rows, ~50% null vectors + int totalRows = 100; + int nullPercent = 50; + List data = new ArrayList<>(); + int nullCount = 0; + int validCount = 0; + + for (int i = 1; i <= totalRows; i++) { + JsonObject row = new JsonObject(); + row.addProperty("id", (long) i); + row.addProperty("name", "item_" + i); + + boolean isNull = RANDOM.nextInt(100) < nullPercent; + if (isNull) { + row.add("embedding", JsonNull.INSTANCE); + nullCount++; + } else { + row.add("embedding", JsonUtils.toJsonTree(generateFloatVector())); + validCount++; + } + data.add(row); + } + + // Insert data + InsertResp insertResp = client.insert(InsertReq.builder() + .collectionName(collectionName) + .data(data) + .build()); + System.out.println("Inserted " + insertResp.getInsertCnt() + " rows: " + validCount + " valid, " + nullCount + " null"); + + Thread.sleep(1000); + + // Query all data + QueryResp queryResp = client.query(QueryReq.builder() + .collectionName(collectionName) + .filter("id >= 0") + .outputFields(Arrays.asList("id", "embedding")) + .limit(totalRows + 10) + .build()); + + int queryNullCount = 0; + int queryValidCount = 0; + for (QueryResp.QueryResult result : queryResp.getQueryResults()) { + Object embedding = result.getEntity().get("embedding"); + if (embedding == null) { + queryNullCount++; + } else { + queryValidCount++; + } + } + System.out.println("Query result: " + queryValidCount + " valid, " + queryNullCount + " null"); + + // Search - only returns non-null vectors + SearchResp searchResp = client.search(SearchReq.builder() + .collectionName(collectionName) + .data(Collections.singletonList(new FloatVec(generateFloatVector()))) + .annsField("embedding") + .topK(10) + .outputFields(Arrays.asList("id", "embedding")) + .build()); + + List> searchResults = searchResp.getSearchResults(); + if (!searchResults.isEmpty()) { + System.out.println("Search returned " + searchResults.get(0).size() + " hits (only non-null vectors)"); + } + + // Cleanup + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + System.out.println("Dropped collection\n"); + } + + private static void addNullableVectorField(MilvusClientV2 client) throws InterruptedException { + String collectionName = "java_sdk_example_add_vector_field"; + System.out.println("=== Demo 2: Add nullable vector field to existing collection ==="); + + // Drop collection if exists + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Create collection with one vector field (Milvus requires at least one) + CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder() + .build(); + schema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(true) + .autoID(false) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("name") + .dataType(DataType.VarChar) + .maxLength(100) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("embedding_v1") + .dataType(DataType.FloatVector) + .dimension(DIMENSION) + .build()); + + client.createCollection(CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(schema) + .build()); + System.out.println("Created collection with one vector field"); + + // Create index and load + IndexParam indexParam = IndexParam.builder() + .fieldName("embedding_v1") + .metricType(IndexParam.MetricType.L2) + .indexType(IndexParam.IndexType.FLAT) + .build(); + client.createIndex(CreateIndexReq.builder() + .collectionName(collectionName) + .indexParams(Collections.singletonList(indexParam)) + .build()); + client.loadCollection(LoadCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Insert some data first + List data = new ArrayList<>(); + for (int i = 1; i <= 10; i++) { + JsonObject row = new JsonObject(); + row.addProperty("id", (long) i); + row.addProperty("name", "item_" + i); + row.add("embedding_v1", JsonUtils.toJsonTree(generateFloatVector())); + data.add(row); + } + client.insert(InsertReq.builder() + .collectionName(collectionName) + .data(data) + .build()); + System.out.println("Inserted 10 rows"); + + // Release before adding field + client.releaseCollection(ReleaseCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Add a second nullable vector field to existing collection + client.addCollectionField(AddCollectionFieldReq.builder() + .collectionName(collectionName) + .fieldName("embedding_v2") + .dataType(DataType.FloatVector) + .dimension(DIMENSION) + .isNullable(true) // Must be nullable when adding to existing collection + .build()); + System.out.println("Added nullable vector field 'embedding_v2'"); + + // Create index for the new field + IndexParam newIndexParam = IndexParam.builder() + .fieldName("embedding_v2") + .metricType(IndexParam.MetricType.L2) + .indexType(IndexParam.IndexType.FLAT) + .build(); + client.createIndex(CreateIndexReq.builder() + .collectionName(collectionName) + .indexParams(Collections.singletonList(newIndexParam)) + .build()); + + // Load collection + client.loadCollection(LoadCollectionReq.builder() + .collectionName(collectionName) + .build()); + + Thread.sleep(1000); + + // Query to verify old rows have null for the new field + QueryResp queryResp = client.query(QueryReq.builder() + .collectionName(collectionName) + .filter("id >= 0") + .outputFields(Arrays.asList("id", "embedding_v1", "embedding_v2")) + .limit(10) + .build()); + + System.out.println("Query result (old rows have null for new field):"); + for (QueryResp.QueryResult result : queryResp.getQueryResults()) { + Map entity = result.getEntity(); + long id = (Long) entity.get("id"); + Object v1 = entity.get("embedding_v1"); + Object v2 = entity.get("embedding_v2"); + System.out.println(" id=" + id + ", embedding_v1=" + (v1 == null ? "null" : "has value") + + ", embedding_v2=" + (v2 == null ? "null" : "has value")); + } + + // Cleanup + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + System.out.println("Dropped collection\n"); + } +} diff --git a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java index 0bdba579c..f5f8b916c 100644 --- a/sdk-core/src/main/java/io/milvus/param/ParamUtils.java +++ b/sdk-core/src/main/java/io/milvus/param/ParamUtils.java @@ -1204,6 +1204,16 @@ public static FieldData genFieldData(String fieldName, DataType dataType, DataTy FieldData.Builder builder = FieldData.newBuilder(); if (isVectorDataType(dataType)) { + if (isNullable) { + List tempObjects = new ArrayList<>(); + for (Object obj : objects) { + builder.addValidData(obj != null); + if (obj != null) { + tempObjects.add(obj); + } + } + objects = tempObjects; + } VectorField vectorField = genVectorField(dataType, objects); return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build(); } else { @@ -1228,6 +1238,22 @@ public static FieldData genFieldData(String fieldName, DataType dataType, DataTy @SuppressWarnings("unchecked") public static VectorField genVectorField(DataType dataType, List objects) { + if (objects.isEmpty()) { + if (dataType == DataType.FloatVector) { + return VectorField.newBuilder().setDim(0).setFloatVector(FloatArray.newBuilder().build()).build(); + } else if (dataType == DataType.BinaryVector) { + return VectorField.newBuilder().setDim(0).setBinaryVector(ByteString.EMPTY).build(); + } else if (dataType == DataType.Float16Vector) { + return VectorField.newBuilder().setDim(0).setFloat16Vector(ByteString.EMPTY).build(); + } else if (dataType == DataType.BFloat16Vector) { + return VectorField.newBuilder().setDim(0).setBfloat16Vector(ByteString.EMPTY).build(); + } else if (dataType == DataType.Int8Vector) { + return VectorField.newBuilder().setDim(0).setInt8Vector(ByteString.EMPTY).build(); + } else if (dataType == DataType.SparseFloatVector) { + return VectorField.newBuilder().setDim(0).setSparseFloatVector(SparseFloatArray.newBuilder().build()).build(); + } + } + if (dataType == DataType.FloatVector) { List floats = new ArrayList<>(); // each object is List diff --git a/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java b/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java index 91a6de91d..3c799679b 100644 --- a/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java +++ b/sdk-core/src/main/java/io/milvus/response/FieldDataWrapper.java @@ -138,6 +138,10 @@ public long getRowCount() throws IllegalResponseException { DataType dt = fieldData.getType(); switch (dt) { case FloatVector: { + List validData = fieldData.getValidDataList(); + if (validData != null && !validData.isEmpty()) { + return validData.size(); + } int dim = getDim(); List data = fieldData.getVectors().getFloatVector().getDataList(); if (data.size() % dim != 0) { @@ -152,6 +156,10 @@ public long getRowCount() throws IllegalResponseException { case Float16Vector: case BFloat16Vector: case Int8Vector: { + List validData = fieldData.getValidDataList(); + if (validData != null && !validData.isEmpty()) { + return validData.size(); + } int dim = getDim(); ByteString data = getVectorBytes(fieldData.getVectors(), dt); int bytePerVec = checkDim(dt, data, dim); @@ -159,6 +167,10 @@ public long getRowCount() throws IllegalResponseException { return data.size() / bytePerVec; } case SparseFloatVector: { + List validData = fieldData.getValidDataList(); + if (validData != null && !validData.isEmpty()) { + return validData.size(); + } // for sparse vector, each content is a vector return fieldData.getVectors().getSparseFloatVector().getContentsCount(); } @@ -241,7 +253,7 @@ private List getFieldDataInternal() throws IllegalResponseException { case BFloat16Vector: case Int8Vector: case SparseFloatVector: - return getVectorData(dt, fieldData.getVectors()); + return getVectorData(dt, fieldData.getVectors(), fieldData.getValidDataList()); case Array: case Int64: case Int32: @@ -276,23 +288,26 @@ private List setNoneData(List data, List validData) { return data; } - private List getVectorData(DataType dt, VectorField vector) { + private List getVectorData(DataType dt, VectorField vector, List validData) { + List packData; switch (dt) { case FloatVector: { int dim = getDimInternal(vector); List data = vector.getFloatVector().getDataList(); - if (data.size() % dim != 0) { - String msg = String.format("Returned float vector data array size %d doesn't match dimension %d", - data.size(), dim); - throw new IllegalResponseException(msg); - } - - List> packData = new ArrayList<>(); - int count = data.size() / dim; - for (int i = 0; i < count; ++i) { - packData.add(data.subList(i * dim, (i + 1) * dim)); + List> floatPackData = new ArrayList<>(); + if (dim > 0) { + if (data.size() % dim != 0) { + String msg = String.format("Returned float vector data array size %d doesn't match dimension %d", + data.size(), dim); + throw new IllegalResponseException(msg); + } + int count = data.size() / dim; + for (int i = 0; i < count; ++i) { + floatPackData.add(data.subList(i * dim, (i + 1) * dim)); + } } - return packData; + packData = floatPackData; + break; } case BinaryVector: case Float16Vector: @@ -300,37 +315,57 @@ private List getVectorData(DataType dt, VectorField vector) { case Int8Vector: { int dim = getDimInternal(vector); ByteString data = getVectorBytes(vector, dt); - int bytePerVec = checkDim(dt, data, dim); - int count = data.size() / bytePerVec; - List packData = new ArrayList<>(); - for (int i = 0; i < count; ++i) { - ByteBuffer bf = ByteBuffer.allocate(bytePerVec); - // binary vector doesn't care endian since each byte is independent - // fp16/bf16/int8 vector is sensitive to endian because each dim occupies 1~2 bytes, - // milvus server stores fp16/bf16/int8 vector as little endian - bf.order(ByteOrder.LITTLE_ENDIAN); - bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray()); - packData.add(bf); + List bytePackData = new ArrayList<>(); + if (dim > 0) { + int bytePerVec = checkDim(dt, data, dim); + int count = data.size() / bytePerVec; + for (int i = 0; i < count; ++i) { + ByteBuffer bf = ByteBuffer.allocate(bytePerVec); + // binary vector doesn't care endian since each byte is independent + // fp16/bf16/int8 vector is sensitive to endian because each dim occupies 1~2 bytes, + // milvus server stores fp16/bf16/int8 vector as little endian + bf.order(ByteOrder.LITTLE_ENDIAN); + bf.put(data.substring(i * bytePerVec, (i + 1) * bytePerVec).toByteArray()); + bytePackData.add(bf); + } } - return packData; + packData = bytePackData; + break; } case SparseFloatVector: { // in Java sdk, each sparse vector is pairs of long+float // in server side, each sparse vector is stored as uint+float (8 bytes) // don't use sparseArray.getDim() because the dim is the max index of each rows SparseFloatArray sparseArray = vector.getSparseFloatVector(); - List> packData = new ArrayList<>(); + List> sparsePackData = new ArrayList<>(); for (int i = 0; i < sparseArray.getContentsCount(); ++i) { ByteString bs = sparseArray.getContents(i); ByteBuffer bf = ByteBuffer.wrap(bs.toByteArray()); SortedMap sparse = ParamUtils.decodeSparseFloatVector(bf); - packData.add(sparse); + sparsePackData.add(sparse); } - return packData; + packData = sparsePackData; + break; } default: return new ArrayList<>(); } + + // Handle nullable vectors - insert null values at positions where validData is false + if (validData != null && !validData.isEmpty()) { + List result = new ArrayList<>(); + int dataIdx = 0; + for (Boolean valid : validData) { + if (valid) { + result.add(packData.get(dataIdx++)); + } else { + result.add(null); + } + } + return result; + } + + return packData; } private List getScalarData(DataType dt, ScalarField scalar, List validData) { @@ -404,7 +439,7 @@ private List getStructData(StructArrayField field, String fieldName) { } else if (fd.getType() == DataType.ArrayOfVector) { VectorArray vecArr = fd.getVectors().getVectorArray(); for (VectorField vf : vecArr.getDataList()) { - List vector = getVectorData(vecArr.getElementType(), vf); + List vector = getVectorData(vecArr.getElementType(), vf, null); column.add(vector); } rowCount = column.size(); diff --git a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java index f653ab138..d11a49352 100644 --- a/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java +++ b/sdk-core/src/main/java/io/milvus/v2/service/collection/CollectionService.java @@ -262,6 +262,7 @@ public Void addCollectionField(MilvusServiceGrpc.MilvusServiceBlockingStub block String dbName = request.getDatabaseName(); String collectionName = request.getCollectionName(); String title = String.format("Add field to collection: '%s' in database: '%s'", collectionName, dbName); + AddCollectionFieldRequest.Builder builder = AddCollectionFieldRequest.newBuilder() .setCollectionName(collectionName); if (StringUtils.isNotEmpty(dbName)) { @@ -269,7 +270,7 @@ public Void addCollectionField(MilvusServiceGrpc.MilvusServiceBlockingStub block } CreateCollectionReq.FieldSchema fieldSchema = SchemaUtils.convertFieldReqToFieldSchema(request); - FieldSchema grpcFieldSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema); + FieldSchema grpcFieldSchema = SchemaUtils.convertToGrpcFieldSchema(fieldSchema, true); builder.setSchema(grpcFieldSchema.toByteString()); Status response = blockingStub.addCollectionField(builder.build()); @@ -645,4 +646,5 @@ private void WaitForLoadCollection(MilvusServiceGrpc.MilvusServiceBlockingStub b } } } + } diff --git a/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java b/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java index cd6f8e44a..74e2c61f2 100644 --- a/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java +++ b/sdk-core/src/main/java/io/milvus/v2/utils/SchemaUtils.java @@ -50,9 +50,19 @@ public static void checkNullEmptyString(String target, String title) { } public static FieldSchema convertToGrpcFieldSchema(CreateCollectionReq.FieldSchema fieldSchema) { + return convertToGrpcFieldSchema(fieldSchema, false); + } + + public static FieldSchema convertToGrpcFieldSchema(CreateCollectionReq.FieldSchema fieldSchema, boolean forAddField) { checkNullEmptyString(fieldSchema.getName(), "Field name"); DataType dType = DataType.valueOf(fieldSchema.getDataType().name()); + + // Vector field must be nullable when adding to existing collection + if (forAddField && ParamUtils.isVectorDataType(dType) && !fieldSchema.getIsNullable()) { + throw new MilvusClientException(ErrorCode.INVALID_PARAMS, + "Vector field must be nullable when adding to existing collection, field name: " + fieldSchema.getName()); + } FieldSchema.Builder builder = FieldSchema.newBuilder() .setName(fieldSchema.getName()) .setDescription(fieldSchema.getDescription()) diff --git a/sdk-core/src/test/java/io/milvus/v2/service/vector/NullableVectorTest.java b/sdk-core/src/test/java/io/milvus/v2/service/vector/NullableVectorTest.java new file mode 100644 index 000000000..5c4fc8c98 --- /dev/null +++ b/sdk-core/src/test/java/io/milvus/v2/service/vector/NullableVectorTest.java @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.milvus.v2.service.vector; + +import com.google.gson.JsonNull; +import com.google.gson.JsonObject; +import io.milvus.common.utils.Float16Utils; +import io.milvus.common.utils.JsonUtils; +import io.milvus.v2.client.ConnectConfig; +import io.milvus.v2.client.MilvusClientV2; +import io.milvus.v2.common.DataType; +import io.milvus.v2.common.IndexParam; +import io.milvus.v2.service.collection.request.*; +import io.milvus.v2.service.index.request.CreateIndexReq; +import io.milvus.v2.service.vector.request.*; +import io.milvus.v2.service.vector.request.data.*; +import io.milvus.v2.service.vector.response.*; +import org.junit.jupiter.api.*; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.milvus.MilvusContainer; + +import java.nio.ByteBuffer; +import java.util.*; + +@Testcontainers(disabledWithoutDocker = true) +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class NullableVectorTest { + + @Container + private static final MilvusContainer milvus = new MilvusContainer(io.milvus.TestUtils.MilvusDockerImageID) + .withEnv("DEPLOY_MODE", "STANDALONE"); + + private static MilvusClientV2 client; + private static final int DIMENSION = 8; + private static final Random RANDOM = new Random(); + private static final String COLLECTION_PREFIX = "test_nullable_vec_"; + + // Vector type configurations + private static final List VECTOR_TYPES = Arrays.asList( + new VectorTypeConfig("float_vector", DataType.FloatVector, DIMENSION, "L2", "FLAT"), + new VectorTypeConfig("binary_vector", DataType.BinaryVector, DIMENSION * 8, "HAMMING", "BIN_FLAT"), + new VectorTypeConfig("float16_vector", DataType.Float16Vector, DIMENSION, "L2", "FLAT"), + new VectorTypeConfig("bfloat16_vector", DataType.BFloat16Vector, DIMENSION, "L2", "FLAT"), + new VectorTypeConfig("sparse_float_vector", DataType.SparseFloatVector, 0, "IP", "SPARSE_INVERTED_INDEX"), + new VectorTypeConfig("int8_vector", DataType.Int8Vector, DIMENSION, "L2", "HNSW") + ); + + static class VectorTypeConfig { + String name; + DataType dataType; + int dimension; + String metricType; + String indexType; + + VectorTypeConfig(String name, DataType dataType, int dimension, String metricType, String indexType) { + this.name = name; + this.dataType = dataType; + this.dimension = dimension; + this.metricType = metricType; + this.indexType = indexType; + } + } + + @BeforeAll + public static void setUp() { + ConnectConfig config = ConnectConfig.builder() + .uri(milvus.getEndpoint()) + .build(); + client = new MilvusClientV2(config); + } + + @AfterAll + public static void tearDown() { + if (client != null) { + // Cleanup collections + for (VectorTypeConfig vtc : VECTOR_TYPES) { + String collectionName = COLLECTION_PREFIX + vtc.name; + try { + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + } catch (Exception ignored) { + } + } + try { + client.close(5L); + } catch (InterruptedException ignored) { + } + } + } + + private Object generateVector(VectorTypeConfig config) { + switch (config.dataType) { + case FloatVector: { + List vector = new ArrayList<>(); + for (int i = 0; i < config.dimension; i++) { + vector.add(RANDOM.nextFloat()); + } + return vector; + } + case BinaryVector: { + int byteCount = config.dimension / 8; + byte[] bytes = new byte[byteCount]; + RANDOM.nextBytes(bytes); + ByteBuffer buf = ByteBuffer.wrap(bytes); + return buf; + } + case Float16Vector: { + // Generate float values first, then convert to fp16 + List floatVec = new ArrayList<>(); + for (int i = 0; i < config.dimension; i++) { + floatVec.add(RANDOM.nextFloat()); + } + ByteBuffer buf = Float16Utils.f32VectorToFp16Buffer(floatVec); + buf.rewind(); + return buf; + } + case BFloat16Vector: { + // Generate float values first, then convert to bf16 + List floatVec = new ArrayList<>(); + for (int i = 0; i < config.dimension; i++) { + floatVec.add(RANDOM.nextFloat()); + } + ByteBuffer buf = Float16Utils.f32VectorToBf16Buffer(floatVec); + buf.rewind(); + return buf; + } + case Int8Vector: { + ByteBuffer buf = ByteBuffer.allocate(config.dimension); + for (int i = 0; i < config.dimension; i++) { + buf.put((byte) (RANDOM.nextInt(256) - 128)); + } + buf.rewind(); + return buf; + } + case SparseFloatVector: { + SortedMap sparse = new TreeMap<>(); + int nnz = RANDOM.nextInt(10) + 2; + for (int i = 0; i < nnz; i++) { + sparse.put((long) i, RANDOM.nextFloat()); + } + return sparse; + } + default: + throw new IllegalArgumentException("Unknown vector type: " + config.dataType); + } + } + + @Test + @Order(1) + void testNullableFloatVector() throws Exception { + testNullableVector(VECTOR_TYPES.get(0)); + } + + @Test + @Order(2) + void testNullableBinaryVector() throws Exception { + testNullableVector(VECTOR_TYPES.get(1)); + } + + @Test + @Order(3) + void testNullableFloat16Vector() throws Exception { + testNullableVector(VECTOR_TYPES.get(2)); + } + + @Test + @Order(4) + void testNullableBFloat16Vector() throws Exception { + testNullableVector(VECTOR_TYPES.get(3)); + } + + @Test + @Order(5) + void testNullableSparseFloatVector() throws Exception { + testNullableVector(VECTOR_TYPES.get(4)); + } + + @Test + @Order(6) + void testNullableInt8Vector() throws Exception { + testNullableVector(VECTOR_TYPES.get(5)); + } + + /** + * Test that adding a non-nullable vector field to existing collection should fail. + */ + @Test + @Order(7) + void testAddVectorFieldMustBeNullable() throws Exception { + String collectionName = COLLECTION_PREFIX + "add_field_test"; + System.out.println("\n[Test] add_vector_field_must_be_nullable"); + + // Drop if exists + try { + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + } catch (Exception ignored) { + } + + // Create collection with one vector field + CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder() + .build(); + schema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(true) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("embedding") + .dataType(DataType.FloatVector) + .dimension(DIMENSION) + .build()); + + client.createCollection(CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(schema) + .build()); + + // Try to add a non-nullable vector field - should fail + Exception exception = Assertions.assertThrows(Exception.class, () -> { + client.addCollectionField(AddCollectionFieldReq.builder() + .collectionName(collectionName) + .fieldName("embedding_v2") + .dataType(DataType.FloatVector) + .dimension(DIMENSION) + .isNullable(false) // Non-nullable should fail + .build()); + }); + System.out.println(" Expected error: " + exception.getMessage()); + Assertions.assertTrue(exception.getMessage().toLowerCase().contains("nullable"), + "Error should mention nullable requirement"); + + // Cleanup + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + + System.out.println(" PASSED"); + } + + private void testNullableVector(VectorTypeConfig config) throws Exception { + String collectionName = COLLECTION_PREFIX + config.name; + System.out.println("\n[Test] " + config.name); + + // Drop if exists + try { + client.dropCollection(DropCollectionReq.builder() + .collectionName(collectionName) + .build()); + } catch (Exception ignored) { + } + + // Create schema with nullable vector field + CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder() + .build(); + schema.addField(AddFieldReq.builder() + .fieldName("id") + .dataType(DataType.Int64) + .isPrimaryKey(true) + .build()); + schema.addField(AddFieldReq.builder() + .fieldName("name") + .dataType(DataType.VarChar) + .maxLength(100) + .build()); + + // Add nullable vector field + AddFieldReq.AddFieldReqBuilder fieldBuilder = AddFieldReq.builder() + .fieldName("embedding") + .dataType(config.dataType) + .isNullable(true); + if (config.dimension > 0) { + fieldBuilder.dimension(config.dimension); + } + schema.addField(fieldBuilder.build()); + + // Create collection + client.createCollection(CreateCollectionReq.builder() + .collectionName(collectionName) + .collectionSchema(schema) + .build()); + + // Create index + IndexParam indexParam = IndexParam.builder() + .fieldName("embedding") + .metricType(IndexParam.MetricType.valueOf(config.metricType)) + .indexType(IndexParam.IndexType.valueOf(config.indexType)) + .build(); + client.createIndex(CreateIndexReq.builder() + .collectionName(collectionName) + .indexParams(Collections.singletonList(indexParam)) + .build()); + + // Load collection + client.loadCollection(LoadCollectionReq.builder() + .collectionName(collectionName) + .build()); + + // Prepare test data: 100 rows, ~50% null vectors + int totalRows = 100; + int nullPercent = 50; + List data = new ArrayList<>(); + int expectedNullCount = 0; + int expectedValidCount = 0; + + for (int i = 1; i <= totalRows; i++) { + JsonObject row = new JsonObject(); + row.addProperty("id", (long) i); + row.addProperty("name", "row_" + i); + + boolean isNull = RANDOM.nextInt(100) < nullPercent; + if (isNull) { + row.add("embedding", JsonNull.INSTANCE); + expectedNullCount++; + } else { + Object vector = generateVector(config); + expectedValidCount++; + if (config.dataType == DataType.FloatVector) { + row.add("embedding", JsonUtils.toJsonTree(vector)); + } else if (config.dataType == DataType.SparseFloatVector) { + row.add("embedding", JsonUtils.toJsonTree(vector)); + } else { + // For binary/fp16/bf16/int8, encode as base64 or byte array + ByteBuffer buf = (ByteBuffer) vector; + byte[] bytes = new byte[buf.remaining()]; + buf.get(bytes); + buf.rewind(); + row.add("embedding", JsonUtils.toJsonTree(bytes)); + } + } + data.add(row); + } + + // Insert data + InsertResp insertResp = client.insert(InsertReq.builder() + .collectionName(collectionName) + .data(data) + .build()); + Assertions.assertEquals(totalRows, insertResp.getInsertCnt()); + + // Wait for data to be available + Thread.sleep(1000); + + // Query all data + QueryResp queryResp = client.query(QueryReq.builder() + .collectionName(collectionName) + .filter("id >= 0") + .outputFields(Arrays.asList("id", "name", "embedding")) + .limit(totalRows + 10) + .build()); + + List results = queryResp.getQueryResults(); + Assertions.assertEquals(totalRows, results.size()); + + int nullCount = 0; + int validCount = 0; + for (QueryResp.QueryResult result : results) { + Map entity = result.getEntity(); + Object embedding = entity.get("embedding"); + if (embedding == null) { + nullCount++; + } else { + validCount++; + } + } + Assertions.assertEquals(expectedNullCount, nullCount); + Assertions.assertEquals(expectedValidCount, validCount); + + // Search - should only return non-null vectors + Object searchVector = generateVector(config); + BaseVector searchVec; + if (config.dataType == DataType.FloatVector) { + searchVec = new FloatVec((List) searchVector); + } else if (config.dataType == DataType.BinaryVector) { + searchVec = new BinaryVec((ByteBuffer) searchVector); + } else if (config.dataType == DataType.Float16Vector) { + searchVec = new Float16Vec((ByteBuffer) searchVector); + } else if (config.dataType == DataType.BFloat16Vector) { + searchVec = new BFloat16Vec((ByteBuffer) searchVector); + } else if (config.dataType == DataType.SparseFloatVector) { + searchVec = new SparseFloatVec((SortedMap) searchVector); + } else { + // Int8Vector + searchVec = new Int8Vec((ByteBuffer) searchVector); + } + + int searchLimit = Math.min(50, expectedValidCount); + SearchResp searchResp = client.search(SearchReq.builder() + .collectionName(collectionName) + .data(Collections.singletonList(searchVec)) + .annsField("embedding") + .topK(searchLimit) + .outputFields(Arrays.asList("id", "name", "embedding")) + .build()); + + List> searchResults = searchResp.getSearchResults(); + Assertions.assertFalse(searchResults.isEmpty()); + List hits = searchResults.get(0); + + // Search should only return non-null vectors + Assertions.assertTrue(hits.size() <= expectedValidCount, "Search should return at most expectedValidCount results"); + for (SearchResp.SearchResult hit : hits) { + Map entity = hit.getEntity(); + Object embedding = entity.get("embedding"); + Assertions.assertNotNull(embedding, "Search should not return null vectors"); + } + + System.out.println(" PASSED"); + } +}