Skip to content

Commit 6db14b1

Browse files
authored
fill null parameters in connector body template (#1192) (#1219)
Signed-off-by: Yaliang Wu <[email protected]>
1 parent c93005f commit 6db14b1

File tree

4 files changed

+49
-3
lines changed

4 files changed

+49
-3
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import java.util.Map;
2020
import java.util.Optional;
2121
import java.util.function.Function;
22+
import java.util.regex.Matcher;
23+
import java.util.regex.Pattern;
24+
2225
import lombok.Builder;
2326
import lombok.EqualsAndHashCode;
2427
import lombok.NoArgsConstructor;
@@ -81,7 +84,7 @@ public HttpConnector(String protocol, XContentParser parser) throws IOException
8184
description = parser.text();
8285
break;
8386
case PROTOCOL_FIELD:
84-
protocol = parser.text();
87+
this.protocol = parser.text();
8588
break;
8689
case PARAMETERS_FIELD:
8790
Map<String, Object> map = parser.map();
@@ -250,6 +253,7 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
250253
Optional<ConnectorAction> predictAction = findPredictAction();
251254
if (predictAction.isPresent() && predictAction.get().getRequestBody() != null) {
252255
String payload = predictAction.get().getRequestBody();
256+
payload = fillNullParameters(parameters, payload);
253257
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
254258
payload = substitutor.replace(payload);
255259

@@ -261,6 +265,30 @@ public <T> T createPredictPayload(Map<String, String> parameters) {
261265
return (T) parameters.get("http_body");
262266
}
263267

268+
protected String fillNullParameters(Map<String, String> parameters, String payload) {
269+
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
270+
String newPayload = payload;
271+
for (String key : bodyParams) {
272+
if (!parameters.containsKey(key) || parameters.get(key) == null) {
273+
newPayload = newPayload.replace("\"${parameters." + key + ":-null}\"", "null");
274+
}
275+
}
276+
return newPayload;
277+
}
278+
279+
private List<String> findStringParametersWithNullDefaultValue(String input) {
280+
String regex = "\"\\$\\{parameters\\.(\\w+):-null}\"";
281+
Pattern pattern = Pattern.compile(regex);
282+
Matcher matcher = pattern.matcher(input);
283+
284+
List<String> paramList = new ArrayList<>();
285+
while (matcher.find()) {
286+
String parameterValue = matcher.group(1);
287+
paramList.add(parameterValue);
288+
}
289+
return paramList;
290+
}
291+
264292
@Override
265293
public void decrypt(Function<String, String> function) {
266294
Map<String, String> decrypted = new HashMap<>();

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,14 @@ public void parseResponse_NonJsonString() throws IOException {
259259
Assert.assertEquals("test output", modelTensors.get(0).getDataAsMap().get("response"));
260260
}
261261

262+
@Test
263+
public void fillNullParameters() {
264+
HttpConnector connector = createHttpConnector();
265+
Map<String, String> parameters = new HashMap<>();
266+
String output = connector.fillNullParameters(parameters, "{\"input1\": \"${parameters.input1:-null}\"}");
267+
Assert.assertEquals("{\"input1\": null}", output);
268+
}
269+
262270
public static HttpConnector createHttpConnector() {
263271
ConnectorAction.ActionType actionType = ConnectorAction.ActionType.PREDICT;
264272
String method = "POST";

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
9999
if (inputData.getParameters() != null) {
100100
Map<String, String> newParameters = new HashMap<>();
101101
inputData.getParameters().entrySet().forEach(entry -> {
102-
if (StringUtils.isJson(entry.getValue())) {
102+
if (entry.getValue() == null) {
103+
newParameters.put(entry.getKey(), entry.getValue());
104+
} else if (StringUtils.isJson(entry.getValue())) {
103105
// no need to escape if it's already valid json
104106
newParameters.put(entry.getKey(), entry.getValue());
105107
} else {

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,16 @@ public void processInput_RemoteInferenceInputDataSet_NotEscapeJsonString() {
9696
processInput_RemoteInferenceInputDataSet(input, input);
9797
}
9898

99+
@Test
100+
public void processInput_RemoteInferenceInputDataSet_NullParam() {
101+
String input = null;
102+
processInput_RemoteInferenceInputDataSet(input, input);
103+
}
104+
99105
private void processInput_RemoteInferenceInputDataSet(String input, String expectedInput) {
100-
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(ImmutableMap.of("input", input)).build();
106+
Map<String, String> params = new HashMap<>();
107+
params.put("input", input);
108+
RemoteInferenceInputDataSet dataSet = RemoteInferenceInputDataSet.builder().parameters(params).build();
101109
MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(dataSet).build();
102110

103111
ConnectorAction predictAction = ConnectorAction.builder()

0 commit comments

Comments
 (0)