Skip to content

Commit 20e0978

Browse files
modify error message when model group not unique is provided (#1078) (#1088)
* modify error message when model group not unique is provided Signed-off-by: Bhavana Ramaram <[email protected]> * fix unique model group name unit test Signed-off-by: Bhavana Ramaram <[email protected]> --------- Signed-off-by: Bhavana Ramaram <[email protected]> (cherry picked from commit df2a0f5) Co-authored-by: Bhavana Ramaram <[email protected]>
1 parent f67d3d2 commit 20e0978

File tree

5 files changed

+142
-51
lines changed

5 files changed

+142
-51
lines changed

plugin/src/main/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupAction.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import java.util.HashMap;
1313
import java.util.HashSet;
14+
import java.util.Iterator;
1415
import java.util.Map;
1516

1617
import org.apache.commons.lang3.StringUtils;
@@ -43,6 +44,7 @@
4344
import org.opensearch.ml.utils.MLNodeUtils;
4445
import org.opensearch.ml.utils.RestActionUtils;
4546
import org.opensearch.rest.RestStatus;
47+
import org.opensearch.search.SearchHit;
4648
import org.opensearch.tasks.Task;
4749
import org.opensearch.transport.TransportService;
4850

@@ -150,23 +152,30 @@ private void updateModelGroup(
150152
source.put(MLModelGroup.DESCRIPTION_FIELD, updateModelGroupInput.getDescription());
151153
}
152154
if (StringUtils.isNotBlank(updateModelGroupInput.getName()) && !updateModelGroupInput.getName().equals(modelGroupName)) {
153-
mlModelGroupManager
154-
.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(isModelGroupNameUnique -> {
155-
if (Boolean.FALSE.equals(isModelGroupNameUnique)) {
155+
mlModelGroupManager.validateUniqueModelGroupName(updateModelGroupInput.getName(), ActionListener.wrap(modelGroups -> {
156+
if (modelGroups != null
157+
&& modelGroups.getHits().getTotalHits() != null
158+
&& modelGroups.getHits().getTotalHits().value != 0) {
159+
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
160+
while (iterator.hasNext()) {
161+
String id = iterator.next().getId();
156162
listener
157163
.onFailure(
158164
new IllegalArgumentException(
159-
"The name you provided is already being used by another model group. Please provide a different name."
165+
"The name you provided is already being used by another model with ID: "
166+
+ id
167+
+ ". Please provide a different name"
160168
)
161169
);
162-
} else {
163-
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
164-
updateModelGroup(modelGroupId, source, listener);
165170
}
166-
}, e -> {
167-
log.error("Failed to search model group index", e);
168-
listener.onFailure(e);
169-
}));
171+
} else {
172+
source.put(MLModelGroup.MODEL_GROUP_NAME_FIELD, updateModelGroupInput.getName());
173+
updateModelGroup(modelGroupId, source, listener);
174+
}
175+
}, e -> {
176+
log.error("Failed to search model group index", e);
177+
listener.onFailure(e);
178+
}));
170179
} else {
171180
updateModelGroup(modelGroupId, source, listener);
172181
}

plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import java.time.Instant;
1111
import java.util.HashSet;
12+
import java.util.Iterator;
1213

1314
import org.opensearch.action.ActionListener;
1415
import org.opensearch.action.index.IndexRequest;
1516
import org.opensearch.action.search.SearchRequest;
17+
import org.opensearch.action.search.SearchResponse;
1618
import org.opensearch.action.support.WriteRequest;
1719
import org.opensearch.client.Client;
1820
import org.opensearch.cluster.service.ClusterService;
@@ -32,6 +34,7 @@
3234
import org.opensearch.ml.helper.ModelAccessControlHelper;
3335
import org.opensearch.ml.indices.MLIndicesHandler;
3436
import org.opensearch.ml.utils.RestActionUtils;
37+
import org.opensearch.search.SearchHit;
3538
import org.opensearch.search.builder.SearchSourceBuilder;
3639

3740
import lombok.extern.log4j.Log4j2;
@@ -62,11 +65,22 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
6265
String modelName = input.getName();
6366
User user = RestActionUtils.getUserContext(client);
6467
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
65-
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(isUniqueModelGroupName -> {
66-
if (Boolean.FALSE.equals(isUniqueModelGroupName)) {
67-
throw new IllegalArgumentException(
68-
"The name you provided is already being used by another model group. Please provide a different name"
69-
);
68+
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> {
69+
if (modelGroups != null
70+
&& modelGroups.getHits().getTotalHits() != null
71+
&& modelGroups.getHits().getTotalHits().value != 0) {
72+
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
73+
while (iterator.hasNext()) {
74+
String id = iterator.next().getId();
75+
listener
76+
.onFailure(
77+
new IllegalArgumentException(
78+
"The name you provided is already being used by another model with ID: "
79+
+ id
80+
+ ". Please provide a different name"
81+
)
82+
);
83+
}
7084
} else {
7185
MLModelGroup.MLModelGroupBuilder builder = MLModelGroup.builder();
7286
MLModelGroup mlModelGroup;
@@ -170,21 +184,16 @@ private void validateRequestForAccessControl(MLRegisterModelGroupInput input, Us
170184
}
171185
}
172186

173-
public void validateUniqueModelGroupName(String name, ActionListener<Boolean> listener) throws IllegalArgumentException {
187+
public void validateUniqueModelGroupName(String name, ActionListener<SearchResponse> listener) throws IllegalArgumentException {
174188
BoolQueryBuilder query = new BoolQueryBuilder();
175189
query.filter(new TermQueryBuilder(MLRegisterModelGroupInput.NAME_FIELD + ".keyword", name));
176190

177191
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
178192
SearchRequest searchRequest = new SearchRequest(ML_MODEL_GROUP_INDEX).source(searchSourceBuilder);
179193

180-
client.search(searchRequest, ActionListener.wrap(modelGroups -> {
181-
listener
182-
.onResponse(
183-
modelGroups == null || modelGroups.getHits().getTotalHits() == null || modelGroups.getHits().getTotalHits().value == 0
184-
);
185-
}, e -> {
194+
client.search(searchRequest, ActionListener.wrap(modelGroups -> { listener.onResponse(modelGroups); }, e -> {
186195
if (e instanceof IndexNotFoundException) {
187-
listener.onResponse(true);
196+
listener.onResponse(null);
188197
} else {
189198
log.error("Failed to search model group index", e);
190199
listener.onFailure(e);

plugin/src/test/java/org/opensearch/ml/action/model_group/TransportUpdateModelGroupActionTests.java

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,15 @@
77

88
import static org.mockito.ArgumentMatchers.any;
99
import static org.mockito.Mockito.doAnswer;
10+
import static org.mockito.Mockito.mock;
1011
import static org.mockito.Mockito.verify;
1112
import static org.mockito.Mockito.when;
1213

1314
import java.io.IOException;
1415
import java.util.Arrays;
1516
import java.util.List;
1617

18+
import org.apache.lucene.search.TotalHits;
1719
import org.junit.Before;
1820
import org.junit.Rule;
1921
import org.junit.rules.ExpectedException;
@@ -22,6 +24,7 @@
2224
import org.mockito.MockitoAnnotations;
2325
import org.opensearch.action.ActionListener;
2426
import org.opensearch.action.get.GetResponse;
27+
import org.opensearch.action.search.SearchResponse;
2528
import org.opensearch.action.support.ActionFilters;
2629
import org.opensearch.action.update.UpdateResponse;
2730
import org.opensearch.client.Client;
@@ -45,6 +48,9 @@
4548
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupResponse;
4649
import org.opensearch.ml.helper.ModelAccessControlHelper;
4750
import org.opensearch.ml.model.MLModelGroupManager;
51+
import org.opensearch.ml.utils.TestHelper;
52+
import org.opensearch.search.SearchHit;
53+
import org.opensearch.search.SearchHits;
4854
import org.opensearch.tasks.Task;
4955
import org.opensearch.test.OpenSearchTestCase;
5056
import org.opensearch.threadpool.ThreadPool;
@@ -134,9 +140,10 @@ public void setup() throws IOException {
134140
return null;
135141
}).when(client).get(any(), any());
136142

143+
SearchResponse searchResponse = createModelGroupSearchResponse(0);
137144
doAnswer(invocation -> {
138-
ActionListener<Boolean> listener = invocation.getArgument(1);
139-
listener.onResponse(true);
145+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
146+
listener.onResponse(searchResponse);
140147
return null;
141148
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());
142149

@@ -385,25 +392,24 @@ public void test_SuccessSecurityDisabledCluster() {
385392
verify(actionListener).onResponse(argumentCaptor.capture());
386393
}
387394

388-
public void test_ModelGroupNameNotUnique() {
395+
public void test_ModelGroupNameNotUnique() throws IOException {
389396

397+
when(modelAccessControlHelper.isSecurityEnabledAndModelAccessControlEnabled(any())).thenReturn(false);
398+
399+
SearchResponse searchResponse = createModelGroupSearchResponse(1);
390400
doAnswer(invocation -> {
391-
ActionListener<Boolean> listener = invocation.getArgument(1);
392-
listener.onResponse(false);
401+
ActionListener<SearchResponse> listener = invocation.getArgument(1);
402+
listener.onResponse(searchResponse);
393403
return null;
394404
}).when(mlModelGroupManager).validateUniqueModelGroupName(any(), any());
395405

396-
when(modelAccessControlHelper.isOwner(any(), any())).thenReturn(true);
397-
when(modelAccessControlHelper.isOwnerStillHasPermission(any(), any())).thenReturn(true);
398-
when(modelAccessControlHelper.isAdmin(any())).thenReturn(false);
399-
400406
MLUpdateModelGroupRequest actionRequest = prepareRequest(null, null, null);
401407
transportUpdateModelGroupAction.doExecute(task, actionRequest, actionListener);
402408
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
403409
verify(actionListener).onFailure(argumentCaptor.capture());
404410
assertEquals(
405-
"The name you provided is already being used by another model group. Please provide a different name.",
406-
argumentCaptor.getValue().getMessage()
411+
"The name you provided is already being used by another model with ID: model_group_ID. Please provide a different name",
412+
argumentCaptor.getValue().getMessage()
407413
);
408414
}
409415

@@ -432,4 +438,21 @@ private MLUpdateModelGroupRequest prepareRequest(List<String> backendRoles, Acce
432438
return new MLUpdateModelGroupRequest(UpdateModelGroupInput);
433439
}
434440

441+
private SearchResponse createModelGroupSearchResponse(long totalHits) throws IOException {
442+
SearchResponse searchResponse = mock(SearchResponse.class);
443+
String modelContent = "{\n"
444+
+ " \"created_time\": 1684981986069,\n"
445+
+ " \"access\": \"public\",\n"
446+
+ " \"latest_version\": 0,\n"
447+
+ " \"last_updated_time\": 1684981986069,\n"
448+
+ " \"_id\": \"model_group_ID\",\n"
449+
+ " \"name\": \"model_group_IT\",\n"
450+
+ " \"description\": \"This is an example description\"\n"
451+
+ " }";
452+
SearchHit modelGroup = SearchHit.fromXContent(TestHelper.parser(modelContent));
453+
SearchHits hits = new SearchHits(new SearchHit[] { modelGroup }, new TotalHits(totalHits, TotalHits.Relation.EQUAL_TO), Float.NaN);
454+
when(searchResponse.getHits()).thenReturn(hits);
455+
return searchResponse;
456+
}
457+
435458
}

plugin/src/test/java/org/opensearch/ml/helper/ModelAccessControlHelperTests.java

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import java.util.List;
1818

1919
import org.junit.Before;
20-
import org.junit.Ignore;
2120
import org.mockito.ArgumentCaptor;
2221
import org.mockito.Mock;
2322
import org.mockito.MockitoAnnotations;
@@ -108,7 +107,6 @@ public void test_UndefinedOwner() throws IOException {
108107
assertTrue(argumentCaptor.getValue());
109108
}
110109

111-
@Ignore
112110
public void test_ExceptionEmptyBackendRoles() throws IOException {
113111
String owner = "owner|IT,HR|myTenant";
114112
User user = User.parse("owner|IT,HR|myTenant");
@@ -119,7 +117,6 @@ public void test_ExceptionEmptyBackendRoles() throws IOException {
119117
assertEquals("Backend roles shouldn't be null", argumentCaptor.getValue().getMessage());
120118
}
121119

122-
@Ignore
123120
public void test_MatchingBackendRoles() throws IOException {
124121
String owner = "owner|IT,HR|myTenant";
125122
List<String> backendRoles = Arrays.asList("IT", "HR");
@@ -131,7 +128,6 @@ public void test_MatchingBackendRoles() throws IOException {
131128
assertTrue(argumentCaptor.getValue());
132129
}
133130

134-
@Ignore
135131
public void test_PublicModelGroup() throws IOException {
136132
String owner = "owner|IT,HR|myTenant";
137133
List<String> backendRoles = Arrays.asList("IT", "HR");
@@ -143,7 +139,6 @@ public void test_PublicModelGroup() throws IOException {
143139
assertTrue(argumentCaptor.getValue());
144140
}
145141

146-
@Ignore
147142
public void test_PrivateModelGroupWithSameOwner() throws IOException {
148143
String owner = "owner|IT,HR|myTenant";
149144
List<String> backendRoles = Arrays.asList("IT", "HR");
@@ -155,7 +150,6 @@ public void test_PrivateModelGroupWithSameOwner() throws IOException {
155150
assertTrue(argumentCaptor.getValue());
156151
}
157152

158-
@Ignore
159153
public void test_PrivateModelGroupWithDifferentOwner() throws IOException {
160154
String owner = "owner|IT,HR|myTenant";
161155
List<String> backendRoles = Arrays.asList("IT", "HR");

0 commit comments

Comments
 (0)