Skip to content

Commit ae53b54

Browse files
committed
Support nullable vectors in Java SDK
Signed-off-by: marcelo-cjl <[email protected]>
1 parent 30117fd commit ae53b54

File tree

6 files changed

+832
-30
lines changed

6 files changed

+832
-30
lines changed
Lines changed: 299 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,299 @@
1+
package io.milvus.v2;
2+
3+
import com.google.gson.JsonNull;
4+
import com.google.gson.JsonObject;
5+
import io.milvus.common.utils.JsonUtils;
6+
import io.milvus.v2.client.ConnectConfig;
7+
import io.milvus.v2.client.MilvusClientV2;
8+
import io.milvus.v2.common.DataType;
9+
import io.milvus.v2.common.IndexParam;
10+
import io.milvus.v2.service.collection.request.AddCollectionFieldReq;
11+
import io.milvus.v2.service.collection.request.AddFieldReq;
12+
import io.milvus.v2.service.collection.request.CreateCollectionReq;
13+
import io.milvus.v2.service.collection.request.DropCollectionReq;
14+
import io.milvus.v2.service.collection.request.LoadCollectionReq;
15+
import io.milvus.v2.service.collection.request.ReleaseCollectionReq;
16+
import io.milvus.v2.service.index.request.CreateIndexReq;
17+
import io.milvus.v2.service.vector.request.InsertReq;
18+
import io.milvus.v2.service.vector.request.QueryReq;
19+
import io.milvus.v2.service.vector.request.SearchReq;
20+
import io.milvus.v2.service.vector.request.data.FloatVec;
21+
import io.milvus.v2.service.vector.response.InsertResp;
22+
import io.milvus.v2.service.vector.response.QueryResp;
23+
import io.milvus.v2.service.vector.response.SearchResp;
24+
25+
import java.util.*;
26+
27+
public class NullableVectorExample {
28+
private static final int DIMENSION = 8;
29+
private static final Random RANDOM = new Random();
30+
31+
private static List<Float> generateFloatVector() {
32+
List<Float> vector = new ArrayList<>();
33+
for (int i = 0; i < DIMENSION; i++) {
34+
vector.add(RANDOM.nextFloat());
35+
}
36+
return vector;
37+
}
38+
39+
public static void main(String[] args) throws InterruptedException {
40+
ConnectConfig config = ConnectConfig.builder()
41+
.uri("http://localhost:19530")
42+
.build();
43+
MilvusClientV2 client = new MilvusClientV2(config);
44+
System.out.println("Connected to Milvus\n");
45+
46+
insertNullVectors(client);
47+
addNullableVectorField(client);
48+
49+
client.close(5L);
50+
System.out.println("Done!");
51+
}
52+
53+
private static void insertNullVectors(MilvusClientV2 client) throws InterruptedException {
54+
String collectionName = "java_sdk_example_insert_null_vectors";
55+
System.out.println("=== Demo 1: Insert null vectors ===");
56+
57+
// Drop collection if exists
58+
client.dropCollection(DropCollectionReq.builder()
59+
.collectionName(collectionName)
60+
.build());
61+
62+
// Create collection with nullable vector field
63+
CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
64+
.build();
65+
schema.addField(AddFieldReq.builder()
66+
.fieldName("id")
67+
.dataType(DataType.Int64)
68+
.isPrimaryKey(true)
69+
.autoID(false)
70+
.build());
71+
schema.addField(AddFieldReq.builder()
72+
.fieldName("name")
73+
.dataType(DataType.VarChar)
74+
.maxLength(100)
75+
.build());
76+
schema.addField(AddFieldReq.builder()
77+
.fieldName("embedding")
78+
.dataType(DataType.FloatVector)
79+
.dimension(DIMENSION)
80+
.isNullable(true) // Enable nullable for vector field
81+
.build());
82+
83+
client.createCollection(CreateCollectionReq.builder()
84+
.collectionName(collectionName)
85+
.collectionSchema(schema)
86+
.build());
87+
System.out.println("Created collection with nullable vector field");
88+
89+
// Create index
90+
IndexParam indexParam = IndexParam.builder()
91+
.fieldName("embedding")
92+
.metricType(IndexParam.MetricType.L2)
93+
.indexType(IndexParam.IndexType.FLAT)
94+
.build();
95+
client.createIndex(CreateIndexReq.builder()
96+
.collectionName(collectionName)
97+
.indexParams(Collections.singletonList(indexParam))
98+
.build());
99+
100+
// Load collection
101+
client.loadCollection(LoadCollectionReq.builder()
102+
.collectionName(collectionName)
103+
.build());
104+
105+
// Prepare test data: 100 rows, ~50% null vectors
106+
int totalRows = 100;
107+
int nullPercent = 50;
108+
List<JsonObject> data = new ArrayList<>();
109+
int nullCount = 0;
110+
int validCount = 0;
111+
112+
for (int i = 1; i <= totalRows; i++) {
113+
JsonObject row = new JsonObject();
114+
row.addProperty("id", (long) i);
115+
row.addProperty("name", "item_" + i);
116+
117+
boolean isNull = RANDOM.nextInt(100) < nullPercent;
118+
if (isNull) {
119+
row.add("embedding", JsonNull.INSTANCE);
120+
nullCount++;
121+
} else {
122+
row.add("embedding", JsonUtils.toJsonTree(generateFloatVector()));
123+
validCount++;
124+
}
125+
data.add(row);
126+
}
127+
128+
// Insert data
129+
InsertResp insertResp = client.insert(InsertReq.builder()
130+
.collectionName(collectionName)
131+
.data(data)
132+
.build());
133+
System.out.println("Inserted " + insertResp.getInsertCnt() + " rows: " + validCount + " valid, " + nullCount + " null");
134+
135+
Thread.sleep(1000);
136+
137+
// Query all data
138+
QueryResp queryResp = client.query(QueryReq.builder()
139+
.collectionName(collectionName)
140+
.filter("id >= 0")
141+
.outputFields(Arrays.asList("id", "embedding"))
142+
.limit(totalRows + 10)
143+
.build());
144+
145+
int queryNullCount = 0;
146+
int queryValidCount = 0;
147+
for (QueryResp.QueryResult result : queryResp.getQueryResults()) {
148+
Object embedding = result.getEntity().get("embedding");
149+
if (embedding == null) {
150+
queryNullCount++;
151+
} else {
152+
queryValidCount++;
153+
}
154+
}
155+
System.out.println("Query result: " + queryValidCount + " valid, " + queryNullCount + " null");
156+
157+
// Search - only returns non-null vectors
158+
SearchResp searchResp = client.search(SearchReq.builder()
159+
.collectionName(collectionName)
160+
.data(Collections.singletonList(new FloatVec(generateFloatVector())))
161+
.annsField("embedding")
162+
.topK(10)
163+
.outputFields(Arrays.asList("id", "embedding"))
164+
.build());
165+
166+
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
167+
if (!searchResults.isEmpty()) {
168+
System.out.println("Search returned " + searchResults.get(0).size() + " hits (only non-null vectors)");
169+
}
170+
171+
// Cleanup
172+
client.dropCollection(DropCollectionReq.builder()
173+
.collectionName(collectionName)
174+
.build());
175+
System.out.println("Dropped collection\n");
176+
}
177+
178+
private static void addNullableVectorField(MilvusClientV2 client) throws InterruptedException {
179+
String collectionName = "java_sdk_example_add_vector_field";
180+
System.out.println("=== Demo 2: Add nullable vector field to existing collection ===");
181+
182+
// Drop collection if exists
183+
client.dropCollection(DropCollectionReq.builder()
184+
.collectionName(collectionName)
185+
.build());
186+
187+
// Create collection with one vector field (Milvus requires at least one)
188+
CreateCollectionReq.CollectionSchema schema = CreateCollectionReq.CollectionSchema.builder()
189+
.build();
190+
schema.addField(AddFieldReq.builder()
191+
.fieldName("id")
192+
.dataType(DataType.Int64)
193+
.isPrimaryKey(true)
194+
.autoID(false)
195+
.build());
196+
schema.addField(AddFieldReq.builder()
197+
.fieldName("name")
198+
.dataType(DataType.VarChar)
199+
.maxLength(100)
200+
.build());
201+
schema.addField(AddFieldReq.builder()
202+
.fieldName("embedding_v1")
203+
.dataType(DataType.FloatVector)
204+
.dimension(DIMENSION)
205+
.build());
206+
207+
client.createCollection(CreateCollectionReq.builder()
208+
.collectionName(collectionName)
209+
.collectionSchema(schema)
210+
.build());
211+
System.out.println("Created collection with one vector field");
212+
213+
// Create index and load
214+
IndexParam indexParam = IndexParam.builder()
215+
.fieldName("embedding_v1")
216+
.metricType(IndexParam.MetricType.L2)
217+
.indexType(IndexParam.IndexType.FLAT)
218+
.build();
219+
client.createIndex(CreateIndexReq.builder()
220+
.collectionName(collectionName)
221+
.indexParams(Collections.singletonList(indexParam))
222+
.build());
223+
client.loadCollection(LoadCollectionReq.builder()
224+
.collectionName(collectionName)
225+
.build());
226+
227+
// Insert some data first
228+
List<JsonObject> data = new ArrayList<>();
229+
for (int i = 1; i <= 10; i++) {
230+
JsonObject row = new JsonObject();
231+
row.addProperty("id", (long) i);
232+
row.addProperty("name", "item_" + i);
233+
row.add("embedding_v1", JsonUtils.toJsonTree(generateFloatVector()));
234+
data.add(row);
235+
}
236+
client.insert(InsertReq.builder()
237+
.collectionName(collectionName)
238+
.data(data)
239+
.build());
240+
System.out.println("Inserted 10 rows");
241+
242+
// Release before adding field
243+
client.releaseCollection(ReleaseCollectionReq.builder()
244+
.collectionName(collectionName)
245+
.build());
246+
247+
// Add a second nullable vector field to existing collection
248+
client.addCollectionField(AddCollectionFieldReq.builder()
249+
.collectionName(collectionName)
250+
.fieldName("embedding_v2")
251+
.dataType(DataType.FloatVector)
252+
.dimension(DIMENSION)
253+
.isNullable(true) // Must be nullable when adding to existing collection
254+
.build());
255+
System.out.println("Added nullable vector field 'embedding_v2'");
256+
257+
// Create index for the new field
258+
IndexParam newIndexParam = IndexParam.builder()
259+
.fieldName("embedding_v2")
260+
.metricType(IndexParam.MetricType.L2)
261+
.indexType(IndexParam.IndexType.FLAT)
262+
.build();
263+
client.createIndex(CreateIndexReq.builder()
264+
.collectionName(collectionName)
265+
.indexParams(Collections.singletonList(newIndexParam))
266+
.build());
267+
268+
// Load collection
269+
client.loadCollection(LoadCollectionReq.builder()
270+
.collectionName(collectionName)
271+
.build());
272+
273+
Thread.sleep(1000);
274+
275+
// Query to verify old rows have null for the new field
276+
QueryResp queryResp = client.query(QueryReq.builder()
277+
.collectionName(collectionName)
278+
.filter("id >= 0")
279+
.outputFields(Arrays.asList("id", "embedding_v1", "embedding_v2"))
280+
.limit(10)
281+
.build());
282+
283+
System.out.println("Query result (old rows have null for new field):");
284+
for (QueryResp.QueryResult result : queryResp.getQueryResults()) {
285+
Map<String, Object> entity = result.getEntity();
286+
long id = (Long) entity.get("id");
287+
Object v1 = entity.get("embedding_v1");
288+
Object v2 = entity.get("embedding_v2");
289+
System.out.println(" id=" + id + ", embedding_v1=" + (v1 == null ? "null" : "has value")
290+
+ ", embedding_v2=" + (v2 == null ? "null" : "has value"));
291+
}
292+
293+
// Cleanup
294+
client.dropCollection(DropCollectionReq.builder()
295+
.collectionName(collectionName)
296+
.build());
297+
System.out.println("Dropped collection\n");
298+
}
299+
}

sdk-core/src/main/java/io/milvus/param/ParamUtils.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,16 @@ public static FieldData genFieldData(String fieldName, DataType dataType, DataTy
12041204

12051205
FieldData.Builder builder = FieldData.newBuilder();
12061206
if (isVectorDataType(dataType)) {
1207+
if (isNullable) {
1208+
List<Object> tempObjects = new ArrayList<>();
1209+
for (Object obj : objects) {
1210+
builder.addValidData(obj != null);
1211+
if (obj != null) {
1212+
tempObjects.add(obj);
1213+
}
1214+
}
1215+
objects = tempObjects;
1216+
}
12071217
VectorField vectorField = genVectorField(dataType, objects);
12081218
return builder.setFieldName(fieldName).setType(dataType).setVectors(vectorField).build();
12091219
} else {
@@ -1228,6 +1238,22 @@ public static FieldData genFieldData(String fieldName, DataType dataType, DataTy
12281238

12291239
@SuppressWarnings("unchecked")
12301240
public static VectorField genVectorField(DataType dataType, List<?> objects) {
1241+
if (objects.isEmpty()) {
1242+
if (dataType == DataType.FloatVector) {
1243+
return VectorField.newBuilder().setDim(0).setFloatVector(FloatArray.newBuilder().build()).build();
1244+
} else if (dataType == DataType.BinaryVector) {
1245+
return VectorField.newBuilder().setDim(0).setBinaryVector(ByteString.EMPTY).build();
1246+
} else if (dataType == DataType.Float16Vector) {
1247+
return VectorField.newBuilder().setDim(0).setFloat16Vector(ByteString.EMPTY).build();
1248+
} else if (dataType == DataType.BFloat16Vector) {
1249+
return VectorField.newBuilder().setDim(0).setBfloat16Vector(ByteString.EMPTY).build();
1250+
} else if (dataType == DataType.Int8Vector) {
1251+
return VectorField.newBuilder().setDim(0).setInt8Vector(ByteString.EMPTY).build();
1252+
} else if (dataType == DataType.SparseFloatVector) {
1253+
return VectorField.newBuilder().setDim(0).setSparseFloatVector(SparseFloatArray.newBuilder().build()).build();
1254+
}
1255+
}
1256+
12311257
if (dataType == DataType.FloatVector) {
12321258
List<Float> floats = new ArrayList<>();
12331259
// each object is List<Float>

0 commit comments

Comments
 (0)