11package com .datamate .rag .indexer .infrastructure .milvus ;
22
3+ import com .google .gson .*;
4+ import dev .langchain4j .data .embedding .Embedding ;
35import dev .langchain4j .data .segment .TextSegment ;
46import dev .langchain4j .model .embedding .EmbeddingModel ;
57import dev .langchain4j .store .embedding .EmbeddingStore ;
68import dev .langchain4j .store .embedding .milvus .MilvusEmbeddingStore ;
7- import io .milvus .client .MilvusClient ;
8- import io .milvus .client .MilvusServiceClient ;
9- import io .milvus .param .ConnectParam ;
9+ import io .milvus .common .clientenum .FunctionType ;
10+ import io .milvus .v2 .client .ConnectConfig ;
11+ import io .milvus .v2 .client .MilvusClientV2 ;
12+ import io .milvus .v2 .common .DataType ;
13+ import io .milvus .v2 .common .IndexParam ;
14+ import io .milvus .v2 .service .collection .request .AddFieldReq ;
15+ import io .milvus .v2 .service .collection .request .CreateCollectionReq ;
16+ import io .milvus .v2 .service .collection .request .HasCollectionReq ;
17+ import io .milvus .v2 .service .vector .request .InsertReq ;
1018import lombok .extern .slf4j .Slf4j ;
1119import org .springframework .beans .factory .annotation .Value ;
1220import org .springframework .stereotype .Component ;
1321
22+ import java .util .*;
23+
24+ import static dev .langchain4j .internal .Utils .randomUUID ;
25+
1426/**
1527 * Milvus 服务类
1628 *
@@ -24,13 +36,19 @@ public class MilvusService {
2436 private String milvusHost ;
2537 @ Value ("${datamate.rag.milvus-port:19530}" )
2638 private int milvusPort ;
39+ @ Value ("${datamate.rag.milvus-uri:http://milvus-standalone:19530}" )
40+ private String milvusUri ;
41+ private static final Gson GSON ;
42+
43+ static {
44+ GSON = (new GsonBuilder ()).setObjectToNumberStrategy (ToNumberPolicy .LONG_OR_DOUBLE ).create ();
45+ }
2746
28- private volatile MilvusClient milvusClient ;
47+ private volatile MilvusClientV2 milvusClient ;
2948
3049 public EmbeddingStore <TextSegment > embeddingStore (EmbeddingModel embeddingModel , String knowledgeBaseName ) {
3150 return MilvusEmbeddingStore .builder ()
32- .host (milvusHost )
33- .port (milvusPort )
51+ .uri (milvusUri )
3452 .collectionName (knowledgeBaseName )
3553 .dimension (embeddingModel .dimension ())
3654 .build ();
@@ -41,16 +59,15 @@ public EmbeddingStore<TextSegment> embeddingStore(EmbeddingModel embeddingModel,
4159 *
4260 * @return MilvusClient
4361 */
44- public MilvusClient getMilvusClient () {
62+ public MilvusClientV2 getMilvusClient () {
4563 if (milvusClient == null ) {
4664 synchronized (this ) {
4765 if (milvusClient == null ) {
4866 try {
49- ConnectParam connectParam = ConnectParam .newBuilder ()
50- .withHost (milvusHost )
51- .withPort (milvusPort )
67+ ConnectConfig connectConfig = ConnectConfig .builder ()
68+ .uri (milvusUri )
5269 .build ();
53- milvusClient = new MilvusServiceClient ( connectParam );
70+ milvusClient = new MilvusClientV2 ( connectConfig );
5471 log .info ("Milvus client connected successfully" );
5572 } catch (Exception e ) {
5673 log .error ("Milvus client connection failed: {}" , e .getMessage ());
@@ -61,4 +78,107 @@ public MilvusClient getMilvusClient() {
6178 }
6279 return milvusClient ;
6380 }
81+
82+
83+ public boolean hasCollection (String collectionName ) {
84+ HasCollectionReq request = HasCollectionReq .builder ().collectionName (collectionName ).build ();
85+ return getMilvusClient ().hasCollection (request );
86+ }
87+
88+ public void createCollection (String collectionName , int dimension ) {
89+ CreateCollectionReq .CollectionSchema schema = CreateCollectionReq .CollectionSchema .builder ()
90+ .build ();
91+ schema .addField (AddFieldReq .builder ()
92+ .fieldName ("id" )
93+ .dataType (DataType .VarChar )
94+ .maxLength (36 )
95+ .isPrimaryKey (true )
96+ .autoID (false )
97+ .build ());
98+ schema .addField (AddFieldReq .builder ()
99+ .fieldName ("text" )
100+ .dataType (DataType .VarChar )
101+ .maxLength (65535 )
102+ .enableAnalyzer (true )
103+ .build ());
104+ schema .addField (AddFieldReq .builder ()
105+ .fieldName ("metadata" )
106+ .dataType (DataType .JSON )
107+ .build ());
108+ schema .addField (AddFieldReq .builder ()
109+ .fieldName ("vector" )
110+ .dataType (DataType .FloatVector )
111+ .dimension (dimension )
112+ .build ());
113+ schema .addField (AddFieldReq .builder ()
114+ .fieldName ("sparse" )
115+ .dataType (DataType .SparseFloatVector )
116+ .build ());
117+ schema .addFunction (CreateCollectionReq .Function .builder ()
118+ .functionType (FunctionType .BM25 )
119+ .name ("text_bm25_emb" )
120+ .inputFieldNames (Collections .singletonList ("text" ))
121+ .outputFieldNames (Collections .singletonList ("sparse" ))
122+ .build ());
123+
124+ Map <String , Object > params = new HashMap <>();
125+ params .put ("inverted_index_algo" , "DAAT_MAXSCORE" );
126+ params .put ("bm25_k1" , 1.2 );
127+ params .put ("bm25_b" , 0.75 );
128+
129+ List <IndexParam > indexes = new ArrayList <>();
130+ indexes .add (IndexParam .builder ()
131+ .fieldName ("sparse" )
132+ .indexType (IndexParam .IndexType .SPARSE_INVERTED_INDEX )
133+ .metricType (IndexParam .MetricType .BM25 )
134+ .extraParams (params )
135+ .build ());
136+ indexes .add (IndexParam .builder ()
137+ .fieldName ("vector" )
138+ .indexType (IndexParam .IndexType .FLAT )
139+ .metricType (IndexParam .MetricType .COSINE )
140+ .extraParams (Map .of ())
141+ .build ());
142+
143+ CreateCollectionReq createCollectionReq = CreateCollectionReq .builder ()
144+ .collectionName (collectionName )
145+ .collectionSchema (schema )
146+ .indexParams (indexes )
147+ .build ();
148+ this .getMilvusClient ().createCollection (createCollectionReq );
149+ }
150+
151+ public void addAll (String collectionName , List <TextSegment > textSegments , List <Embedding > embeddings ) {
152+ List <JsonObject > data = convertToJsonObjects (textSegments , embeddings );
153+ InsertReq insertReq = InsertReq .builder ()
154+ .collectionName (collectionName )
155+ .data (data )
156+ .build ();
157+ this .getMilvusClient ().insert (insertReq );
158+ }
159+
160+ public List <JsonObject > convertToJsonObjects (List <TextSegment > textSegments , List <Embedding > embeddings ) {
161+ List <JsonObject > data = new ArrayList <>();
162+ for (int i = 0 ; i < textSegments .size (); i ++) {
163+ JsonObject jsonObject = new JsonObject ();
164+ jsonObject .addProperty ("id" , randomUUID ());
165+ jsonObject .addProperty ("text" , textSegments .get (i ).text ());
166+ jsonObject .add ("metadata" , GSON .toJsonTree (textSegments .get (i ).metadata ().toMap ()).getAsJsonObject ());
167+ JsonArray vectorArray = new JsonArray ();
168+ for (float f : embeddings .get (i ).vector ()) {
169+ vectorArray .add (f );
170+ }
171+ jsonObject .add ("vector" , vectorArray );
172+ data .add (jsonObject );
173+ }
174+ return data ;
175+ }
176+
177+ List <String > generateIds (int n ) {
178+ List <String > ids = new ArrayList <>();
179+ for (int i = 0 ; i < n ; i ++) {
180+ ids .add (randomUUID ());
181+ }
182+ return ids ;
183+ }
64184}
0 commit comments