Skip to content

Commit 60a6b3e

Browse files
committed
add integration tests for index_options / defaults
1 parent 2afeb2b commit 60a6b3e

File tree

2 files changed

+209
-9
lines changed

2 files changed

+209
-9
lines changed

x-pack/plugin/core/src/internalClusterTest/java/org/elasticsearch/xpack/core/ml/search/SparseVectorIndexOptionsIT.java

Lines changed: 208 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,47 @@
99

1010
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
1111

12+
import org.apache.http.HttpStatus;
13+
import org.elasticsearch.client.Request;
14+
import org.elasticsearch.client.Response;
15+
import org.elasticsearch.cluster.metadata.IndexMetadata;
16+
import org.elasticsearch.common.Strings;
17+
import org.elasticsearch.common.settings.Settings;
18+
import org.elasticsearch.common.xcontent.XContentHelper;
19+
import org.elasticsearch.plugins.Plugin;
1220
import org.elasticsearch.test.ESIntegTestCase;
1321

22+
import org.elasticsearch.xcontent.XContentType;
23+
import org.elasticsearch.xpack.core.XPackClientPlugin;
24+
import org.hamcrest.Matchers;
25+
import org.junit.Before;
26+
27+
import java.io.IOException;
1428
import java.util.ArrayList;
29+
import java.util.Collection;
1530
import java.util.List;
31+
import java.util.Map;
32+
33+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
1634

1735
public class SparseVectorIndexOptionsIT extends ESIntegTestCase {
36+
private static final String TEST_INDEX_NAME = "index_with_sparse_vector";
37+
private static final String SPARSE_VECTOR_FIELD = "sparse_vector_field";
38+
private static final int TEST_PRUNING_TOKENS_FREQ_THRESHOLD = 1;
39+
private static final float TEST_PRUNING_TOKENS_WEIGHT_THRESHOLD = 1.0f;
1840

1941
private final boolean testHasIndexOptions;
2042
private final boolean testIndexShouldPrune;
2143
private final boolean testQueryShouldNotPrune;
22-
private final boolean usePreviousIndexVersion;
2344

2445
public SparseVectorIndexOptionsIT(
2546
boolean setIndexOptions,
2647
boolean setIndexShouldPrune,
27-
boolean setQueryShouldNotPrune,
28-
boolean usePreviousIndexVersion
48+
boolean setQueryShouldNotPrune
2949
) {
3050
this.testHasIndexOptions = setIndexOptions;
3151
this.testIndexShouldPrune = setIndexShouldPrune;
3252
this.testQueryShouldNotPrune = setQueryShouldNotPrune;
33-
this.usePreviousIndexVersion = usePreviousIndexVersion;
3453
}
3554

3655
@ParametersFactory
@@ -39,13 +58,193 @@ public static Iterable<Object[]> parameters() throws Exception {
3958
// create a matrix of all combinations
4059
// of our first three parameters
4160
for (int i = 0; i < 8; i++) {
42-
params.add(new Object[] { (i & 1) == 0, (i & 2) == 0, (i & 4) == 0, false });
61+
params.add(new Object[] { (i & 1) == 0, (i & 2) == 0, (i & 4) == 0 });
4362
}
44-
// and add in overrides for the previous index versions
45-
params.add(new Object[] { false, false, false, true });
46-
params.add(new Object[] { false, false, true, true });
4763
return params;
4864
}
4965

50-
public void testItPrunesTokensIfIndexOptions() {}
66+
@Override
67+
protected boolean addMockHttpTransport() {
68+
return false; // enable http
69+
}
70+
71+
@Override
72+
protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) {
73+
Settings.Builder settings = Settings.builder().put(super.nodeSettings(nodeOrdinal, otherSettings));
74+
return settings.build();
75+
}
76+
77+
@Override
78+
protected Collection<Class<? extends Plugin>> nodePlugins() {
79+
return List.of(XPackClientPlugin.class);
80+
}
81+
82+
static final int INTERNAL_UNMANAGED_FLAG_VALUE = 2;
83+
static final String FlAG_SETTING_KEY = IndexMetadata.INDEX_PRIORITY_SETTING.getKey();
84+
85+
@Before
86+
public void setup() {
87+
assertAcked(prepareCreate(TEST_INDEX_NAME).setMapping(getTestIndexMapping()));
88+
ensureGreen(TEST_INDEX_NAME);
89+
90+
for (Map.Entry<String, String> doc : TEST_DOCUMENTS.entrySet()) {
91+
index(TEST_INDEX_NAME, doc.getKey(), doc.getValue());
92+
}
93+
flushAndRefresh(TEST_INDEX_NAME);
94+
}
95+
96+
public void testSparseVectorTokenPruning() throws IOException {
97+
Response response = performSearch(getBuilderForSearch());
98+
assertThat(response.getStatusLine().getStatusCode(), Matchers.equalTo(HttpStatus.SC_OK));
99+
final Map<String, Object> responseMap = XContentHelper.convertToMap(
100+
XContentType.JSON.xContent(),
101+
response.getEntity().getContent(),
102+
true
103+
);
104+
assertCorrectResponse(responseMap);
105+
}
106+
107+
@SuppressWarnings("unchecked")
108+
private void assertCorrectResponse(Map<String, Object> responseMap) {
109+
List<String> expectedIds = getTestExpectedDocIds();
110+
111+
Map<String, Object> mapHits = (Map<String, Object>) responseMap.get("hits");
112+
Map<String, Object> mapHitsTotal = (Map<String, Object>) mapHits.get("total");
113+
int actualTotalHits = (int) mapHitsTotal.get("value");
114+
int numHitsExpected = expectedIds.size();
115+
116+
// assertEquals(getAssertMessage("Search result total hits count mismatch"), numHitsExpected, actualTotalHits);
117+
118+
List<Map<String, Object>> hits = (List<Map<String, Object>>) mapHits.get("hits");
119+
List<String> actualDocIds = new ArrayList<>();
120+
for (Map<String, Object> doc : hits) {
121+
actualDocIds.add((String)doc.get("_id"));
122+
}
123+
124+
assertEquals(getAssertMessage("Result document ids mismatch"), expectedIds, actualDocIds);
125+
}
126+
127+
private String getTestIndexMapping() {
128+
if (isRunningAgainstOldCluster()) {
129+
return "{\"properties\":{\"" + SPARSE_VECTOR_FIELD + "\":{\"type\":\"sparse_vector\"}}}";
130+
}
131+
132+
String testPruningConfigMapping = "\"pruning_config\":{\"tokens_freq_ratio_threshold\":"
133+
+ TEST_PRUNING_TOKENS_FREQ_THRESHOLD
134+
+ ",\"tokens_weight_threshold\":"
135+
+ TEST_PRUNING_TOKENS_WEIGHT_THRESHOLD
136+
+ "}";
137+
138+
String pruningMappingString = testIndexShouldPrune
139+
? "\"prune\":true," + testPruningConfigMapping
140+
: "\"prune\":false";
141+
String indexOptionsString = testHasIndexOptions
142+
? ",\"index_options\":{" + pruningMappingString + "}"
143+
: "";
144+
145+
return "{\"properties\":{\""
146+
+ SPARSE_VECTOR_FIELD
147+
+ "\":{\"type\":\"sparse_vector\""
148+
+ indexOptionsString
149+
+ "}}}";
150+
}
151+
152+
private boolean isRunningAgainstOldCluster() {
153+
return false;
154+
}
155+
156+
private List<String> getTestExpectedDocIds() {
157+
if (testQueryShouldNotPrune) {
158+
// query overrides prune = false in all cases
159+
return EXPECTED_DOC_IDS_WITHOUT_PRUNING;
160+
}
161+
162+
if (testHasIndexOptions) {
163+
// index has set index options in the mapping
164+
return testIndexShouldPrune
165+
? EXPECTED_DOC_IDS_WITH_PRUNING
166+
: EXPECTED_DOC_IDS_WITHOUT_PRUNING;
167+
}
168+
169+
// default pruning should be true with default configuration
170+
return EXPECTED_DOC_IDS_WITH_DEFAULT_PRUNING;
171+
}
172+
173+
private Response performSearch(String source) throws IOException {
174+
Request request = new Request("GET", TEST_INDEX_NAME + "/_search");
175+
request.setJsonEntity(source);
176+
return getRestClient().performRequest(request);
177+
}
178+
179+
private String getBuilderForSearch() {
180+
boolean shouldUseDefaultTokens = (testQueryShouldNotPrune == false && testHasIndexOptions == false);
181+
SparseVectorQueryBuilder queryBuilder = new SparseVectorQueryBuilder(
182+
SPARSE_VECTOR_FIELD,
183+
shouldUseDefaultTokens ? SEARCH_WEIGHTED_TOKENS_WITH_DEFAULTS : SEARCH_WEIGHTED_TOKENS,
184+
null,
185+
null,
186+
testQueryShouldNotPrune ? false : null,
187+
null
188+
);
189+
190+
return "{\"query\":" + Strings.toString(queryBuilder) + "}";
191+
}
192+
193+
private String getAssertMessage(String message) {
194+
return message
195+
+ " (Params: hasIndexOptions="
196+
+ testHasIndexOptions
197+
+ ", indexShouldPrune="
198+
+ testIndexShouldPrune
199+
+ ", queryShouldNotPrune="
200+
+ testQueryShouldNotPrune
201+
+ "): "
202+
+ getDescriptiveTestType();
203+
}
204+
205+
private String getDescriptiveTestType() {
206+
String testDescription = "";
207+
if (testQueryShouldNotPrune) {
208+
testDescription = "query override prune=false:";
209+
}
210+
211+
if (testHasIndexOptions) {
212+
testDescription += " pruning index_options explicitly set:";
213+
} else {
214+
testDescription = " no index options set, tokens should be pruned by default:";
215+
}
216+
217+
if (testIndexShouldPrune == false) {
218+
testDescription += " index options has pruning set to false";
219+
}
220+
221+
return testDescription;
222+
}
223+
224+
private static final Map<String, String> TEST_DOCUMENTS = Map.of(
225+
"1", "{\"sparse_vector_field\":{\"cheese\": 2.671405,\"is\": 0.11809908,\"comet\": 0.26088917}}",
226+
"2", "{\"sparse_vector_field\":{\"planet\": 2.3438394,\"is\": 0.54600334,\"astronomy\": 0.36015007,\"moon\": 0.20022368}}",
227+
"3", "{\"sparse_vector_field\":{\"is\": 0.6891394,\"globe\": 0.484035,\"ocean\": 0.080102935,\"underground\": 0.053516876}}"
228+
);
229+
230+
private static final List<WeightedToken> SEARCH_WEIGHTED_TOKENS = List.of(
231+
new WeightedToken("cheese", 0.5f),
232+
new WeightedToken("comet", 0.5f),
233+
new WeightedToken("globe", 0.484035f),
234+
new WeightedToken("ocean", 0.080102935f),
235+
new WeightedToken("underground", 0.053516876f),
236+
new WeightedToken("is", 0.54600334f)
237+
);
238+
239+
private static final List<WeightedToken> SEARCH_WEIGHTED_TOKENS_WITH_DEFAULTS = List.of(
240+
new WeightedToken("planet", 0.2f)
241+
);
242+
243+
private static final List<String> EXPECTED_DOC_IDS_WITHOUT_PRUNING = List.of(
244+
"1", "3", "2"
245+
);
246+
247+
private static final List<String> EXPECTED_DOC_IDS_WITH_PRUNING = List.of("1");
248+
249+
private static final List<String> EXPECTED_DOC_IDS_WITH_DEFAULT_PRUNING = List.of("2");
51250
}

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/search/SparseVectorQueryBuilder.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,7 @@ private TokenPruningConfig getTokenPruningConfigForQuery(MappedFieldType ft, Sea
375375
pruningConfigToUse = pruningConfigToUse == null ? asSVFieldType.getIndexOptions().getPruningConfig() : pruningConfigToUse;
376376
}
377377

378+
// do not prune if shouldQueryPruneTokens is explicitly set to false
378379
if (shouldQueryPruneTokens != null && shouldQueryPruneTokens == false) {
379380
return null;
380381
}

0 commit comments

Comments
 (0)