Skip to content

Commit 4dbeb28

Browse files
restrict stash context only for stop words system index (#2283) (#2285)
Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit 5f9026d) Co-authored-by: Jing Zhang <[email protected]>
1 parent 2365afc commit 4dbeb28

File tree

1 file changed

+33
-14
lines changed
  • common/src/main/java/org/opensearch/ml/common/model

1 file changed

+33
-14
lines changed

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55

66
package org.opensearch.ml.common.model;
77

8+
import com.google.common.collect.ImmutableSet;
89
import lombok.Getter;
910
import lombok.NonNull;
1011
import lombok.extern.log4j.Log4j2;
11-
import org.opensearch.ResourceNotFoundException;
1212
import org.opensearch.action.LatchedActionListener;
13-
import org.opensearch.action.get.GetResponse;
1413
import org.opensearch.action.search.SearchRequest;
1514
import org.opensearch.action.search.SearchResponse;
1615
import org.opensearch.client.Client;
@@ -20,7 +19,6 @@
2019
import org.opensearch.core.action.ActionListener;
2120
import org.opensearch.core.xcontent.NamedXContentRegistry;
2221
import org.opensearch.core.xcontent.XContentParser;
23-
import org.opensearch.search.SearchHit;
2422
import org.opensearch.search.builder.SearchSourceBuilder;
2523

2624
import java.security.AccessController;
@@ -30,15 +28,14 @@
3028
import java.util.HashMap;
3129
import java.util.List;
3230
import java.util.Map;
31+
import java.util.Set;
3332
import java.util.concurrent.CountDownLatch;
3433
import java.util.concurrent.atomic.AtomicBoolean;
35-
import java.util.concurrent.atomic.AtomicReference;
3634
import java.util.regex.Matcher;
3735
import java.util.regex.Pattern;
3836
import java.util.stream.Collectors;
3937

4038
import static java.util.concurrent.TimeUnit.SECONDS;
41-
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
4239
import static org.opensearch.ml.common.utils.StringUtils.gson;
4340

4441
@Log4j2
@@ -52,6 +49,7 @@ public class MLGuard {
5249
private List<Pattern> outputRegexPattern;
5350
private NamedXContentRegistry xContentRegistry;
5451
private Client client;
52+
private Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
5553

5654
public MLGuard(Guardrails guardrails, NamedXContentRegistry xContentRegistry, Client client) {
5755
this.xContentRegistry = xContentRegistry;
@@ -128,27 +126,44 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
128126
Map<String, Object> queryBodyMap = Map
129127
.of("query", Map.of("percolate", Map.of("field", "query", "document", documentMap)));
130128
CountDownLatch latch = new CountDownLatch(1);
129+
ThreadContext.StoredContext context = null;
131130

132-
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
131+
try {
133132
queryBody = AccessController.doPrivileged((PrivilegedExceptionAction<String>) () -> gson.toJson(queryBodyMap));
134133
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
135134
XContentParser queryParser = XContentType.JSON.xContent().createParser(xContentRegistry, LoggingDeprecationHandler.INSTANCE, queryBody);
136135
searchSourceBuilder.parseXContent(queryParser);
137136
searchSourceBuilder.size(1); //Only need 1 doc returned, if hit.
138137
searchRequest = new SearchRequest().source(searchSourceBuilder).indices(indexName);
139-
context.restore();
140-
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
141-
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
138+
if (isStopWordsSystemIndex(indexName)) {
139+
context = client.threadPool().getThreadContext().stashContext();
140+
ThreadContext.StoredContext finalContext = context;
141+
client.search(searchRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
142+
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
143+
hitStopWords.set(true);
144+
}
145+
}, e -> {
146+
log.error("Failed to search stop words index {}", indexName, e);
147+
hitStopWords.set(true);
148+
}), latch), () -> finalContext.restore()));
149+
} else {
150+
client.search(searchRequest, new LatchedActionListener(ActionListener.<SearchResponse>wrap(r -> {
151+
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
152+
hitStopWords.set(true);
153+
}
154+
}, e -> {
155+
log.error("Failed to search stop words index {}", indexName, e);
142156
hitStopWords.set(true);
143-
}
144-
}, e -> {
145-
log.error("Failed to search stop words index {}", indexName, e);
146-
hitStopWords.set(true);
147-
}), latch), () -> context.restore()));
157+
}), latch));
158+
}
148159
} catch (Exception e) {
149160
log.error("[validateStopWords] Searching stop words index failed.", e);
150161
latch.countDown();
151162
hitStopWords.set(true);
163+
} finally {
164+
if (context != null) {
165+
context.close();
166+
}
152167
}
153168

154169
try {
@@ -160,6 +175,10 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
160175
return hitStopWords.get();
161176
}
162177

178+
private boolean isStopWordsSystemIndex(String index) {
179+
return stopWordsIndices.contains(index);
180+
}
181+
163182
public enum Type {
164183
INPUT,
165184
OUTPUT

0 commit comments

Comments
 (0)