2929import org .opensearch .search .builder .SearchSourceBuilder ;
3030
3131import com .google .common .base .Throwables ;
32- import com .google .common .collect .ImmutableList ;
3332
3433public class SecureMLRestIT extends MLCommonsRestTestCase {
3534 private String irisIndex = "iris_data_secure_ml_it" ;
@@ -53,6 +52,8 @@ public class SecureMLRestIT extends MLCommonsRestTestCase {
5352 @ Rule
5453 public ExpectedException exceptionRule = ExpectedException .none ();
5554
55+ private String modelGroupId ;
56+
5657 @ Before
5758 public void setup () throws IOException {
5859 if (!isHttps ()) {
@@ -102,8 +103,13 @@ public void setup() throws IOException {
102103 searchSourceBuilder .size (1000 );
103104 searchSourceBuilder .fetchSource (new String [] { "petal_length_in_cm" , "petal_width_in_cm" }, null );
104105
105- mlRegisterModelInput = createRegisterModelInput ("testModelGroupID" );
106- mlRegisterModelGroupInput = createRegisterModelGroupInput (ImmutableList .of ("role-1" ), ModelAccessMode .RESTRICTED , false );
106+ // Create public model group
107+ mlRegisterModelGroupInput = createRegisterModelGroupInput (null , ModelAccessMode .PUBLIC , false );
108+
109+ registerModelGroup (mlFullAccessClient , TestHelper .toJsonString (mlRegisterModelGroupInput ), registerModelGroupResult -> {
110+ this .modelGroupId = (String ) registerModelGroupResult .get ("model_group_id" );
111+ });
112+ mlRegisterModelInput = createRegisterModelInput (modelGroupId );
107113 }
108114
109115 @ After
@@ -157,29 +163,26 @@ public void testRegisterModelWithReadOnlyMLAccess() throws IOException {
157163 }
158164
159165 public void testRegisterModelWithFullAccess () throws IOException {
160- registerModelGroup (mlFullAccessClient , TestHelper .toJsonString (mlRegisterModelGroupInput ), registerModelGroupResult -> {
161- try {
162- String modelGroupId = (String ) registerModelGroupResult .get ("model_group_id" );
163- MLRegisterModelInput mlRegisterModelInput = createRegisterModelInput (modelGroupId );
164- registerModel (mlFullAccessClient , TestHelper .toJsonString (mlRegisterModelInput ), registerModelResult -> {
165- assertFalse (registerModelResult .containsKey ("model_id" ));
166- String taskId = (String ) registerModelResult .get ("task_id" );
167- assertNotNull (taskId );
168- String status = (String ) registerModelResult .get ("status" );
169- assertEquals (MLTaskState .CREATED .name (), status );
170- try {
171- getTask (mlFullAccessClient , taskId , task -> {
172- String algorithm = (String ) task .get ("function_name" );
173- assertEquals (FunctionName .TEXT_EMBEDDING .name (), algorithm );
174- });
175- } catch (IOException e ) {
176- assertNull (e );
177- }
178- });
179- } catch (IOException e ) {
180- throw new RuntimeException (e );
181- }
182- });
166+ try {
167+ MLRegisterModelInput mlRegisterModelInput = createRegisterModelInput (modelGroupId );
168+ registerModel (mlFullAccessClient , TestHelper .toJsonString (mlRegisterModelInput ), registerModelResult -> {
169+ assertFalse (registerModelResult .containsKey ("model_id" ));
170+ String taskId = (String ) registerModelResult .get ("task_id" );
171+ assertNotNull (taskId );
172+ String status = (String ) registerModelResult .get ("status" );
173+ assertEquals (MLTaskState .CREATED .name (), status );
174+ try {
175+ getTask (mlFullAccessClient , taskId , task -> {
176+ String algorithm = (String ) task .get ("function_name" );
177+ assertEquals (FunctionName .TEXT_EMBEDDING .name (), algorithm );
178+ });
179+ } catch (IOException e ) {
180+ assertNull (e );
181+ }
182+ });
183+ } catch (IOException e ) {
184+ throw new RuntimeException (e );
185+ }
183186 }
184187
185188 public void testDeployModelWithNoAccess () throws IOException , InterruptedException {
0 commit comments