3030import com .datastax .oss .driver .api .core .cql .BoundStatement ;
3131import com .datastax .oss .driver .api .core .cql .BoundStatementBuilder ;
3232import com .datastax .oss .driver .api .core .cql .PreparedStatement ;
33+ import com .datastax .oss .driver .api .core .cql .ResultSet ;
3334import com .datastax .oss .driver .api .core .cql .Row ;
3435import com .datastax .oss .driver .api .core .cql .SimpleStatement ;
3536import com .datastax .oss .driver .api .core .data .CqlVector ;
3637import com .datastax .oss .driver .api .core .metadata .schema .TableMetadata ;
3738import com .datastax .oss .driver .api .querybuilder .QueryBuilder ;
39+ import static com .datastax .oss .driver .api .querybuilder .QueryBuilder .literal ;
3840import com .datastax .oss .driver .api .querybuilder .delete .Delete ;
3941import com .datastax .oss .driver .api .querybuilder .delete .DeleteSelection ;
4042import com .datastax .oss .driver .api .querybuilder .insert .InsertInto ;
4143import com .datastax .oss .driver .api .querybuilder .insert .RegularInsert ;
44+ import com .datastax .oss .driver .api .querybuilder .select .Select ;
45+ import com .datastax .oss .driver .api .querybuilder .select .Selector ;
4246import com .datastax .oss .driver .shaded .guava .common .base .Preconditions ;
4347import io .micrometer .observation .ObservationRegistry ;
4448import org .slf4j .Logger ;
@@ -112,8 +116,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
112116
113117 public static final String DRIVER_PROFILE_SEARCH = "spring-ai-search" ;
114118
115- private static final String QUERY_FORMAT = "select %s,%s,%s%s from %s.%s ? order by %s ann of ? limit ?" ;
116-
117119 private static final Logger logger = LoggerFactory .getLogger (CassandraVectorStore .class );
118120
119121 private static Map <Similarity , VectorStoreSimilarityMetric > SIMILARITY_TYPE_MAPPING = Map .of (Similarity .COSINE ,
@@ -130,8 +132,6 @@ public class CassandraVectorStore extends AbstractObservationVectorStore impleme
130132
131133 private final PreparedStatement deleteStmt ;
132134
133- private final String similarityStmt ;
134-
135135 private final Similarity similarity ;
136136
137137 private final BatchingStrategy batchingStrategy ;
@@ -162,7 +162,6 @@ public CassandraVectorStore(CassandraVectorStoreConfig conf, EmbeddingModel embe
162162 .get ();
163163
164164 this .similarity = getIndexSimilarity (cassandraMetadata );
165- this .similarityStmt = similaritySearchStatement ();
166165
167166 this .filterExpressionConverter = new CassandraFilterExpressionConverter (
168167 cassandraMetadata .getColumns ().values ());
@@ -232,21 +231,14 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
232231 Preconditions .checkArgument (request .getTopK () <= 1000 );
233232 var embedding = toFloatArray (this .embeddingModel .embed (request .getQuery ()));
234233 CqlVector <Float > cqlVector = CqlVector .newInstance (embedding );
235-
236- String whereClause = "" ;
237- if (request .hasFilterExpression ()) {
238- String expression = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
239- if (!expression .isBlank ()) {
240- whereClause = String .format ("where %s" , expression );
241- }
242- }
243-
244- String query = String .format (this .similarityStmt , cqlVector , whereClause , cqlVector , request .getTopK ());
234+ String cql = createSimilaritySearchCql (request , cqlVector , request .getTopK ());
245235 List <Document > documents = new ArrayList <>();
246- logger .trace ("Executing {}" , query );
247- SimpleStatement s = SimpleStatement .newInstance (query ).setExecutionProfileName (DRIVER_PROFILE_SEARCH );
236+ logger .trace ("Executing {}" , cql );
248237
249- for (Row row : this .conf .session .execute (s )) {
238+ ResultSet result = this .conf .session
239+ .execute (SimpleStatement .newInstance (cql ).setExecutionProfileName (DRIVER_PROFILE_SEARCH ));
240+
241+ for (Row row : result ) {
250242 float score = row .getFloat (0 );
251243 if (score < request .getSimilarityThreshold ()) {
252244 break ;
@@ -333,38 +325,36 @@ private PreparedStatement prepareAddStatement(Set<String> metadataFields) {
333325 });
334326 }
335327
336- private String similaritySearchStatement () {
337- StringBuilder ids = new StringBuilder ();
338- for (var m : this .conf .schema .partitionKeys ()) {
339- ids .append (m .name ()).append (',' );
340- }
341- for (var m : this .conf .schema .clusteringKeys ()) {
342- ids .append (m .name ()).append (',' );
343- }
344- ids .deleteCharAt (ids .length () - 1 );
328+ private String createSimilaritySearchCql (SearchRequest request , CqlVector <Float > cqlVector , int topK ) {
345329
346- String similarityFunction = new StringBuilder ("similarity_" ).append (this .similarity .toString ().toLowerCase ())
347- .append ('(' )
348- .append (this .conf .schema .embedding ())
349- .append (",?)" )
350- .toString ();
330+ Select stmt = QueryBuilder .selectFrom (this .conf .schema .keyspace (), this .conf .schema .table ())
331+ .function ("similarity_" + this .similarity .toString ().toLowerCase (),
332+ Selector .column (this .conf .schema .embedding ()), literal (cqlVector ));
351333
352- StringBuilder extraSelectFields = new StringBuilder ();
334+ for (var c : this .conf .schema .partitionKeys ()) {
335+ stmt = stmt .column (c .name ());
336+ }
337+ for (var c : this .conf .schema .clusteringKeys ()) {
338+ stmt = stmt .column (c .name ());
339+ }
340+ stmt = stmt .column (this .conf .schema .content ());
353341 for (var m : this .conf .schema .metadataColumns ()) {
354- extraSelectFields . append ( ',' ). append (m .name ());
342+ stmt = stmt . column (m .name ());
355343 }
356344 if (this .conf .returnEmbeddings ) {
357- extraSelectFields . append ( ',' ). append (this .conf .schema .embedding ());
345+ stmt = stmt . column (this .conf .schema .embedding ());
358346 }
359347
360- // java-driver-query-builder doesn't support orderByAnnOf yet
361- String query = String .format (QUERY_FORMAT , similarityFunction , ids .toString (), this .conf .schema .content (),
362- extraSelectFields .toString (), this .conf .schema .keyspace (), this .conf .schema .table (),
363- this .conf .schema .embedding ());
364-
365- query = query .replace ("?" , "%s" );
366- logger .debug ("preparing {}" , query );
367- return query ;
348+ // the filterExpression is a string so we go back to building a CQL string
349+ String whereClause = "" ;
350+ if (request .hasFilterExpression ()) {
351+ String expression = this .filterExpressionConverter .convertExpression (request .getFilterExpression ());
352+ if (!expression .isBlank ()) {
353+ whereClause = String .format ("WHERE %s" , expression );
354+ }
355+ }
356+ String cql = stmt .orderByAnnOf (this .conf .schema .embedding (), cqlVector ).limit (topK ).asCql ();
357+ return cql .replace (" ORDER " , whereClause + " ORDER " );
368358 }
369359
370360 private String getDocumentId (Row row ) {
0 commit comments