Skip to content

Commit aeb9063

Browse files
authored
SOLR-16667: LTR Add feature vector caching for ranking (#3433)
by Anna and Alessandro
1 parent 5d8a539 commit aeb9063

33 files changed

+1471
-542
lines changed

solr/CHANGES.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ New Features
8383

8484
* SOLR-17813: Add support for SeededKnnVectorQuery (Ilaria Petreti via Alessandro Benedetti)
8585

86+
* SOLR-16667: LTR Add feature vector caching for ranking. (Anna Ruggero, Alessandro Benedetti)
87+
8688
Improvements
8789
---------------------
8890

@@ -572,7 +574,7 @@ Bug Fixes
572574

573575
* SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via Alessandro Benedetti)
574576

575-
* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti)
577+
* SOLR-17760: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti)
576578

577579
* SOLR-17800: Security Manager should handle symlink on /tmp (Kevin Risden)
578580

solr/core/src/java/org/apache/solr/core/SolrConfig.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,9 @@ private SolrConfig(SolrResourceLoader loader, String name, Properties substituta
301301
queryResultCacheConfig =
302302
CacheConfig.getConfig(
303303
this, get("query").get("queryResultCache"), "query/queryResultCache");
304+
featureVectorCacheConfig =
305+
CacheConfig.getConfig(
306+
this, get("query").get("featureVectorCache"), "query/featureVectorCache");
304307
documentCacheConfig =
305308
CacheConfig.getConfig(this, get("query").get("documentCache"), "query/documentCache");
306309
CacheConfig conf =
@@ -662,6 +665,7 @@ public SolrRequestParsers getRequestParsers() {
662665
public final CacheConfig queryResultCacheConfig;
663666
public final CacheConfig documentCacheConfig;
664667
public final CacheConfig fieldValueCacheConfig;
668+
public final CacheConfig featureVectorCacheConfig;
665669
public final Map<String, CacheConfig> userCacheConfigs;
666670
// SolrIndexSearcher - more...
667671
public final boolean useFilterForSortedQuery;
@@ -998,7 +1002,12 @@ public Map<String, Object> toMap(Map<String, Object> result) {
9981002
}
9991003

10001004
addCacheConfig(
1001-
m, filterCacheConfig, queryResultCacheConfig, documentCacheConfig, fieldValueCacheConfig);
1005+
m,
1006+
filterCacheConfig,
1007+
queryResultCacheConfig,
1008+
documentCacheConfig,
1009+
fieldValueCacheConfig,
1010+
featureVectorCacheConfig);
10021011
m = new LinkedHashMap<>();
10031012
result.put("requestDispatcher", m);
10041013
if (httpCachingConfig != null) m.put("httpCaching", httpCachingConfig);

solr/core/src/java/org/apache/solr/search/SolrIndexSearcher.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
163163
private final SolrCache<Query, DocSet> filterCache;
164164
private final SolrCache<QueryResultKey, DocList> queryResultCache;
165165
private final SolrCache<String, UnInvertedField> fieldValueCache;
166+
private final SolrCache<Integer, float[]> featureVectorCache;
166167
private final LongAdder fullSortCount = new LongAdder();
167168
private final LongAdder skipSortCount = new LongAdder();
168169
private final LongAdder liveDocsNaiveCacheHitCount = new LongAdder();
@@ -448,6 +449,11 @@ public SolrIndexSearcher(
448449
? null
449450
: solrConfig.queryResultCacheConfig.newInstance();
450451
if (queryResultCache != null) clist.add(queryResultCache);
452+
featureVectorCache =
453+
solrConfig.featureVectorCacheConfig == null
454+
? null
455+
: solrConfig.featureVectorCacheConfig.newInstance();
456+
if (featureVectorCache != null) clist.add(featureVectorCache);
451457
SolrCache<Integer, Document> documentCache = docFetcher.getDocumentCache();
452458
if (documentCache != null) clist.add(documentCache);
453459

@@ -469,6 +475,7 @@ public SolrIndexSearcher(
469475
this.filterCache = null;
470476
this.queryResultCache = null;
471477
this.fieldValueCache = null;
478+
this.featureVectorCache = null;
472479
this.cacheMap = NO_GENERIC_CACHES;
473480
this.cacheList = NO_CACHES;
474481
}
@@ -689,6 +696,10 @@ public SolrCache<Query, DocSet> getFilterCache() {
689696
return filterCache;
690697
}
691698

699+
public SolrCache<Integer, float[]> getFeatureVectorCache() {
700+
return featureVectorCache;
701+
}
702+
692703
//
693704
// Set default regenerators on filter and query caches if they don't have any
694705
//

solr/modules/ltr/build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ dependencies {
5555
testImplementation libs.junit.junit
5656
testImplementation libs.hamcrest.hamcrest
5757

58+
testImplementation libs.prometheus.metrics.model
59+
5860
testImplementation libs.commonsio.commonsio
5961
}
6062

solr/modules/ltr/gradle.lockfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ io.opentelemetry:opentelemetry-sdk-metrics:1.53.0=jarValidation,runtimeClasspath
6565
io.opentelemetry:opentelemetry-sdk-trace:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
6666
io.opentelemetry:opentelemetry-sdk:1.53.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
6767
io.prometheus:prometheus-metrics-exposition-formats:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
68-
io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
68+
io.prometheus:prometheus-metrics-model:1.1.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath
6969
io.sgr:s2-geometry-library-java:1.0.0=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath
7070
io.swagger.core.v3:swagger-annotations-jakarta:2.2.22=compileClasspath,jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testCompileClasspath,testRuntimeClasspath
7171
jakarta.annotation:jakarta.annotation-api:2.1.1=jarValidation,runtimeClasspath,runtimeLibs,solrPlatformLibs,testRuntimeClasspath

solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,20 @@ public class CSVFeatureLogger extends FeatureLogger {
2323
private final char keyValueSep;
2424
private final char featureSep;
2525

26-
public CSVFeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) {
27-
super(fvCacheName, f, logAll);
26+
public CSVFeatureLogger(FeatureFormat f, Boolean logAll) {
27+
super(f, logAll);
2828
this.keyValueSep = DEFAULT_KEY_VALUE_SEPARATOR;
2929
this.featureSep = DEFAULT_FEATURE_SEPARATOR;
3030
}
3131

32-
public CSVFeatureLogger(
33-
String fvCacheName, FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) {
34-
super(fvCacheName, f, logAll);
32+
public CSVFeatureLogger(FeatureFormat f, Boolean logAll, char keyValueSep, char featureSep) {
33+
super(f, logAll);
3534
this.keyValueSep = keyValueSep;
3635
this.featureSep = featureSep;
3736
}
3837

3938
@Override
40-
public String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
39+
public String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo) {
4140
// Allocate the buffer to a size based on the number of features instead of the
4241
// default 16. You need space for the name, value, and two separators per feature,
4342
// but not all the features are expected to fire, so this is just a naive estimate.

solr/modules/ltr/src/java/org/apache/solr/ltr/DocInfo.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ public class DocInfo extends HashMap<String, Object> {
2222

2323
// Name of key used to store the original score of a doc
2424
private static final String ORIGINAL_DOC_SCORE = "ORIGINAL_DOC_SCORE";
25+
// Name of key used to store the original id of a doc
26+
private static final String ORIGINAL_DOC_ID = "ORIGINAL_DOC_ID";
2527

2628
public DocInfo() {
2729
super();
@@ -38,4 +40,12 @@ public Float getOriginalDocScore() {
3840
public boolean hasOriginalDocScore() {
3941
return containsKey(ORIGINAL_DOC_SCORE);
4042
}
43+
44+
public void setOriginalDocId(int docId) {
45+
put(ORIGINAL_DOC_ID, docId);
46+
}
47+
48+
public int getOriginalDocId() {
49+
return (int) get(ORIGINAL_DOC_ID);
50+
}
4151
}

solr/modules/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java

Lines changed: 13 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,10 @@
1616
*/
1717
package org.apache.solr.ltr;
1818

19-
import org.apache.solr.search.SolrIndexSearcher;
20-
2119
/**
2220
* FeatureLogger can be registered in a model and provide a strategy for logging the feature values.
2321
*/
2422
public abstract class FeatureLogger {
25-
26-
/** the name of the cache using for storing the feature value */
27-
private final String fvCacheName;
28-
2923
public enum FeatureFormat {
3024
DENSE,
3125
SPARSE
@@ -35,54 +29,15 @@ public enum FeatureFormat {
3529

3630
protected Boolean logAll;
3731

38-
protected FeatureLogger(String fvCacheName, FeatureFormat f, Boolean logAll) {
39-
this.fvCacheName = fvCacheName;
32+
protected boolean logFeatures;
33+
34+
protected FeatureLogger(FeatureFormat f, Boolean logAll) {
4035
this.featureFormat = f;
4136
this.logAll = logAll;
37+
this.logFeatures = false;
4238
}
4339

44-
/**
45-
* Log will be called every time that the model generates the feature values for a document and a
46-
* query.
47-
*
48-
* @param docid Solr document id whose features we are saving
49-
* @param featuresInfo List of all the {@link LTRScoringQuery.FeatureInfo} objects which contain
50-
* name and value for all the features triggered by the result set
51-
* @return true if the logger successfully logged the features, false otherwise.
52-
*/
53-
public boolean log(
54-
int docid,
55-
LTRScoringQuery scoringQuery,
56-
SolrIndexSearcher searcher,
57-
LTRScoringQuery.FeatureInfo[] featuresInfo) {
58-
final String featureVector = makeFeatureVector(featuresInfo);
59-
if (featureVector == null) {
60-
return false;
61-
}
62-
63-
if (null == searcher.cacheInsert(fvCacheName, fvCacheKey(scoringQuery, docid), featureVector)) {
64-
return false;
65-
}
66-
67-
return true;
68-
}
69-
70-
public abstract String makeFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo);
71-
72-
private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) {
73-
return scoringQuery.hashCode() + (31 * docid);
74-
}
75-
76-
/**
77-
* populate the document with its feature vector
78-
*
79-
* @param docid Solr document id
80-
* @return String representation of the list of features calculated for docid
81-
*/
82-
public String getFeatureVector(
83-
int docid, LTRScoringQuery scoringQuery, SolrIndexSearcher searcher) {
84-
return (String) searcher.cacheLookup(fvCacheName, fvCacheKey(scoringQuery, docid));
85-
}
40+
public abstract String printFeatureVector(LTRScoringQuery.FeatureInfo[] featuresInfo);
8641

8742
public Boolean isLoggingAll() {
8843
return logAll;
@@ -91,4 +46,12 @@ public Boolean isLoggingAll() {
9146
public void setLogAll(Boolean logAll) {
9247
this.logAll = logAll;
9348
}
49+
50+
public void setLogFeatures(boolean logFeatures) {
51+
this.logFeatures = logFeatures;
52+
}
53+
54+
public boolean isLogFeatures() {
55+
return logFeatures;
56+
}
9457
}

0 commit comments

Comments
 (0)