|
28 | 28 | import org.opensearch.index.query.QueryBuilder; |
29 | 29 | import org.opensearch.index.query.RangeQueryBuilder; |
30 | 30 | import org.opensearch.index.query.TermQueryBuilder; |
| 31 | +import org.opensearch.index.query.TermsQueryBuilder; |
31 | 32 | import org.opensearch.ingest.Processor; |
32 | 33 | import org.opensearch.ml.common.output.model.ModelTensor; |
33 | 34 | import org.opensearch.ml.common.output.model.ModelTensorOutput; |
@@ -247,6 +248,66 @@ public void onFailure(Exception e) { |
247 | 248 |
|
248 | 249 | } |
249 | 250 |
|
| 251 | + /** |
| 252 | + * Tests the successful rewriting of multiple string in terms query based on the model output. |
| 253 | + * |
| 254 | + * @throws Exception if an error occurs during the test |
| 255 | + */ |
| 256 | + public void testExecute_rewriteTermsQuerySuccess() throws Exception { |
| 257 | + /** |
| 258 | + * example term query: {"query":{"terms":{"text":["foo","bar],"boost":1.0}}} |
| 259 | + */ |
| 260 | + String modelInputField = "inputs"; |
| 261 | + String originalQueryField = "query.terms.text"; |
| 262 | + String newQueryField = "query.terms.text"; |
| 263 | + String modelOutputField = "response"; |
| 264 | + MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor( |
| 265 | + null, |
| 266 | + modelInputField, |
| 267 | + originalQueryField, |
| 268 | + newQueryField, |
| 269 | + modelOutputField, |
| 270 | + false, |
| 271 | + false |
| 272 | + ); |
| 273 | + ModelTensor modelTensor = ModelTensor |
| 274 | + .builder() |
| 275 | + .dataAsMap(ImmutableMap.of("response", Arrays.asList("car", "vehicle", "truck"))) |
| 276 | + .build(); |
| 277 | + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); |
| 278 | + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); |
| 279 | + |
| 280 | + doAnswer(invocation -> { |
| 281 | + ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2); |
| 282 | + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); |
| 283 | + return null; |
| 284 | + }).when(client).execute(any(), any(), any()); |
| 285 | + |
| 286 | + QueryBuilder incomingQuery = new TermsQueryBuilder("text", Arrays.asList("foo", "bar")); |
| 287 | + SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); |
| 288 | + SearchRequest request = new SearchRequest().source(source); |
| 289 | + /** |
| 290 | + * example terms query: {"query":{"terms":{"text":["car","vehicle","truck"],"boost":1.0}}} |
| 291 | + */ |
| 292 | + |
| 293 | + ActionListener<SearchRequest> Listener = new ActionListener<>() { |
| 294 | + @Override |
| 295 | + public void onResponse(SearchRequest newSearchRequest) { |
| 296 | + QueryBuilder expectedQuery = new TermsQueryBuilder("text", Arrays.asList("car", "vehicle", "truck")); |
| 297 | + assertEquals(expectedQuery, newSearchRequest.source().query()); |
| 298 | + assertEquals(request.toString(), newSearchRequest.toString()); |
| 299 | + } |
| 300 | + |
| 301 | + @Override |
| 302 | + public void onFailure(Exception e) { |
| 303 | + throw new RuntimeException("Failed in executing processRequestAsync."); |
| 304 | + } |
| 305 | + }; |
| 306 | + |
| 307 | + requestProcessor.processRequestAsync(request, requestContext, Listener); |
| 308 | + |
| 309 | + } |
| 310 | + |
250 | 311 | /** |
251 | 312 | * Tests the successful rewriting of a double in a term query based on the model output. |
252 | 313 | * |
@@ -444,7 +505,6 @@ public void onFailure(Exception e) { |
444 | 505 | * @throws Exception if an error occurs during the test |
445 | 506 | */ |
446 | 507 | public void testExecute_rewriteListFromTermQueryToGeometryQuerySuccess() throws Exception { |
447 | | - |
448 | 508 | String queryTemplate = "{\n" |
449 | 509 | + " \"query\": {\n" |
450 | 510 | + " \"geo_shape\" : {\n" |
|
0 commit comments