Skip to content

Commit f7a821e

Browse files
GET memory API (opensearch-project#4069)
* Add memory API Signed-off-by: rithin-pullela-aws <[email protected]> * Fix feature flag Signed-off-by: rithin-pullela-aws <[email protected]> * Add UTs Signed-off-by: rithin-pullela-aws <[email protected]> * Fix test Signed-off-by: rithin-pullela-aws <[email protected]> * add feature flag for Get memory API Signed-off-by: rithin-pullela-aws <[email protected]> --------- Signed-off-by: rithin-pullela-aws <[email protected]>
1 parent 2916b8d commit f7a821e

File tree

12 files changed

+1257
-1
lines changed

12 files changed

+1257
-1
lines changed

common/src/main/java/org/opensearch/ml/common/memorycontainer/MemoryContainerConstants.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ public class MemoryContainerConstants {
7272
public static final String SEARCH_MEMORIES_PATH = MEMORIES_PATH + "/_search";
7373
public static final String DELETE_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}";
7474
public static final String UPDATE_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}";
75+
public static final String GET_MEMORY_PATH = MEMORIES_PATH + "/{" + PARAMETER_MEMORY_ID + "}";
7576

7677
// Memory types are defined in MemoryType enum
7778

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
105105
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_EXECUTE_TOOL_ENABLED, it -> isExecuteToolEnabled = it);
106106
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it);
107107
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
108-
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isMcpConnectorEnabled = it);
108+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_MEMORY_ENABLED, it -> isAgenticMemoryEnabled = it);
109109
}
110110

111111
/**
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.memorycontainer.memory;
7+
8+
import org.opensearch.action.ActionType;
9+
10+
public class MLGetMemoryAction extends ActionType<MLGetMemoryResponse> {
11+
public static final MLGetMemoryAction INSTANCE = new MLGetMemoryAction();
12+
public static final String NAME = "cluster:admin/opensearch/ml/memory_containers/memory/get";
13+
14+
private MLGetMemoryAction() {
15+
super(NAME, MLGetMemoryResponse::new);
16+
}
17+
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.memorycontainer.memory;
7+
8+
import static org.opensearch.action.ValidateActions.addValidationError;
9+
10+
import java.io.ByteArrayInputStream;
11+
import java.io.ByteArrayOutputStream;
12+
import java.io.IOException;
13+
import java.io.UncheckedIOException;
14+
15+
import org.opensearch.action.ActionRequest;
16+
import org.opensearch.action.ActionRequestValidationException;
17+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
18+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
19+
import org.opensearch.core.common.io.stream.StreamInput;
20+
import org.opensearch.core.common.io.stream.StreamOutput;
21+
22+
import lombok.AccessLevel;
23+
import lombok.Builder;
24+
import lombok.Getter;
25+
import lombok.ToString;
26+
import lombok.experimental.FieldDefaults;
27+
28+
@Getter
29+
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
30+
@ToString
31+
public class MLGetMemoryRequest extends ActionRequest {
32+
33+
String memoryContainerId;
34+
String memoryId;
35+
36+
@Builder
37+
public MLGetMemoryRequest(String memoryContainerId, String memoryId) {
38+
this.memoryContainerId = memoryContainerId;
39+
this.memoryId = memoryId;
40+
}
41+
42+
public MLGetMemoryRequest(StreamInput in) throws IOException {
43+
super(in);
44+
this.memoryContainerId = in.readString();
45+
this.memoryId = in.readString();
46+
}
47+
48+
@Override
49+
public void writeTo(StreamOutput out) throws IOException {
50+
super.writeTo(out);
51+
out.writeString(this.memoryContainerId);
52+
out.writeString(this.memoryId);
53+
}
54+
55+
@Override
56+
public ActionRequestValidationException validate() {
57+
ActionRequestValidationException exception = null;
58+
59+
if (this.memoryContainerId == null || this.memoryId == null) {
60+
exception = addValidationError("memoryContainerId and memoryId id can not be null", exception);
61+
}
62+
63+
return exception;
64+
}
65+
66+
public static MLGetMemoryRequest fromActionRequest(ActionRequest actionRequest) {
67+
if (actionRequest instanceof MLGetMemoryRequest) {
68+
return (MLGetMemoryRequest) actionRequest;
69+
}
70+
71+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
72+
actionRequest.writeTo(osso);
73+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
74+
return new MLGetMemoryRequest(input);
75+
}
76+
} catch (IOException e) {
77+
throw new UncheckedIOException("failed to parse ActionRequest into MLMemoryGetRequest", e);
78+
}
79+
}
80+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.memorycontainer.memory;
7+
8+
import java.io.ByteArrayInputStream;
9+
import java.io.ByteArrayOutputStream;
10+
import java.io.IOException;
11+
import java.io.UncheckedIOException;
12+
13+
import org.opensearch.core.action.ActionResponse;
14+
import org.opensearch.core.common.io.stream.InputStreamStreamInput;
15+
import org.opensearch.core.common.io.stream.OutputStreamStreamOutput;
16+
import org.opensearch.core.common.io.stream.StreamInput;
17+
import org.opensearch.core.common.io.stream.StreamOutput;
18+
import org.opensearch.core.xcontent.ToXContentObject;
19+
import org.opensearch.core.xcontent.XContentBuilder;
20+
import org.opensearch.ml.common.memorycontainer.MLMemory;
21+
22+
import lombok.Builder;
23+
import lombok.Getter;
24+
import lombok.ToString;
25+
26+
@Getter
27+
@ToString
28+
public class MLGetMemoryResponse extends ActionResponse implements ToXContentObject {
29+
MLMemory mlMemory;
30+
31+
@Builder
32+
public MLGetMemoryResponse(MLMemory mlMemory) {
33+
this.mlMemory = mlMemory;
34+
}
35+
36+
public MLGetMemoryResponse(StreamInput in) throws IOException {
37+
super(in);
38+
mlMemory = new MLMemory(in);
39+
}
40+
41+
@Override
42+
public void writeTo(StreamOutput out) throws IOException {
43+
mlMemory.writeTo(out);
44+
}
45+
46+
@Override
47+
public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
48+
return mlMemory.toXContent(xContentBuilder, params);
49+
}
50+
51+
public static MLGetMemoryResponse fromActionResponse(ActionResponse actionResponse) {
52+
if (actionResponse instanceof MLGetMemoryResponse) {
53+
return (MLGetMemoryResponse) actionResponse;
54+
}
55+
56+
try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
57+
actionResponse.writeTo(osso);
58+
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
59+
return new MLGetMemoryResponse(input);
60+
}
61+
} catch (IOException e) {
62+
throw new UncheckedIOException("failed to parse ActionResponse into MLMemoryGetResponse", e);
63+
}
64+
}
65+
}
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.ml.common.transport.memorycontainer.memory;
7+
8+
import static org.junit.Assert.assertEquals;
9+
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertNull;
11+
import static org.junit.Assert.assertTrue;
12+
13+
import java.io.IOException;
14+
import java.io.UncheckedIOException;
15+
16+
import org.junit.Before;
17+
import org.junit.Test;
18+
import org.opensearch.action.ActionRequest;
19+
import org.opensearch.action.ActionRequestValidationException;
20+
import org.opensearch.common.io.stream.BytesStreamOutput;
21+
import org.opensearch.core.common.io.stream.StreamInput;
22+
import org.opensearch.core.common.io.stream.StreamOutput;
23+
24+
public class MLGetMemoryRequestTest {
25+
26+
private MLGetMemoryRequest requestNormal;
27+
private MLGetMemoryRequest requestWithNulls;
28+
29+
@Before
30+
public void setUp() {
31+
requestNormal = MLGetMemoryRequest.builder().memoryContainerId("container-123").memoryId("memory-456").build();
32+
33+
requestWithNulls = MLGetMemoryRequest.builder().memoryContainerId(null).memoryId(null).build();
34+
}
35+
36+
@Test
37+
public void testBuilderNormal() {
38+
assertNotNull(requestNormal);
39+
assertEquals("container-123", requestNormal.getMemoryContainerId());
40+
assertEquals("memory-456", requestNormal.getMemoryId());
41+
}
42+
43+
@Test
44+
public void testStreamInputOutput() throws IOException {
45+
BytesStreamOutput out = new BytesStreamOutput();
46+
requestNormal.writeTo(out);
47+
StreamInput in = out.bytes().streamInput();
48+
MLGetMemoryRequest deserialized = new MLGetMemoryRequest(in);
49+
50+
assertEquals(requestNormal.getMemoryContainerId(), deserialized.getMemoryContainerId());
51+
assertEquals(requestNormal.getMemoryId(), deserialized.getMemoryId());
52+
}
53+
54+
@Test
55+
public void testValidateSuccess() {
56+
ActionRequestValidationException exception = requestNormal.validate();
57+
assertNull(exception);
58+
}
59+
60+
@Test
61+
public void testValidateWithNullValues() {
62+
ActionRequestValidationException exception = requestWithNulls.validate();
63+
assertNotNull(exception);
64+
assertEquals(1, exception.validationErrors().size());
65+
assertTrue(exception.validationErrors().get(0).contains("memoryContainerId and memoryId id can not be null"));
66+
}
67+
68+
@Test
69+
public void testFromActionRequestSameInstance() {
70+
MLGetMemoryRequest result = MLGetMemoryRequest.fromActionRequest(requestNormal);
71+
assertEquals(requestNormal, result);
72+
}
73+
74+
@Test
75+
public void testFromActionRequestDifferentInstance() throws IOException {
76+
// Create a mock ActionRequest that's not MLGetMemoryRequest
77+
ActionRequest mockRequest = new ActionRequest() {
78+
@Override
79+
public ActionRequestValidationException validate() {
80+
return null;
81+
}
82+
83+
@Override
84+
public void writeTo(StreamOutput out) throws IOException {
85+
super.writeTo(out);
86+
out.writeString("test-container");
87+
out.writeString("test-memory");
88+
}
89+
};
90+
91+
MLGetMemoryRequest result = MLGetMemoryRequest.fromActionRequest(mockRequest);
92+
assertNotNull(result);
93+
assertEquals("test-container", result.getMemoryContainerId());
94+
assertEquals("test-memory", result.getMemoryId());
95+
}
96+
97+
@Test(expected = UncheckedIOException.class)
98+
public void testFromActionRequestIOException() {
99+
// Create a mock ActionRequest that throws IOException
100+
ActionRequest mockRequest = new ActionRequest() {
101+
@Override
102+
public ActionRequestValidationException validate() {
103+
return null;
104+
}
105+
106+
@Override
107+
public void writeTo(StreamOutput out) throws IOException {
108+
throw new IOException("Test exception");
109+
}
110+
};
111+
112+
MLGetMemoryRequest.fromActionRequest(mockRequest);
113+
}
114+
}

0 commit comments

Comments
 (0)