Skip to content

Commit 1841810

Browse files
guardrails bug fixes and IT for creating guardrails (#2269) (#2275)
Signed-off-by: Jing Zhang <[email protected]> (cherry picked from commit be56bcf) Co-authored-by: Jing Zhang <[email protected]>
1 parent 1babe1f commit 1841810

File tree

9 files changed

+290
-29
lines changed

9 files changed

+290
-29
lines changed

common/src/main/java/org/opensearch/ml/common/model/Guardrails.java

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,22 @@
2222
@Getter
2323
public class Guardrails implements ToXContentObject {
2424
public static final String TYPE_FIELD = "type";
25-
public static final String ENGLISH_DETECTION_ENABLED_FIELD = "english_detection_enabled";
2625
public static final String INPUT_GUARDRAIL_FIELD = "input_guardrail";
2726
public static final String OUTPUT_GUARDRAIL_FIELD = "output_guardrail";
2827

2928
private String type;
30-
private Boolean engDetectionEnabled;
3129
private Guardrail inputGuardrail;
3230
private Guardrail outputGuardrail;
3331

3432
@Builder(toBuilder = true)
35-
public Guardrails(String type, Boolean engDetectionEnabled, Guardrail inputGuardrail, Guardrail outputGuardrail) {
33+
public Guardrails(String type, Guardrail inputGuardrail, Guardrail outputGuardrail) {
3634
this.type = type;
37-
this.engDetectionEnabled = engDetectionEnabled;
3835
this.inputGuardrail = inputGuardrail;
3936
this.outputGuardrail = outputGuardrail;
4037
}
4138

4239
public Guardrails(StreamInput input) throws IOException {
4340
type = input.readString();
44-
engDetectionEnabled = input.readBoolean();
4541
if (input.readBoolean()) {
4642
inputGuardrail = new Guardrail(input);
4743
}
@@ -52,7 +48,6 @@ public Guardrails(StreamInput input) throws IOException {
5248

5349
public void writeTo(StreamOutput out) throws IOException {
5450
out.writeString(type);
55-
out.writeBoolean(engDetectionEnabled);
5651
if (inputGuardrail != null) {
5752
out.writeBoolean(true);
5853
inputGuardrail.writeTo(out);
@@ -73,9 +68,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
7368
if (type != null) {
7469
builder.field(TYPE_FIELD, type);
7570
}
76-
if (engDetectionEnabled != null) {
77-
builder.field(ENGLISH_DETECTION_ENABLED_FIELD, engDetectionEnabled);
78-
}
7971
if (inputGuardrail != null) {
8072
builder.field(INPUT_GUARDRAIL_FIELD, inputGuardrail);
8173
}
@@ -88,7 +80,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
8880

8981
public static Guardrails parse(XContentParser parser) throws IOException {
9082
String type = null;
91-
Boolean engDetectionEnabled = null;
9283
Guardrail inputGuardrail = null;
9384
Guardrail outputGuardrail = null;
9485

@@ -101,9 +92,6 @@ public static Guardrails parse(XContentParser parser) throws IOException {
10192
case TYPE_FIELD:
10293
type = parser.text();
10394
break;
104-
case ENGLISH_DETECTION_ENABLED_FIELD:
105-
engDetectionEnabled = parser.booleanValue();
106-
break;
10795
case INPUT_GUARDRAIL_FIELD:
10896
inputGuardrail = Guardrail.parse(parser);
10997
break;
@@ -117,7 +105,6 @@ public static Guardrails parse(XContentParser parser) throws IOException {
117105
}
118106
return Guardrails.builder()
119107
.type(type)
120-
.engDetectionEnabled(engDetectionEnabled)
121108
.inputGuardrail(inputGuardrail)
122109
.outputGuardrail(outputGuardrail)
123110
.build();

common/src/main/java/org/opensearch/ml/common/model/MLGuard.java

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ private void fillStopWordsToMap(@NonNull Guardrail guardrail, Map<String, List<S
8383
}
8484
}
8585

86-
public Boolean validate(String input, int type) {
86+
public Boolean validate(String input, Type type) {
8787
switch (type) {
88-
case 0: // validate input
88+
case INPUT: // validate input
8989
return validateRegexList(input, inputRegexPattern) && validateStopWords(input, stopWordsIndicesInput);
90-
case 1: // validate output
90+
case OUTPUT: // validate output
9191
return validateRegexList(input, outputRegexPattern) && validateStopWords(input, stopWordsIndicesOutput);
9292
default:
9393
throw new IllegalArgumentException("Unsupported type to validate for guardrails.");
@@ -159,4 +159,9 @@ public Boolean validateStopWordsSingleIndex(String input, String indexName, List
159159
}
160160
return hitStopWords.get();
161161
}
162+
163+
public enum Type {
164+
INPUT,
165+
OUTPUT
166+
}
162167
}

common/src/test/java/org/opensearch/ml/common/model/GuardrailsTests.java

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,24 @@ public void setUp() {
4040

4141
@Test
4242
public void writeTo() throws IOException {
43-
Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
43+
Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail);
4444
BytesStreamOutput output = new BytesStreamOutput();
4545
guardrails.writeTo(output);
4646
Guardrails guardrails1 = new Guardrails(output.bytes().streamInput());
4747

4848
Assert.assertEquals(guardrails.getType(), guardrails1.getType());
49-
Assert.assertEquals(guardrails.getEngDetectionEnabled(), guardrails1.getEngDetectionEnabled());
5049
Assert.assertEquals(guardrails.getInputGuardrail(), guardrails1.getInputGuardrail());
5150
Assert.assertEquals(guardrails.getOutputGuardrail(), guardrails1.getOutputGuardrail());
5251
}
5352

5453
@Test
5554
public void toXContent() throws IOException {
56-
Guardrails guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
55+
Guardrails guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail);
5756
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
5857
guardrails.toXContent(builder, ToXContent.EMPTY_PARAMS);
5958
String content = TestHelper.xContentBuilderToString(builder);
6059

6160
Assert.assertEquals("{\"type\":\"test_type\"," +
62-
"\"english_detection_enabled\":false," +
6361
"\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," +
6462
"\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}",
6563
content);
@@ -68,7 +66,6 @@ public void toXContent() throws IOException {
6866
@Test
6967
public void parse() throws IOException {
7068
String jsonStr = "{\"type\":\"test_type\"," +
71-
"\"english_detection_enabled\":false," +
7269
"\"input_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}," +
7370
"\"output_guardrail\":{\"stop_words\":[{\"index_name\":\"test_index\",\"source_fields\":[\"test_field\"]}],\"regex\":[\"regex1\"]}}";
7471
XContentParser parser = XContentType.JSON.xContent().createParser(new NamedXContentRegistry(new SearchModule(Settings.EMPTY,
@@ -77,7 +74,6 @@ public void parse() throws IOException {
7774
Guardrails guardrails = Guardrails.parse(parser);
7875

7976
Assert.assertEquals(guardrails.getType(), "test_type");
80-
Assert.assertEquals(guardrails.getEngDetectionEnabled(), false);
8177
Assert.assertEquals(guardrails.getInputGuardrail(), inputGuardrail);
8278
Assert.assertEquals(guardrails.getOutputGuardrail(), outputGuardrail);
8379
}

common/src/test/java/org/opensearch/ml/common/model/MLGuardTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,22 +75,22 @@ public void setUp() {
7575
regexPatterns = List.of(Pattern.compile("(.|\n)*stop words(.|\n)*"));
7676
inputGuardrail = new Guardrail(List.of(stopWords), regex);
7777
outputGuardrail = new Guardrail(List.of(stopWords), regex);
78-
guardrails = new Guardrails("test_type", false, inputGuardrail, outputGuardrail);
78+
guardrails = new Guardrails("test_type", inputGuardrail, outputGuardrail);
7979
mlGuard = new MLGuard(guardrails, xContentRegistry, client);
8080
}
8181

8282
@Test
8383
public void validateInput() {
8484
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
85-
Boolean res = mlGuard.validate(input, 0);
85+
Boolean res = mlGuard.validate(input, MLGuard.Type.INPUT);
8686

8787
Assert.assertFalse(res);
8888
}
8989

9090
@Test
9191
public void validateOutput() {
9292
String input = "\n\nHuman:hello stop words.\n\nAssistant:";
93-
Boolean res = mlGuard.validate(input, 1);
93+
Boolean res = mlGuard.validate(input, MLGuard.Type.OUTPUT);
9494

9595
Assert.assertFalse(res);
9696
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
150150
throw new OpenSearchStatusException("No response from model", RestStatus.BAD_REQUEST);
151151
}
152152
String modelResponse = responseBuilder.toString();
153-
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, 1)) {
153+
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
154154
throw new IllegalArgumentException("guardrails triggered for LLM output");
155155
}
156156
if (statusCode < 200 || statusCode >= 300) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ public void invokeRemoteModel(MLInput mlInput, Map<String, String> parameters, S
139139
return null;
140140
});
141141
String modelResponse = responseRef.get();
142-
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, 1)) {
142+
if (getMlGuard() != null && !getMlGuard().validate(modelResponse, MLGuard.Type.OUTPUT)) {
143143
throw new IllegalArgumentException("guardrails triggered for LLM output");
144144
}
145145
Integer statusCode = statusCodeRef.get();

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ && getUserRateLimiterMap().get(user.getName()) != null
142142
RestStatus.TOO_MANY_REQUESTS
143143
);
144144
} else {
145-
if (getMlGuard() != null && !getMlGuard().validate(payload, 0)) {
145+
if (getMlGuard() != null && !getMlGuard().validate(payload, MLGuard.Type.INPUT)) {
146146
throw new IllegalArgumentException("guardrails triggered for user input");
147147
}
148148
invokeRemoteModel(mlInput, parameters, payload, tensorOutputs);

0 commit comments

Comments
 (0)