55
66package org .opensearch .ml .common .model ;
77
8+ import com .google .common .collect .ImmutableSet ;
89import lombok .Getter ;
910import lombok .NonNull ;
1011import lombok .extern .log4j .Log4j2 ;
11- import org .opensearch .ResourceNotFoundException ;
1212import org .opensearch .action .LatchedActionListener ;
13- import org .opensearch .action .get .GetResponse ;
1413import org .opensearch .action .search .SearchRequest ;
1514import org .opensearch .action .search .SearchResponse ;
1615import org .opensearch .client .Client ;
2019import org .opensearch .core .action .ActionListener ;
2120import org .opensearch .core .xcontent .NamedXContentRegistry ;
2221import org .opensearch .core .xcontent .XContentParser ;
23- import org .opensearch .search .SearchHit ;
2422import org .opensearch .search .builder .SearchSourceBuilder ;
2523
2624import java .security .AccessController ;
3028import java .util .HashMap ;
3129import java .util .List ;
3230import java .util .Map ;
31+ import java .util .Set ;
3332import java .util .concurrent .CountDownLatch ;
3433import java .util .concurrent .atomic .AtomicBoolean ;
35- import java .util .concurrent .atomic .AtomicReference ;
3634import java .util .regex .Matcher ;
3735import java .util .regex .Pattern ;
3836import java .util .stream .Collectors ;
3937
4038import static java .util .concurrent .TimeUnit .SECONDS ;
41- import static org .opensearch .ml .common .CommonValue .MASTER_KEY ;
4239import static org .opensearch .ml .common .utils .StringUtils .gson ;
4340
4441@ Log4j2
@@ -52,6 +49,7 @@ public class MLGuard {
5249 private List <Pattern > outputRegexPattern ;
5350 private NamedXContentRegistry xContentRegistry ;
5451 private Client client ;
52+ private Set <String > stopWordsIndices = ImmutableSet .of (".plugins-ml-stop-words" );
5553
5654 public MLGuard (Guardrails guardrails , NamedXContentRegistry xContentRegistry , Client client ) {
5755 this .xContentRegistry = xContentRegistry ;
@@ -128,27 +126,44 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
128126 Map <String , Object > queryBodyMap = Map
129127 .of ("query" , Map .of ("percolate" , Map .of ("field" , "query" , "document" , documentMap )));
130128 CountDownLatch latch = new CountDownLatch (1 );
129+ ThreadContext .StoredContext context = null ;
131130
132- try ( ThreadContext . StoredContext context = client . threadPool (). getThreadContext (). stashContext ()) {
131+ try {
133132 queryBody = AccessController .doPrivileged ((PrivilegedExceptionAction <String >) () -> gson .toJson (queryBodyMap ));
134133 SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder ();
135134 XContentParser queryParser = XContentType .JSON .xContent ().createParser (xContentRegistry , LoggingDeprecationHandler .INSTANCE , queryBody );
136135 searchSourceBuilder .parseXContent (queryParser );
137136 searchSourceBuilder .size (1 ); //Only need 1 doc returned, if hit.
138137 searchRequest = new SearchRequest ().source (searchSourceBuilder ).indices (indexName );
139- context .restore ();
140- client .search (searchRequest , ActionListener .runBefore (new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
141- if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value == 0 ) {
138+ if (isStopWordsSystemIndex (indexName )) {
139+ context = client .threadPool ().getThreadContext ().stashContext ();
140+ ThreadContext .StoredContext finalContext = context ;
141+ client .search (searchRequest , ActionListener .runBefore (new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
142+ if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value == 0 ) {
143+ hitStopWords .set (true );
144+ }
145+ }, e -> {
146+ log .error ("Failed to search stop words index {}" , indexName , e );
147+ hitStopWords .set (true );
148+ }), latch ), () -> finalContext .restore ()));
149+ } else {
150+ client .search (searchRequest , new LatchedActionListener (ActionListener .<SearchResponse >wrap (r -> {
151+ if (r == null || r .getHits () == null || r .getHits ().getTotalHits () == null || r .getHits ().getTotalHits ().value == 0 ) {
152+ hitStopWords .set (true );
153+ }
154+ }, e -> {
155+ log .error ("Failed to search stop words index {}" , indexName , e );
142156 hitStopWords .set (true );
143- }
144- }, e -> {
145- log .error ("Failed to search stop words index {}" , indexName , e );
146- hitStopWords .set (true );
147- }), latch ), () -> context .restore ()));
157+ }), latch ));
158+ }
148159 } catch (Exception e ) {
149160 log .error ("[validateStopWords] Searching stop words index failed." , e );
150161 latch .countDown ();
151162 hitStopWords .set (true );
163+ } finally {
164+ if (context != null ) {
165+ context .close ();
166+ }
152167 }
153168
154169 try {
@@ -160,6 +175,10 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
160175 return hitStopWords .get ();
161176 }
162177
178+ private boolean isStopWordsSystemIndex (String index ) {
179+ return stopWordsIndices .contains (index );
180+ }
181+
163182 public enum Type {
164183 INPUT ,
165184 OUTPUT
0 commit comments