Skip to content

Commit a966bf1

Browse files
stash context before accessing ml config index; increase master key size to 32 (#1092) (#1093)
Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit d9d1190) Co-authored-by: Yaliang Wu <[email protected]>
1 parent 16245f0 commit a966bf1

File tree

16 files changed

+92
-46
lines changed

16 files changed

+92
-46
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.opensearch.action.get.GetResponse;
1818
import org.opensearch.client.Client;
1919
import org.opensearch.cluster.service.ClusterService;
20+
import org.opensearch.common.util.concurrent.ThreadContext;
2021
import org.opensearch.ml.common.exception.MLException;
2122

2223
import javax.crypto.spec.SecretKeySpec;
@@ -62,9 +63,9 @@ public String encrypt(String plainText) {
6263
final AwsCrypto crypto = AwsCrypto.builder()
6364
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
6465
.build();
65-
66+
byte[] bytes = Base64.getDecoder().decode(masterKey);
6667
JceMasterKey jceMasterKey
67-
= JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "",
68+
= JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "",
6869
"AES/GCM/NoPadding");
6970

7071
final CryptoResult<byte[], JceMasterKey> encryptResult = crypto.encryptData(jceMasterKey,
@@ -79,8 +80,9 @@ public String decrypt(String encryptedText) {
7980
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
8081
.build();
8182

83+
byte[] bytes = Base64.getDecoder().decode(masterKey);
8284
JceMasterKey jceMasterKey
83-
= JceMasterKey.getInstance(new SecretKeySpec(masterKey.getBytes(), "AES"), "Custom", "",
85+
= JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "",
8486
"AES/GCM/NoPadding");
8587

8688
final CryptoResult<byte[], JceMasterKey> decryptedResult
@@ -90,7 +92,7 @@ public String decrypt(String encryptedText) {
9092

9193
@Override
9294
public String generateMasterKey() {
93-
byte[] keyBytes = new byte[16];
95+
byte[] keyBytes = new byte[32];
9496
new SecureRandom().nextBytes(keyBytes);
9597
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
9698
return base64Key;
@@ -104,18 +106,20 @@ private void initMasterKey() {
104106

105107
CountDownLatch latch = new CountDownLatch(1);
106108
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
107-
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
108-
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
109-
if (r.isExists()) {
110-
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
111-
setMasterKey(masterKey);
112-
} else {
113-
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
114-
}
115-
}, e -> {
116-
log.error("Failed to get ML encryption master key", e);
117-
exceptionRef.set(e);
118-
}), latch));
109+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
110+
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
111+
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
112+
if (r.isExists()) {
113+
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
114+
setMasterKey(masterKey);
115+
} else {
116+
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
117+
}
118+
}, e -> {
119+
log.error("Failed to get ML encryption master key", e);
120+
exceptionRef.set(e);
121+
}), latch));
122+
}
119123
} else {
120124
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
121125
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public class MLEngineTest {
5353

5454
@Before
5555
public void setUp() {
56-
Encryptor encryptor = new EncryptorImpl("0000000000000000");
56+
Encryptor encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
5757
mlEngine = new MLEngine(Path.of("/tmp/test" + UUID.randomUUID()), encryptor);
5858
}
5959

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ public void setUp() throws IOException, URISyntaxException {
157157
System.setProperty("testMode", "true");
158158

159159
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
160-
encryptor = new EncryptorImpl("0000000000000001");
160+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
161161
mlEngine = new MLEngine(mlCachePath, encryptor);
162162
modelConfig = MetricsCorrelationModelConfig.builder()
163163
.modelType(MetricsCorrelation.MODEL_TYPE)

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ public class AwsConnectorExecutorTest {
6767
@Before
6868
public void setUp() {
6969
MockitoAnnotations.openMocks(this);
70-
encryptor = new EncryptorImpl("0000000000000001");
70+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
7171
}
7272

7373
@Test

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteModelTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public class RemoteModelTest {
4848
public void setUp() {
4949
MockitoAnnotations.openMocks(this);
5050
remoteModel = new RemoteModel();
51-
encryptor = spy(new EncryptorImpl("0000000000000001"));
51+
encryptor = spy(new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="));
5252
}
5353

5454
@Test

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ public void setup() throws URISyntaxException {
5959
MockitoAnnotations.openMocks(this);
6060
modelFormat = MLModelFormat.TORCH_SCRIPT;
6161
modelId = "model_id";
62-
encryptor = new EncryptorImpl("0000000000000001");
62+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
6363
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
6464
modelHelper = new ModelHelper(mlEngine);
6565
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ public class TextEmbeddingModelTest {
7373
@Before
7474
public void setUp() throws URISyntaxException {
7575
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
76-
encryptor = new EncryptorImpl("0000000000000001");
76+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
7777
mlEngine = new MLEngine(mlCachePath, encryptor);
7878
modelId = "test_model_id";
7979
modelName = "test_model_name";

ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import org.opensearch.cluster.metadata.Metadata;
1919
import org.opensearch.cluster.service.ClusterService;
2020
import org.opensearch.common.settings.Settings;
21+
import org.opensearch.common.util.concurrent.ThreadContext;
22+
import org.opensearch.commons.ConfigConstants;
23+
import org.opensearch.threadpool.ThreadPool;
2124

2225
import java.time.Instant;
2326

@@ -43,10 +46,15 @@ public class EncryptorImplTest {
4346

4447
String masterKey;
4548

49+
@Mock
50+
ThreadPool threadPool;
51+
ThreadContext threadContext;
52+
final String USER_STRING = "myuser|role1,role2|myTenant";
53+
4654
@Before
4755
public void setUp() {
4856
MockitoAnnotations.openMocks(this);
49-
masterKey = "0000000000000001";
57+
masterKey = "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=";
5058

5159
doAnswer(invocation -> {
5260
ActionListener<GetResponse> listener = invocation.getArgument(1);
@@ -72,6 +80,12 @@ public void setUp() {
7280
.build())
7381
.build()).build();
7482
when(clusterState.metadata()).thenReturn(metadata);
83+
84+
Settings settings = Settings.builder().build();
85+
threadContext = new ThreadContext(settings);
86+
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING);
87+
when(client.threadPool()).thenReturn(threadPool);
88+
when(threadPool.getThreadContext()).thenReturn(threadContext);
7589
}
7690

7791
@Test
@@ -83,6 +97,17 @@ public void encrypt() {
8397
Assert.assertEquals(masterKey, encryptor.getMasterKey());
8498
}
8599

100+
@Test
101+
public void encrypt_DifferentMasterKey() {
102+
Encryptor encryptor = new EncryptorImpl(masterKey);
103+
Assert.assertNotNull(encryptor.getMasterKey());
104+
String encrypted1 = encryptor.encrypt("test");
105+
106+
encryptor.setMasterKey(encryptor.generateMasterKey());
107+
String encrypted2 = encryptor.encrypt("test");
108+
Assert.assertNotEquals(encrypted1, encrypted2);
109+
}
110+
86111
@Test
87112
public void decrypt() {
88113
Encryptor encryptor = new EncryptorImpl(clusterService, client);

plugin/src/main/java/org/opensearch/ml/cluster/MLSyncUpCron.java

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import org.opensearch.client.Client;
3232
import org.opensearch.cluster.node.DiscoveryNode;
3333
import org.opensearch.cluster.service.ClusterService;
34+
import org.opensearch.common.util.concurrent.ThreadContext;
3435
import org.opensearch.index.query.BoolQueryBuilder;
3536
import org.opensearch.index.query.TermsQueryBuilder;
3637
import org.opensearch.ml.common.MLModel;
@@ -168,24 +169,26 @@ void initMLConfig() {
168169
}
169170
mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> {
170171
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
171-
client.get(getRequest, ActionListener.wrap(getResponse -> {
172-
if (!getResponse.isExists()) {
173-
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
174-
final String masterKey = encryptor.generateMasterKey();
175-
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
176-
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
177-
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
178-
log.info("ML configuration initialized successfully");
172+
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
173+
client.get(getRequest, ActionListener.wrap(getResponse -> {
174+
if (!getResponse.isExists()) {
175+
IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
176+
final String masterKey = encryptor.generateMasterKey();
177+
indexRequest.source(ImmutableMap.of(MASTER_KEY, masterKey, CREATE_TIME_FIELD, Instant.now().toEpochMilli()));
178+
indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
179+
client.index(indexRequest, ActionListener.wrap(indexResponse -> {
180+
log.info("ML configuration initialized successfully");
181+
encryptor.setMasterKey(masterKey);
182+
mlConfigInited = true;
183+
}, e -> { log.debug("Failed to save ML encryption master key", e); }));
184+
} else {
185+
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
179186
encryptor.setMasterKey(masterKey);
180187
mlConfigInited = true;
181-
}, e -> { log.debug("Failed to save ML encryption master key", e); }));
182-
} else {
183-
final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY);
184-
encryptor.setMasterKey(masterKey);
185-
mlConfigInited = true;
186-
log.info("ML configuration already initialized, no action needed");
187-
}
188-
}, e -> { log.debug("Failed to get ML encryption master key", e); }));
188+
log.info("ML configuration already initialized, no action needed");
189+
}
190+
}, e -> { log.debug("Failed to get ML encryption master key", e); }));
191+
}
189192
}, e -> { log.debug("Failed to init ML config index", e); }));
190193
}
191194

plugin/src/test/java/org/opensearch/ml/action/deploy/TransportDeployModelActionTests.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ public void setup() {
142142
clusterSettings = new ClusterSettings(settings, new HashSet<>(Arrays.asList(ML_COMMONS_ALLOW_CUSTOM_DEPLOYMENT_PLAN)));
143143
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
144144

145-
encryptor = new EncryptorImpl("0000000000000001");
145+
encryptor = new EncryptorImpl("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=");
146146
mlEngine = new MLEngine(Path.of("/tmp/test" + randomAlphaOfLength(10)), encryptor);
147147
modelHelper = new ModelHelper(mlEngine);
148148
when(mlDeployModelRequest.getModelId()).thenReturn("mockModelId");

0 commit comments

Comments
 (0)