Skip to content

Commit 99858ee

Browse files
committed
Support search by id
Signed-off-by: yhmo <yihua.mo@zilliz.com>
1 parent 7711927 commit 99858ee

File tree

5 files changed

+98
-24
lines changed

5 files changed

+98
-24
lines changed

docker-compose.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ version: '3.5'
33
services:
44
standalone:
55
container_name: milvus-javasdk-standalone-1
6-
image: milvusdb/milvus:v2.6.7
6+
image: milvusdb/milvus:v2.6.9
77
command: [ "milvus", "run", "standalone" ]
88
environment:
99
- COMMON_STORAGETYPE=local
@@ -24,7 +24,7 @@ services:
2424

2525
standaloneslave:
2626
container_name: milvus-javasdk-standalone-2
27-
image: milvusdb/milvus:v2.6.7
27+
image: milvusdb/milvus:v2.6.9
2828
command: [ "milvus", "run", "standalone" ]
2929
environment:
3030
- COMMON_STORAGETYPE=local

sdk-core/src/main/java/io/milvus/v2/service/vector/request/SearchReq.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ public class SearchReq {
4040
private String filter;
4141
private List<String> outputFields;
4242
private List<BaseVector> data;
43+
private List<Object> ids;
4344
private long offset;
4445
private long limit;
4546
private int roundDecimal;
@@ -78,6 +79,7 @@ private SearchReq(SearchReqBuilder builder) {
7879
this.filter = builder.filter;
7980
this.outputFields = builder.outputFields;
8081
this.data = builder.data;
82+
this.ids = builder.ids;
8183
this.offset = builder.offset;
8284
this.limit = builder.limit;
8385
this.roundDecimal = builder.roundDecimal;
@@ -171,6 +173,10 @@ public void setData(List<BaseVector> data) {
171173
this.data = data;
172174
}
173175

176+
public List<Object> getIds() {
177+
return ids;
178+
}
179+
174180
public long getOffset() {
175181
return offset;
176182
}
@@ -299,7 +305,7 @@ public String toString() {
299305
", topK=" + topK +
300306
", filter='" + filter + '\'' +
301307
", outputFields=" + outputFields +
302-
", data=" + data +
308+
(ids == null || ids.isEmpty() ? ", data=" + data : ", ids=" + ids) +
303309
", offset=" + offset +
304310
", limit=" + limit +
305311
", roundDecimal=" + roundDecimal +
@@ -332,6 +338,7 @@ public static class SearchReqBuilder {
332338
private String filter;
333339
private List<String> outputFields = new ArrayList<>(); // default value
334340
private List<BaseVector> data = new ArrayList<>(); // default value
341+
private List<Object> ids = new ArrayList<>();
335342
private long offset;
336343
private long limit = 0L; // default value
337344
private int roundDecimal = -1; // default value
@@ -399,6 +406,11 @@ public SearchReqBuilder data(List<BaseVector> data) {
399406
return this;
400407
}
401408

409+
public SearchReqBuilder ids(List<Object> ids) {
410+
this.ids = ids;
411+
return this;
412+
}
413+
402414
public SearchReqBuilder offset(long offset) {
403415
this.offset = offset;
404416
return this;

sdk-core/src/main/java/io/milvus/v2/utils/VectorUtils.java

Lines changed: 61 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,65 @@ private static ByteString convertPlaceholder(List<Object> data, PlaceholderType
172172
}
173173
}
174174

175+
private static void convertSearchTarget(SearchReq request, SearchRequest.Builder builder) {
176+
// prepare target, the input could be:
177+
// 1. vectors or string list for doc-in-doc-out
178+
// 2. ids list for search by primary keys
179+
List<BaseVector> vectors = request.getData();
180+
List<Object> ids = request.getIds();
181+
boolean vectorsIsEmpty = (vectors == null || vectors.isEmpty());
182+
boolean idsIsEmpty = (ids == null || ids.isEmpty());
183+
if (vectorsIsEmpty && idsIsEmpty) {
184+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Require either ids or vectors, but both are empty");
185+
}
186+
if (!vectorsIsEmpty && !idsIsEmpty) {
187+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Require either ids or vectors, but both are provided");
188+
}
189+
190+
if (!vectorsIsEmpty) {
191+
// the elements must be all-vector or all-string
192+
PlaceholderType plType = vectors.get(0).getPlaceholderType();
193+
List<Object> data = new ArrayList<>();
194+
for (BaseVector vector : vectors) {
195+
if (vector.getPlaceholderType() != plType) {
196+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS,
197+
"Different types of target vectors in a search request is not allowed.");
198+
}
199+
data.add(vector.getData());
200+
}
201+
202+
ByteString byteStr = convertPlaceholder(data, plType);
203+
builder.setPlaceholderGroup(byteStr);
204+
builder.setNq(vectors.size());
205+
} else {
206+
Object val = ids.get(0);
207+
if (val instanceof String) {
208+
StringArray.Builder strBuilder = StringArray.newBuilder();
209+
for (Object obj : ids) {
210+
if (!(obj instanceof String)) {
211+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS,
212+
"All IDs must be of type String if the first ID is a String.");
213+
}
214+
strBuilder.addData((String) obj);
215+
}
216+
builder.setIds(IDs.newBuilder().setStrId(strBuilder.build()).build());
217+
} else if (val instanceof Long) {
218+
LongArray.Builder longBuilder = LongArray.newBuilder();
219+
for (Object obj : ids) {
220+
if (!(obj instanceof Long)) {
221+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS,
222+
"All IDs must be of type Long if the first ID is a Long.");
223+
}
224+
longBuilder.addData((Long) obj);
225+
}
226+
builder.setIds(IDs.newBuilder().setIntId(longBuilder.build()).build());
227+
} else {
228+
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "ID type must be String or Long.");
229+
}
230+
builder.setNq(ids.size());
231+
}
232+
}
233+
175234
public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
176235
String dbName = request.getDatabaseName();
177236
String collectionName = request.getCollectionName();
@@ -185,26 +244,8 @@ public SearchRequest ConvertToGrpcSearchRequest(SearchReq request) {
185244
builder.setDbName(dbName);
186245
}
187246

188-
// prepare target, the input could be vectors or string list for doc-in-doc-out
189-
List<BaseVector> vectors = request.getData();
190-
if (vectors == null || vectors.isEmpty()) {
191-
throw new MilvusClientException(ErrorCode.INVALID_PARAMS, "Target data list of search request is empty.");
192-
}
193-
194-
// the elements must be all-vector or all-string
195-
PlaceholderType plType = vectors.get(0).getPlaceholderType();
196-
List<Object> data = new ArrayList<>();
197-
for (BaseVector vector : vectors) {
198-
if (vector.getPlaceholderType() != plType) {
199-
throw new MilvusClientException(ErrorCode.INVALID_PARAMS,
200-
"Different types of target vectors in a search request is not allowed.");
201-
}
202-
data.add(vector.getData());
203-
}
204-
205-
ByteString byteStr = convertPlaceholder(data, plType);
206-
builder.setPlaceholderGroup(byteStr);
207-
builder.setNq(vectors.size());
247+
// target vectors or ids
248+
convertSearchTarget(request, builder);
208249

209250
// search parameters
210251
// tries to fit the compatibility between v2.5.1 and older versions

sdk-core/src/test/java/io/milvus/TestUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public class TestUtils {
1111
private int dimension = 256;
1212
private static final Random RANDOM = new Random();
1313

14-
public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.7";
14+
public static final String MilvusDockerImageID = "milvusdb/milvus:v2.6.9";
1515

1616
public TestUtils(int dimension) {
1717
this.dimension = dimension;

sdk-core/src/test/java/io/milvus/v2/client/MilvusClientV2DockerTest.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,27 @@ void testFloat16Vectors() {
827827
// System.out.println("Output bfloat16 vector: " + outputVector);
828828
}
829829

830+
// search by ids
831+
{
832+
List<Object> ids = Arrays.asList(5L, 88L, 100L);
833+
SearchResp searchResp = client.search(SearchReq.builder()
834+
.collectionName(randomCollectionName)
835+
.annsField(bfloat16Field)
836+
.ids(ids)
837+
.limit(topk)
838+
.consistencyLevel(ConsistencyLevel.STRONG)
839+
.outputFields(Collections.singletonList(bfloat16Field))
840+
.build());
841+
List<List<SearchResp.SearchResult>> searchResults = searchResp.getSearchResults();
842+
Assertions.assertEquals(3, searchResults.size());
843+
for (int i = 0; i < searchResults.size(); i++) {
844+
List<SearchResp.SearchResult> results = searchResults.get(i);
845+
Assertions.assertEquals(topk, results.size());
846+
SearchResp.SearchResult firstResult = results.get(0);
847+
Assertions.assertEquals(ids.get(i), firstResult.getId());
848+
}
849+
}
850+
830851
// get row count
831852
long rowCount = getRowCount("", randomCollectionName);
832853
Assertions.assertEquals(count, rowCount);

0 commit comments

Comments
 (0)