Skip to content

Commit 7da5942

Browse files
remove ignoreFailure and fix JsonArray Parsing Issue (#2770) (#2774)
1 parent 97d9f68 commit 7da5942

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessor.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
251251
requestListener.onResponse(request);
252252
}
253253
} catch (Exception e) {
254-
if (ignoreMissing || ignoreFailure) {
254+
if (ignoreFailure) {
255255
logger.error("Failed in writing prediction outcomes to new query", e);
256256
requestListener.onResponse(request);
257257

@@ -348,7 +348,7 @@ private boolean validateQueryFieldInQueryString(
348348
for (Map.Entry<String, String> entry : inputMap.entrySet()) {
349349
// the inputMap takes in model input as keys and query fields as value
350350
String queryField = entry.getValue();
351-
String pathData = jsonData.read(queryField);
351+
Object pathData = jsonData.read(queryField);
352352
if (pathData == null) {
353353
throw new IllegalArgumentException("cannot find field: " + queryField + " in query string: " + jsonData.jsonString());
354354
}
@@ -358,7 +358,7 @@ private boolean validateQueryFieldInQueryString(
358358
for (Map<String, String> outputMap : processOutputMap) {
359359
for (Map.Entry<String, String> entry : outputMap.entrySet()) {
360360
String queryField = entry.getKey();
361-
String pathData = jsonData.read(queryField);
361+
Object pathData = jsonData.read(queryField);
362362
if (pathData == null) {
363363
throw new IllegalArgumentException(
364364
"cannot find field: " + queryField + " in query string: " + jsonData.jsonString()
@@ -402,7 +402,7 @@ private void processPredictions(
402402
// model field as key, query field name as value
403403
String modelInputFieldName = entry.getKey();
404404
String queryFieldName = entry.getValue();
405-
String queryFieldValue = JsonPath.parse(newQuery).read(queryFieldName);
405+
String queryFieldValue = StringUtils.toJson(JsonPath.parse(newQuery).read(queryFieldName));
406406
modelParameters.put(modelInputFieldName, queryFieldValue);
407407
}
408408
}

plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,6 @@ public void onFailure(Exception e) {
478478
*/
479479
private boolean checkIsModelInputMissing(Map<String, Object> document, Map<String, String> inputMapping) {
480480
boolean isModelInputMissing = false;
481-
482481
for (Map.Entry<String, String> inputMapEntry : inputMapping.entrySet()) {
483482
String oldDocumentFieldName = inputMapEntry.getValue();
484483
boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName);

plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchRequestProcessorTests.java

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.opensearch.index.query.QueryBuilder;
2929
import org.opensearch.index.query.RangeQueryBuilder;
3030
import org.opensearch.index.query.TermQueryBuilder;
31+
import org.opensearch.index.query.TermsQueryBuilder;
3132
import org.opensearch.ingest.Processor;
3233
import org.opensearch.ml.common.output.model.ModelTensor;
3334
import org.opensearch.ml.common.output.model.ModelTensorOutput;
@@ -247,6 +248,66 @@ public void onFailure(Exception e) {
247248

248249
}
249250

251+
/**
252+
* Tests the successful rewriting of multiple string in terms query based on the model output.
253+
*
254+
* @throws Exception if an error occurs during the test
255+
*/
256+
public void testExecute_rewriteTermsQuerySuccess() throws Exception {
257+
/**
258+
* example term query: {"query":{"terms":{"text":["foo","bar],"boost":1.0}}}
259+
*/
260+
String modelInputField = "inputs";
261+
String originalQueryField = "query.terms.text";
262+
String newQueryField = "query.terms.text";
263+
String modelOutputField = "response";
264+
MLInferenceSearchRequestProcessor requestProcessor = getMlInferenceSearchRequestProcessor(
265+
null,
266+
modelInputField,
267+
originalQueryField,
268+
newQueryField,
269+
modelOutputField,
270+
false,
271+
false
272+
);
273+
ModelTensor modelTensor = ModelTensor
274+
.builder()
275+
.dataAsMap(ImmutableMap.of("response", Arrays.asList("car", "vehicle", "truck")))
276+
.build();
277+
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
278+
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
279+
280+
doAnswer(invocation -> {
281+
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
282+
actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build());
283+
return null;
284+
}).when(client).execute(any(), any(), any());
285+
286+
QueryBuilder incomingQuery = new TermsQueryBuilder("text", Arrays.asList("foo", "bar"));
287+
SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery);
288+
SearchRequest request = new SearchRequest().source(source);
289+
/**
290+
* example terms query: {"query":{"terms":{"text":["car","vehicle","truck"],"boost":1.0}}}
291+
*/
292+
293+
ActionListener<SearchRequest> Listener = new ActionListener<>() {
294+
@Override
295+
public void onResponse(SearchRequest newSearchRequest) {
296+
QueryBuilder expectedQuery = new TermsQueryBuilder("text", Arrays.asList("car", "vehicle", "truck"));
297+
assertEquals(expectedQuery, newSearchRequest.source().query());
298+
assertEquals(request.toString(), newSearchRequest.toString());
299+
}
300+
301+
@Override
302+
public void onFailure(Exception e) {
303+
throw new RuntimeException("Failed in executing processRequestAsync.");
304+
}
305+
};
306+
307+
requestProcessor.processRequestAsync(request, requestContext, Listener);
308+
309+
}
310+
250311
/**
251312
* Tests the successful rewriting of a double in a term query based on the model output.
252313
*
@@ -444,7 +505,6 @@ public void onFailure(Exception e) {
444505
* @throws Exception if an error occurs during the test
445506
*/
446507
public void testExecute_rewriteListFromTermQueryToGeometryQuerySuccess() throws Exception {
447-
448508
String queryTemplate = "{\n"
449509
+ " \"query\": {\n"
450510
+ " \"geo_shape\" : {\n"

0 commit comments

Comments
 (0)