Skip to content

Commit 5fccfd5

Browse files
fix system index access bug #1272 (#1320) (#1322)
Signed-off-by: HenryL27 <[email protected]> (cherry picked from commit 8cdac91) Co-authored-by: HenryL27 <[email protected]>
1 parent 5b24f61 commit 5fccfd5

File tree

3 files changed

+42
-64
lines changed

3 files changed

+42
-64
lines changed

memory/src/main/java/org/opensearch/ml/memory/index/ConversationMetaIndex.java

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import org.opensearch.index.query.MatchAllQueryBuilder;
4949
import org.opensearch.index.query.QueryBuilder;
5050
import org.opensearch.index.query.TermQueryBuilder;
51+
import org.opensearch.ml.common.conversation.ActionConstants;
5152
import org.opensearch.ml.common.conversation.ConversationMeta;
5253
import org.opensearch.ml.common.conversation.ConversationalIndexConstants;
5354
import org.opensearch.search.SearchHit;
@@ -75,7 +76,7 @@ public void initConversationMetaIndexIfAbsent(ActionListener<Boolean> listener)
7576
if (!clusterService.state().metadata().hasIndex(indexName)) {
7677
log.debug("No conversational meta index found. Adding it");
7778
CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.META_MAPPING);
78-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
79+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
7980
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
8081
ActionListener<CreateIndexResponse> al = ActionListener.wrap(createIndexResponse -> {
8182
if (createIndexResponse.equals(new CreateIndexResponse(true, true, indexName))) {
@@ -130,7 +131,7 @@ public void createConversation(String name, ActionListener<String> listener) {
130131
ConversationalIndexConstants.USER_FIELD,
131132
userstr == null ? null : User.parse(userstr).getName()
132133
);
133-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
134+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
134135
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
135136
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
136137
if (resp.status() == RestStatus.CREATED) {
@@ -181,7 +182,7 @@ public void getConversations(int from, int maxResults, ActionListener<List<Conve
181182
request.source().query(queryBuilder);
182183
request.source().from(from).size(maxResults);
183184
request.source().sort(ConversationalIndexConstants.META_CREATED_FIELD, SortOrder.DESC);
184-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
185+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
185186
ActionListener<List<ConversationMeta>> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
186187
ActionListener<SearchResponse> al = ActionListener.wrap(searchResponse -> {
187188
List<ConversationMeta> result = new LinkedList<ConversationMeta>();
@@ -225,37 +226,34 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
225226
listener.onResponse(true);
226227
}
227228
DeleteRequest delRequest = Requests.deleteRequest(indexName).id(conversationId);
228-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
229-
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
230-
// When we get the delete response, do this:
231-
ActionListener<DeleteResponse> al = ActionListener.wrap(deleteResponse -> {
232-
if (deleteResponse.getResult() == Result.DELETED) {
233-
internalListener.onResponse(true);
234-
} else if (deleteResponse.status() == RestStatus.NOT_FOUND) {
235-
internalListener.onResponse(true);
236-
} else {
237-
internalListener.onResponse(false);
238-
}
239-
}, e -> {
240-
log.error("Failure deleting conversation " + conversationId, e);
241-
internalListener.onFailure(e);
242-
});
243-
this.checkAccess(conversationId, ActionListener.wrap(access -> {
244-
if (access) {
229+
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
230+
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
231+
this.checkAccess(conversationId, ActionListener.wrap(access -> {
232+
if (access) {
233+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
234+
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
235+
// When we get the delete response, do this:
236+
ActionListener<DeleteResponse> al = ActionListener.wrap(deleteResponse -> {
237+
if (deleteResponse.getResult() == Result.DELETED) {
238+
internalListener.onResponse(true);
239+
} else if (deleteResponse.status() == RestStatus.NOT_FOUND) {
240+
internalListener.onResponse(true);
241+
} else {
242+
internalListener.onResponse(false);
243+
}
244+
}, e -> {
245+
log.error("Failure deleting conversation " + conversationId, e);
246+
internalListener.onFailure(e);
247+
});
245248
client.delete(delRequest, al);
246-
} else {
247-
String userstr = client
248-
.threadPool()
249-
.getThreadContext()
250-
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
251-
String user = User.parse(userstr).getName();
252-
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
249+
} catch (Exception e) {
250+
log.error("Failed deleting conversation with id=" + conversationId, e);
251+
listener.onFailure(e);
253252
}
254-
}, e -> { internalListener.onFailure(e); }));
255-
} catch (Exception e) {
256-
log.error("Failed deleting conversation with id=" + conversationId, e);
257-
listener.onFailure(e);
258-
}
253+
} else {
254+
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
255+
}
256+
}, e -> { listener.onFailure(e); }));
259257
}
260258

261259
/**
@@ -269,13 +267,9 @@ public void checkAccess(String conversationId, ActionListener<Boolean> listener)
269267
listener.onResponse(true);
270268
return;
271269
}
272-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
270+
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
271+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
273272
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
274-
String userstr = client
275-
.threadPool()
276-
.getThreadContext()
277-
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
278-
log.info("USERSTR: " + userstr);
279273
// If security is off - User doesn't exist - you have permission
280274
if (userstr == null || User.parse(userstr) == null) {
281275
internalListener.onResponse(true);

memory/src/main/java/org/opensearch/ml/memory/index/InteractionsIndex.java

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public void initInteractionsIndexIfAbsent(ActionListener<Boolean> listener) {
7575
if (!clusterService.state().metadata().hasIndex(indexName)) {
7676
log.debug("No interactions index found. Adding it");
7777
CreateIndexRequest request = Requests.createIndexRequest(indexName).mapping(ConversationalIndexConstants.INTERACTIONS_MAPPINGS);
78-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
78+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
7979
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
8080
ActionListener<CreateIndexResponse> al = ActionListener.wrap(r -> {
8181
if (r.equals(new CreateIndexResponse(true, true, indexName))) {
@@ -130,6 +130,11 @@ public void createInteraction(
130130
ActionListener<String> listener
131131
) {
132132
initInteractionsIndexIfAbsent(ActionListener.wrap(indexExists -> {
133+
String userstr = client
134+
.threadPool()
135+
.getThreadContext()
136+
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
137+
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
133138
if (indexExists) {
134139
this.conversationMetaIndex.checkAccess(conversationId, ActionListener.wrap(access -> {
135140
if (access) {
@@ -151,7 +156,7 @@ public void createInteraction(
151156
ConversationalIndexConstants.INTERACTIONS_CREATE_TIME_FIELD,
152157
timestamp
153158
);
154-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
159+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
155160
ActionListener<String> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
156161
ActionListener<IndexResponse> al = ActionListener.wrap(resp -> {
157162
if (resp.status() == RestStatus.CREATED) {
@@ -165,13 +170,6 @@ public void createInteraction(
165170
listener.onFailure(e);
166171
}
167172
} else {
168-
String userstr = client
169-
.threadPool()
170-
.getThreadContext()
171-
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
172-
String user = User.parse(userstr) == null
173-
? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS
174-
: User.parse(userstr).getName();
175173
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
176174
}
177175
}, e -> { listener.onFailure(e); }));
@@ -313,7 +311,9 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
313311
listener.onResponse(true);
314312
return;
315313
}
316-
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().newStoredContext(true)) {
314+
String userstr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
315+
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
316+
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
317317
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
318318
ActionListener<List<Interaction>> searchListener = ActionListener.wrap(interactions -> {
319319
BulkRequest request = Requests.bulkRequest();
@@ -330,11 +330,6 @@ public void deleteConversation(String conversationId, ActionListener<Boolean> li
330330
if (access) {
331331
getAllInteractions(conversationId, resultsAtATime, searchListener);
332332
} else {
333-
String userstr = client
334-
.threadPool()
335-
.getThreadContext()
336-
.getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT);
337-
String user = User.parse(userstr) == null ? ActionConstants.DEFAULT_USERNAME_FOR_ERRORS : User.parse(userstr).getName();
338333
throw new OpenSearchSecurityException("User [" + user + "] does not have access to conversation " + conversationId);
339334
}
340335
}, e -> { listener.onFailure(e); });

memory/src/test/java/org/opensearch/ml/memory/index/ConversationMetaIndexTests.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -402,17 +402,6 @@ public void testDelete_DeleteFails_ThenFail() {
402402
assert (argCaptor.getValue().getMessage().equals("Test Fail in Delete"));
403403
}
404404

405-
public void testDelete_HighLevelFailure_ThenFail() {
406-
doReturn(true).when(metadata).hasIndex(anyString());
407-
doThrow(new RuntimeException("Check Fail")).when(conversationMetaIndex).checkAccess(any(), any());
408-
@SuppressWarnings("unchecked")
409-
ActionListener<Boolean> deleteConversationListener = mock(ActionListener.class);
410-
conversationMetaIndex.deleteConversation("test-id", deleteConversationListener);
411-
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
412-
verify(deleteConversationListener, times(1)).onFailure(argCaptor.capture());
413-
assert (argCaptor.getValue().getMessage().equals("Check Fail"));
414-
}
415-
416405
public void testCheckAccess_DoesNotExist_ThenFail() {
417406
setupUser("user");
418407
doReturn(true).when(metadata).hasIndex(anyString());

0 commit comments

Comments
 (0)