Skip to content

Commit 7ade595

Browse files
authored
adding feature settings for all the agentic memory apis (opensearch-project#4074)
Signed-off-by: Dhrubo Saha <[email protected]>
1 parent d6c7983 commit 7ade595

13 files changed

+122
-41
lines changed

plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportAddMemoriesAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_EMBEDDING_FIELD;
1515
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORY_FIELD;
1616
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PERSONAL_INFORMATION_ORGANIZER_PROMPT;
17+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
1718

1819
import java.time.Instant;
1920
import java.util.ArrayList;
@@ -173,6 +174,11 @@ public TransportAddMemoriesAction(
173174

174175
@Override
175176
protected void doExecute(Task task, MLAddMemoriesRequest request, ActionListener<MLAddMemoriesResponse> actionListener) {
177+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
178+
actionListener.onFailure(new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN));
179+
return;
180+
}
181+
176182
User user = RestActionUtils.getUserContext(client);
177183
MLAddMemoriesInput input = request.getMlAddMemoryInput();
178184

plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportDeleteMemoryAction.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.core.action.ActionListener;
1818
import org.opensearch.core.rest.RestStatus;
1919
import org.opensearch.core.xcontent.NamedXContentRegistry;
20+
import org.opensearch.ml.common.settings.MLCommonsSettings;
2021
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
2122
import org.opensearch.ml.common.transport.memorycontainer.memory.MLDeleteMemoryAction;
2223
import org.opensearch.ml.common.transport.memorycontainer.memory.MLDeleteMemoryRequest;
@@ -65,6 +66,14 @@ public TransportDeleteMemoryAction(
6566

6667
@Override
6768
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
69+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
70+
actionListener
71+
.onFailure(
72+
new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN)
73+
);
74+
return;
75+
}
76+
6877
MLDeleteMemoryRequest deleteRequest = MLDeleteMemoryRequest.fromActionRequest(request);
6978
String memoryContainerId = deleteRequest.getMemoryContainerId();
7079
String memoryId = deleteRequest.getMemoryId();

plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportSearchMemoriesAction.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.action.memorycontainer.memory;
77

88
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.*;
9+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
910

1011
import java.io.IOException;
1112
import java.time.Instant;
@@ -78,6 +79,11 @@ public TransportSearchMemoriesAction(
7879

7980
@Override
8081
protected void doExecute(Task task, MLSearchMemoriesRequest request, ActionListener<MLSearchMemoriesResponse> actionListener) {
82+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
83+
actionListener.onFailure(new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN));
84+
return;
85+
}
86+
8187
MLSearchMemoriesInput input = request.getMlSearchMemoriesInput();
8288
String tenantId = request.getTenantId();
8389

plugin/src/main/java/org/opensearch/ml/action/memorycontainer/memory/TransportUpdateMemoryAction.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.core.rest.RestStatus;
2626
import org.opensearch.core.xcontent.NamedXContentRegistry;
2727
import org.opensearch.ml.common.memorycontainer.MemoryStorageConfig;
28+
import org.opensearch.ml.common.settings.MLCommonsSettings;
2829
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
2930
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction;
3031
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest;
@@ -81,6 +82,14 @@ public TransportUpdateMemoryAction(
8182

8283
@Override
8384
protected void doExecute(Task task, ActionRequest request, ActionListener<UpdateResponse> actionListener) {
85+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
86+
actionListener
87+
.onFailure(
88+
new OpenSearchStatusException(MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN)
89+
);
90+
return;
91+
}
92+
8493
MLUpdateMemoryRequest updateRequest = MLUpdateMemoryRequest.fromActionRequest(request);
8594
String memoryContainerId = updateRequest.getMemoryContainerId();
8695
String memoryId = updateRequest.getMemoryId();

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -934,13 +934,13 @@ public List<RestHandler> getRestHandlers(
934934
mlFeatureEnabledSetting
935935
);
936936
RestMLGetMemoryContainerAction restMLGetMemoryContainerAction = new RestMLGetMemoryContainerAction(mlFeatureEnabledSetting);
937+
RestMLAddMemoriesAction restMLAddMemoriesAction = new RestMLAddMemoriesAction(mlFeatureEnabledSetting);
938+
RestMLSearchMemoriesAction restMLSearchMemoriesAction = new RestMLSearchMemoriesAction(mlFeatureEnabledSetting);
939+
RestMLDeleteMemoryAction restMLDeleteMemoryAction = new RestMLDeleteMemoryAction(mlFeatureEnabledSetting);
940+
RestMLUpdateMemoryAction restMLUpdateMemoryAction = new RestMLUpdateMemoryAction(mlFeatureEnabledSetting);
937941
RestMLDeleteMemoryContainerAction restMLDeleteMemoryContainerAction = new RestMLDeleteMemoryContainerAction(
938942
mlFeatureEnabledSetting
939943
);
940-
RestMLAddMemoriesAction restMLAddMemoriesAction = new RestMLAddMemoriesAction();
941-
RestMLSearchMemoriesAction restMLSearchMemoriesAction = new RestMLSearchMemoriesAction(mlFeatureEnabledSetting);
942-
RestMLDeleteMemoryAction restMLDeleteMemoryAction = new RestMLDeleteMemoryAction();
943-
RestMLUpdateMemoryAction restMLUpdateMemoryAction = new RestMLUpdateMemoryAction();
944944
RestMemorySearchConversationsAction restSearchConversationsAction = new RestMemorySearchConversationsAction(
945945
mlFeatureEnabledSetting
946946
);

plugin/src/main/java/org/opensearch/ml/rest/RestMLAddMemoriesAction.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77

88
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.MEMORIES_PATH;
99
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID;
10+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
1011
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
1112

1213
import java.io.IOException;
1314
import java.util.List;
1415

16+
import org.opensearch.OpenSearchStatusException;
17+
import org.opensearch.core.rest.RestStatus;
1518
import org.opensearch.core.xcontent.XContentParser;
19+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1620
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesAction;
1721
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesInput;
1822
import org.opensearch.ml.common.transport.memorycontainer.memory.MLAddMemoriesRequest;
@@ -29,6 +33,14 @@
2933
public class RestMLAddMemoriesAction extends BaseRestHandler {
3034

3135
private static final String ML_ADD_MEMORIES_ACTION = "ml_add_memories_action";
36+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
37+
38+
/**
39+
* Constructor
40+
*/
41+
public RestMLAddMemoriesAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
42+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
43+
}
3244

3345
@Override
3446
public List<Route> routes() {
@@ -42,6 +54,10 @@ public String getName() {
4254

4355
@Override
4456
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
57+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
58+
throw new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN);
59+
}
60+
4561
MLAddMemoriesRequest mlAddMemoryRequest = getRequest(request);
4662
return channel -> client.execute(MLAddMemoriesAction.INSTANCE, mlAddMemoryRequest, new RestToXContentListener<>(channel));
4763
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLDeleteMemoryAction.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.DELETE_MEMORY_PATH;
99
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID;
1010
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_ID;
11+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
1112
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
1213

1314
import java.io.IOException;
1415
import java.util.List;
1516

17+
import org.opensearch.OpenSearchStatusException;
18+
import org.opensearch.core.rest.RestStatus;
19+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1620
import org.opensearch.ml.common.transport.memorycontainer.memory.MLDeleteMemoryAction;
1721
import org.opensearch.ml.common.transport.memorycontainer.memory.MLDeleteMemoryRequest;
1822
import org.opensearch.rest.BaseRestHandler;
@@ -29,11 +33,14 @@
2933
public class RestMLDeleteMemoryAction extends BaseRestHandler {
3034

3135
private static final String ML_DELETE_MEMORY_ACTION = "ml_delete_memory_action";
36+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
3237

3338
/**
3439
* Constructor
3540
*/
36-
public RestMLDeleteMemoryAction() {}
41+
public RestMLDeleteMemoryAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
42+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
43+
}
3744

3845
@Override
3946
public String getName() {
@@ -47,6 +54,10 @@ public List<Route> routes() {
4754

4855
@Override
4956
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
57+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
58+
throw new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN);
59+
}
60+
5061
MLDeleteMemoryRequest mlDeleteMemoryRequest = getRequest(request);
5162
return channel -> client.execute(MLDeleteMemoryAction.INSTANCE, mlDeleteMemoryRequest, new RestToXContentListener<>(channel));
5263
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLSearchMemoriesAction.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
99
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID;
1010
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.SEARCH_MEMORIES_PATH;
11+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
1112
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
1213
import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID;
1314

1415
import java.io.IOException;
1516
import java.util.List;
1617

18+
import org.opensearch.OpenSearchStatusException;
19+
import org.opensearch.core.rest.RestStatus;
1720
import org.opensearch.core.xcontent.XContentParser;
1821
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1922
import org.opensearch.ml.common.transport.memorycontainer.memory.MLSearchMemoriesAction;
@@ -55,6 +58,10 @@ public List<Route> routes() {
5558

5659
@Override
5760
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
61+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
62+
throw new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN);
63+
}
64+
5865
MLSearchMemoriesRequest mlSearchMemoriesRequest = getRequest(request);
5966
return channel -> client.execute(MLSearchMemoriesAction.INSTANCE, mlSearchMemoriesRequest, new RestToXContentListener<>(channel));
6067
}

plugin/src/main/java/org/opensearch/ml/rest/RestMLUpdateMemoryAction.java

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_CONTAINER_ID;
1010
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.PARAMETER_MEMORY_ID;
1111
import static org.opensearch.ml.common.memorycontainer.MemoryContainerConstants.UPDATE_MEMORY_PATH;
12+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE;
1213
import static org.opensearch.ml.utils.RestActionUtils.getParameterId;
1314

1415
import java.io.IOException;
1516
import java.util.List;
1617

18+
import org.opensearch.OpenSearchStatusException;
19+
import org.opensearch.core.rest.RestStatus;
1720
import org.opensearch.core.xcontent.XContentParser;
21+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1822
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryAction;
1923
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryInput;
2024
import org.opensearch.ml.common.transport.memorycontainer.memory.MLUpdateMemoryRequest;
@@ -32,11 +36,14 @@
3236
public class RestMLUpdateMemoryAction extends BaseRestHandler {
3337

3438
private static final String ML_UPDATE_MEMORY_ACTION = "ml_update_memory_action";
39+
private final MLFeatureEnabledSetting mlFeatureEnabledSetting;
3540

3641
/**
3742
* Constructor
3843
*/
39-
public RestMLUpdateMemoryAction() {}
44+
public RestMLUpdateMemoryAction(MLFeatureEnabledSetting mlFeatureEnabledSetting) {
45+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
46+
}
4047

4148
@Override
4249
public String getName() {
@@ -50,6 +57,10 @@ public List<Route> routes() {
5057

5158
@Override
5259
protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException {
60+
if (!mlFeatureEnabledSetting.isAgenticMemoryEnabled()) {
61+
throw new OpenSearchStatusException(ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE, RestStatus.FORBIDDEN);
62+
}
63+
5364
MLUpdateMemoryRequest mlUpdateMemoryRequest = getRequest(request);
5465
return channel -> client.execute(MLUpdateMemoryAction.INSTANCE, mlUpdateMemoryRequest, new RestToXContentListener<>(channel));
5566
}

plugin/src/test/java/org/opensearch/ml/action/memorycontainer/TransportCreateMemoryContainerActionTests.java

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,14 +1087,14 @@ public void testDoExecuteWithNullMemoryStorageConfig() throws InterruptedExcepti
10871087
public void testDoExecuteWithTenantValidationFailure() throws InterruptedException {
10881088
// Enable multi-tenancy and provide null tenant ID to trigger validation failure
10891089
when(mlFeatureEnabledSetting.isMultiTenancyEnabled()).thenReturn(true);
1090-
1090+
10911091
MLCreateMemoryContainerInput invalidTenantInput = MLCreateMemoryContainerInput
1092-
.builder()
1093-
.name("invalid-tenant-container")
1094-
.description("Container with invalid tenant")
1095-
.memoryStorageConfig(memoryStorageConfig)
1096-
.tenantId(null) // This should trigger tenant validation failure
1097-
.build();
1092+
.builder()
1093+
.name("invalid-tenant-container")
1094+
.description("Container with invalid tenant")
1095+
.memoryStorageConfig(memoryStorageConfig)
1096+
.tenantId(null) // This should trigger tenant validation failure
1097+
.build();
10981098

10991099
MLCreateMemoryContainerRequest invalidTenantRequest = new MLCreateMemoryContainerRequest(invalidTenantInput);
11001100

0 commit comments

Comments
 (0)