77
88import static org .opensearch .core .rest .RestStatus .BAD_REQUEST ;
99import static org .opensearch .core .rest .RestStatus .INTERNAL_SERVER_ERROR ;
10- import static org .opensearch .ml .common .CommonValue .TENANT_ID_FIELD ;
1110import static org .opensearch .ml .utils .RestActionUtils .wrapListenerToHandleSearchIndexNotFound ;
1211
1312import java .util .ArrayList ;
4342import org .opensearch .ml .common .exception .MLResourceNotFoundException ;
4443import org .opensearch .ml .helper .ModelAccessControlHelper ;
4544import org .opensearch .ml .utils .RestActionUtils ;
45+ import org .opensearch .remote .metadata .client .SdkClient ;
46+ import org .opensearch .remote .metadata .client .SearchDataObjectRequest ;
47+ import org .opensearch .remote .metadata .common .SdkClientUtils ;
4648import org .opensearch .search .SearchHits ;
4749import org .opensearch .search .builder .SearchSourceBuilder ;
4850import org .opensearch .search .fetch .subphase .FetchSourceContext ;
@@ -77,10 +79,11 @@ public MLSearchHandler(
7779
7880 /**
7981 * Fetch all the models from the model group index, and then create a combined query to model version index.
82+ * @param sdkClient sdkclient a wrapper of the client
8083 * @param request
8184 * @param actionListener
8285 */
83- public void search (SearchRequest request , String tenantId , ActionListener <SearchResponse > actionListener ) {
86+ public void search (SdkClient sdkClient , SearchRequest request , String tenantId , ActionListener <SearchResponse > actionListener ) {
8487 User user = RestActionUtils .getUserContext (client );
8588 ActionListener <SearchResponse > listener = wrapRestActionListener (actionListener , "Fail to search model version" );
8689 try (ThreadContext .StoredContext context = client .threadPool ().getThreadContext ().stashContext ()) {
@@ -114,11 +117,6 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
114117 // Add a should clause to include documents where IS_HIDDEN_FIELD is false
115118 shouldQuery .should (QueryBuilders .termQuery (MLModel .IS_HIDDEN_FIELD , false ));
116119
117- // For multi-tenancy
118- if (tenantId != null ) {
119- shouldQuery .should (QueryBuilders .termQuery (TENANT_ID_FIELD , tenantId ));
120- }
121-
122120 // Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
123121 shouldQuery .should (QueryBuilders .boolQuery ().mustNot (QueryBuilders .existsQuery (MLModel .IS_HIDDEN_FIELD )));
124122
@@ -132,10 +130,29 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
132130 request .source ().fetchSource (rebuiltFetchSourceContext );
133131 final ActionListener <SearchResponse > doubleWrapperListener = ActionListener
134132 .wrap (wrappedListener ::onResponse , e -> wrapListenerToHandleSearchIndexNotFound (e , wrappedListener ));
135- if (modelAccessControlHelper .skipModelAccessControl (user )) {
136- client .search (request , doubleWrapperListener );
137- } else if (!clusterService .state ().metadata ().hasIndex (CommonValue .ML_MODEL_GROUP_INDEX )) {
138- client .search (request , doubleWrapperListener );
133+ if (modelAccessControlHelper .skipModelAccessControl (user )
134+ || !clusterService .state ().metadata ().hasIndex (CommonValue .ML_MODEL_GROUP_INDEX )) {
135+
136+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
137+ .builder ()
138+ .indices (request .indices ())
139+ .searchSourceBuilder (request .source ())
140+ .tenantId (tenantId )
141+ .build ();
142+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((r , throwable ) -> {
143+ if (throwable == null ) {
144+ try {
145+ SearchResponse searchResponse = SearchResponse .fromXContent (r .parser ());
146+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
147+ doubleWrapperListener .onResponse (searchResponse );
148+ } catch (Exception e ) {
149+ doubleWrapperListener .onFailure (e );
150+ }
151+ } else {
152+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
153+ doubleWrapperListener .onFailure (e );
154+ }
155+ });
139156 } else {
140157 SearchSourceBuilder sourceBuilder = modelAccessControlHelper .createSearchSourceBuilder (user );
141158 SearchRequest modelGroupSearchRequest = new SearchRequest ();
@@ -154,17 +171,54 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
154171 Arrays .stream (r .getHits ().getHits ()).forEach (hit -> { modelGroupIds .add (hit .getId ()); });
155172
156173 request .source ().query (rewriteQueryBuilder (request .source ().query (), modelGroupIds ));
157- client .search (request , doubleWrapperListener );
158174 } else {
159175 log .debug ("No model group found" );
160176 request .source ().query (rewriteQueryBuilder (request .source ().query (), null ));
161- client .search (request , doubleWrapperListener );
162177 }
178+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
179+ .builder ()
180+ .indices (request .indices ())
181+ .searchSourceBuilder (request .source ())
182+ .tenantId (tenantId )
183+ .build ();
184+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((sr , throwable ) -> {
185+ if (throwable == null ) {
186+ try {
187+ SearchResponse searchResponse = SearchResponse .fromXContent (sr .parser ());
188+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
189+ doubleWrapperListener .onResponse (searchResponse );
190+ } catch (Exception e ) {
191+ doubleWrapperListener .onFailure (e );
192+ }
193+ } else {
194+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
195+ doubleWrapperListener .onFailure (e );
196+ }
197+ });
163198 }, e -> {
164199 log .error ("Fail to search model groups!" , e );
165200 wrappedListener .onFailure (e );
166201 });
167- client .search (modelGroupSearchRequest , modelGroupSearchActionListener );
202+ SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
203+ .builder ()
204+ .indices (modelGroupSearchRequest .indices ())
205+ .searchSourceBuilder (modelGroupSearchRequest .source ())
206+ .tenantId (tenantId )
207+ .build ();
208+ sdkClient .searchDataObjectAsync (searchDataObjectRequest ).whenComplete ((r , throwable ) -> {
209+ if (throwable == null ) {
210+ try {
211+ SearchResponse searchResponse = SearchResponse .fromXContent (r .parser ());
212+ log .info ("Model search complete: {}" , searchResponse .getHits ().getTotalHits ());
213+ modelGroupSearchActionListener .onResponse (searchResponse );
214+ } catch (Exception e ) {
215+ modelGroupSearchActionListener .onFailure (e );
216+ }
217+ } else {
218+ Exception e = SdkClientUtils .unwrapAndConvertToException (throwable , OpenSearchStatusException .class );
219+ modelGroupSearchActionListener .onFailure (e );
220+ }
221+ });
168222 }
169223 } catch (Exception e ) {
170224 log .error (e .getMessage (), e );
0 commit comments