Skip to content

Commit 84db0a2

Browse files
fix model access mode upper case bug (#937) (#939)
* fix upper case bug Signed-off-by: Yaliang Wu <[email protected]> * create shared public model group in SecureMLRestIT Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit 7c39b1b) Co-authored-by: Yaliang Wu <[email protected]>
1 parent b8c7daf commit 84db0a2

File tree

3 files changed

+33
-28
lines changed

3 files changed

+33
-28
lines changed

common/src/main/java/org/opensearch/ml/common/transport/model_group/MLRegisterModelGroupInput.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.Locale;
2122

2223
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2324

@@ -123,7 +124,7 @@ public static MLRegisterModelGroupInput parse(XContentParser parser) throws IOEx
123124
}
124125
break;
125126
case MODEL_ACCESS_MODE:
126-
modelAccessMode = ModelAccessMode.from(parser.text());
127+
modelAccessMode = ModelAccessMode.from(parser.text().toLowerCase(Locale.ROOT));
127128
break;
128129
case ADD_ALL_BACKEND_ROLES:
129130
isAddAllBackendRoles = parser.booleanValue();

common/src/main/java/org/opensearch/ml/common/transport/model_group/MLUpdateModelGroupInput.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import java.io.IOException;
1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.Locale;
2122

2223
import static org.opensearch.common.xcontent.XContentParserUtils.ensureExpectedToken;
2324

@@ -133,7 +134,7 @@ public static MLUpdateModelGroupInput parse(XContentParser parser) throws IOExce
133134
}
134135
break;
135136
case MODEL_ACCESS_MODE:
136-
modelAccessMode = ModelAccessMode.from(parser.text());
137+
modelAccessMode = ModelAccessMode.from(parser.text().toLowerCase(Locale.ROOT));
137138
break;
138139
case ADD_ALL_BACKEND_ROLES_FIELD:
139140
isAddAllBackendRoles = parser.booleanValue();

plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import org.opensearch.search.builder.SearchSourceBuilder;
3030

3131
import com.google.common.base.Throwables;
32-
import com.google.common.collect.ImmutableList;
3332

3433
public class SecureMLRestIT extends MLCommonsRestTestCase {
3534
private String irisIndex = "iris_data_secure_ml_it";
@@ -53,6 +52,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5352
@Rule
5453
public ExpectedException exceptionRule = ExpectedException.none();
5554

55+
private String modelGroupId;
56+
5657
@Before
5758
public void setup() throws IOException {
5859
if (!isHttps()) {
@@ -102,8 +103,13 @@ public void setup() throws IOException {
102103
searchSourceBuilder.size(1000);
103104
searchSourceBuilder.fetchSource(new String[] { "petal_length_in_cm", "petal_width_in_cm" }, null);
104105

105-
mlRegisterModelInput = createRegisterModelInput("testModelGroupID");
106-
mlRegisterModelGroupInput = createRegisterModelGroupInput(ImmutableList.of("role-1"), ModelAccessMode.RESTRICTED, false);
106+
// Create public model group
107+
mlRegisterModelGroupInput = createRegisterModelGroupInput(null, ModelAccessMode.PUBLIC, false);
108+
109+
registerModelGroup(mlFullAccessClient, TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> {
110+
this.modelGroupId = (String) registerModelGroupResult.get("model_group_id");
111+
});
112+
mlRegisterModelInput = createRegisterModelInput(modelGroupId);
107113
}
108114

109115
@After
@@ -157,29 +163,26 @@ public void testRegisterModelWithReadOnlyMLAccess() throws IOException {
157163
}
158164

159165
public void testRegisterModelWithFullAccess() throws IOException {
160-
registerModelGroup(mlFullAccessClient, TestHelper.toJsonString(mlRegisterModelGroupInput), registerModelGroupResult -> {
161-
try {
162-
String modelGroupId = (String) registerModelGroupResult.get("model_group_id");
163-
MLRegisterModelInput mlRegisterModelInput = createRegisterModelInput(modelGroupId);
164-
registerModel(mlFullAccessClient, TestHelper.toJsonString(mlRegisterModelInput), registerModelResult -> {
165-
assertFalse(registerModelResult.containsKey("model_id"));
166-
String taskId = (String) registerModelResult.get("task_id");
167-
assertNotNull(taskId);
168-
String status = (String) registerModelResult.get("status");
169-
assertEquals(MLTaskState.CREATED.name(), status);
170-
try {
171-
getTask(mlFullAccessClient, taskId, task -> {
172-
String algorithm = (String) task.get("function_name");
173-
assertEquals(FunctionName.TEXT_EMBEDDING.name(), algorithm);
174-
});
175-
} catch (IOException e) {
176-
assertNull(e);
177-
}
178-
});
179-
} catch (IOException e) {
180-
throw new RuntimeException(e);
181-
}
182-
});
166+
try {
167+
MLRegisterModelInput mlRegisterModelInput = createRegisterModelInput(modelGroupId);
168+
registerModel(mlFullAccessClient, TestHelper.toJsonString(mlRegisterModelInput), registerModelResult -> {
169+
assertFalse(registerModelResult.containsKey("model_id"));
170+
String taskId = (String) registerModelResult.get("task_id");
171+
assertNotNull(taskId);
172+
String status = (String) registerModelResult.get("status");
173+
assertEquals(MLTaskState.CREATED.name(), status);
174+
try {
175+
getTask(mlFullAccessClient, taskId, task -> {
176+
String algorithm = (String) task.get("function_name");
177+
assertEquals(FunctionName.TEXT_EMBEDDING.name(), algorithm);
178+
});
179+
} catch (IOException e) {
180+
assertNull(e);
181+
}
182+
});
183+
} catch (IOException e) {
184+
throw new RuntimeException(e);
185+
}
183186
}
184187

185188
public void testDeployModelWithNoAccess() throws IOException, InterruptedException {

0 commit comments

Comments
 (0)