Skip to content

Commit 16245f0

Browse files
init master key automatically (#1075) (#1091)
* init master key automatically Signed-off-by: Yaliang Wu <[email protected]> * remove unnecessary escape Signed-off-by: Yaliang Wu <[email protected]> * fix failed ut Signed-off-by: Yaliang Wu <[email protected]> * tune syncup jot interval Signed-off-by: Yaliang Wu <[email protected]> * tune syncup jot interval Signed-off-by: Yaliang Wu <[email protected]> * remove local config file code Signed-off-by: Yaliang Wu <[email protected]> * set master key when init remote model Signed-off-by: Yaliang Wu <[email protected]> * move init master key to encryptor Signed-off-by: Yaliang Wu <[email protected]> * fine tune code Signed-off-by: Yaliang Wu <[email protected]> * fine tune code Signed-off-by: Yaliang Wu <[email protected]> --------- Signed-off-by: Yaliang Wu <[email protected]> (cherry picked from commit 66672d9) Co-authored-by: Yaliang Wu <[email protected]>
1 parent 9908229 commit 16245f0

File tree

23 files changed

+415
-65
lines changed

23 files changed

+415
-65
lines changed

common/src/main/java/org/opensearch/ml/common/CommonValue.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ public class CommonValue {
2424
public static final String UNDEPLOYED = "undeployed";
2525
public static final String NOT_FOUND = "not_found";
2626

27+
public static final String MASTER_KEY = "master_key";
28+
public static final String CREATE_TIME_FIELD = "create_time";
29+
2730
public static final String BOX_TYPE_KEY = "box_type";
2831
//hot node
2932
public static String HOT_BOX_TYPE = "hot";
@@ -37,6 +40,8 @@ public class CommonValue {
3740
public static final String ML_CONNECTOR_INDEX = ".plugins-ml-connector";
3841
public static final Integer ML_TASK_INDEX_SCHEMA_VERSION = 1;
3942
public static final Integer ML_CONNECTOR_SCHEMA_VERSION = 1;
43+
public static final String ML_CONFIG_INDEX = ".plugins-ml-config";
44+
public static final Integer ML_CONFIG_INDEX_SCHEMA_VERSION = 1;
4045
public static final String USER_FIELD_MAPPING = " \""
4146
+ CommonValue.USER
4247
+ "\": {\n"
@@ -301,4 +306,19 @@ public class CommonValue {
301306
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
302307
+ " }\n"
303308
+ "}";
309+
310+
311+
public static final String ML_CONFIG_INDEX_MAPPING = "{\n"
312+
+ " \"_meta\": {\"schema_version\": "
313+
+ ML_CONFIG_INDEX_SCHEMA_VERSION
314+
+ "},\n"
315+
+ " \"properties\": {\n"
316+
+ " \""
317+
+ MASTER_KEY
318+
+ "\": {\"type\": \"keyword\"},\n"
319+
+ " \""
320+
+ CREATE_TIME_FIELD
321+
+ "\": {\"type\": \"date\", \"format\": \"strict_date_time||epoch_millis\"}\n"
322+
+ " }\n"
323+
+ "}";
304324
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package org.opensearch.ml.engine;
77

88
import lombok.Getter;
9+
import lombok.extern.log4j.Log4j2;
910
import org.opensearch.ml.common.FunctionName;
1011
import org.opensearch.ml.common.MLModel;
1112
import org.opensearch.ml.common.dataframe.DataFrame;
@@ -18,29 +19,33 @@
1819
import org.opensearch.ml.common.output.MLOutput;
1920
import org.opensearch.ml.common.output.Output;
2021
import org.opensearch.ml.engine.encryptor.Encryptor;
21-
2222
import java.nio.file.Path;
2323
import java.util.Locale;
2424
import java.util.Map;
2525

2626
/**
2727
* This is the interface to all ml algorithms.
2828
*/
29+
@Log4j2
2930
public class MLEngine {
3031

3132
public static final String REGISTER_MODEL_FOLDER = "register";
3233
public static final String DEPLOY_MODEL_FOLDER = "deploy";
3334
private final String MODEL_REPO = "https://artifacts.opensearch.org/models/ml-models";
3435

36+
@Getter
37+
private final Path mlConfigPath;
38+
3539
@Getter
3640
private final Path mlCachePath;
3741
private final Path mlModelsCachePath;
3842

39-
private final Encryptor encryptor;
43+
private Encryptor encryptor;
4044

4145
public MLEngine(Path opensearchDataFolder, Encryptor encryptor) {
42-
mlCachePath = opensearchDataFolder.resolve("ml_cache");
43-
mlModelsCachePath = mlCachePath.resolve("models_cache");
46+
this.mlCachePath = opensearchDataFolder.resolve("ml_cache");
47+
this.mlModelsCachePath = mlCachePath.resolve("models_cache");
48+
this.mlConfigPath = mlCachePath.resolve("config");
4449
this.encryptor = encryptor;
4550
}
4651

@@ -195,7 +200,4 @@ public String encrypt(String credential) {
195200
return encryptor.encrypt(credential);
196201
}
197202

198-
public void setMasterKey(String masterKey) {
199-
encryptor.setMasterKey(masterKey);
200-
}
201203
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,6 @@ public static RemoteInferenceInputDataSet processInput(MLInput mlInput, Connecto
9595
} else {
9696
throw new IllegalArgumentException("Wrong input type");
9797
}
98-
Map<String, String> escapedParameters = new HashMap<>();
99-
inputData.getParameters().entrySet().forEach(entry -> {
100-
escapedParameters.put(entry.getKey(), escapeJava(entry.getValue()));
101-
});
102-
inputData.setParameters(escapedParameters);
10398
return inputData;
10499
}
105100

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/Encryptor.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55

66
package org.opensearch.ml.engine.encryptor;
77

8+
import java.security.SecureRandom;
9+
import java.util.Base64;
10+
811
public interface Encryptor {
912

1013
/**
@@ -29,4 +32,8 @@ public interface Encryptor {
2932
* @param masterKey masterKey to be set.
3033
*/
3134
void setMasterKey(String masterKey);
35+
String getMasterKey();
36+
37+
String generateMasterKey();
38+
3239
}

ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java

Lines changed: 70 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,39 @@
99
import com.amazonaws.encryptionsdk.CommitmentPolicy;
1010
import com.amazonaws.encryptionsdk.CryptoResult;
1111
import com.amazonaws.encryptionsdk.jce.JceMasterKey;
12-
import org.opensearch.ml.engine.exceptions.MetaDataException;
12+
import lombok.extern.log4j.Log4j2;
13+
import org.opensearch.ResourceNotFoundException;
14+
import org.opensearch.action.ActionListener;
15+
import org.opensearch.action.LatchedActionListener;
16+
import org.opensearch.action.get.GetRequest;
17+
import org.opensearch.action.get.GetResponse;
18+
import org.opensearch.client.Client;
19+
import org.opensearch.cluster.service.ClusterService;
20+
import org.opensearch.ml.common.exception.MLException;
1321

1422
import javax.crypto.spec.SecretKeySpec;
1523
import java.nio.charset.StandardCharsets;
24+
import java.security.SecureRandom;
1625
import java.util.Base64;
26+
import java.util.concurrent.CountDownLatch;
27+
import java.util.concurrent.atomic.AtomicReference;
1728

29+
import static org.opensearch.ml.common.CommonValue.MASTER_KEY;
30+
import static org.opensearch.ml.common.CommonValue.ML_CONFIG_INDEX;
31+
32+
@Log4j2
1833
public class EncryptorImpl implements Encryptor {
1934

35+
private ClusterService clusterService;
36+
private Client client;
2037
private volatile String masterKey;
2138

39+
public EncryptorImpl(ClusterService clusterService, Client client) {
40+
this.masterKey = null;
41+
this.clusterService = clusterService;
42+
this.client = client;
43+
}
44+
2245
public EncryptorImpl(String masterKey) {
2346
this.masterKey = masterKey;
2447
}
@@ -28,9 +51,14 @@ public void setMasterKey(String masterKey) {
2851
this.masterKey = masterKey;
2952
}
3053

54+
@Override
55+
public String getMasterKey() {
56+
return masterKey;
57+
}
58+
3159
@Override
3260
public String encrypt(String plainText) {
33-
checkMasterKey();
61+
initMasterKey();
3462
final AwsCrypto crypto = AwsCrypto.builder()
3563
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
3664
.build();
@@ -46,7 +74,7 @@ public String encrypt(String plainText) {
4674

4775
@Override
4876
public String decrypt(String encryptedText) {
49-
checkMasterKey();
77+
initMasterKey();
5078
final AwsCrypto crypto = AwsCrypto.builder()
5179
.withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt)
5280
.build();
@@ -60,14 +88,45 @@ public String decrypt(String encryptedText) {
6088
return new String(decryptedResult.getResult());
6189
}
6290

63-
private void checkMasterKey() {
64-
if (masterKey == "0000000000000000" || masterKey == null) {
65-
throw new MetaDataException("Please provide a masterKey for credential encryption! Example: PUT /_cluster/settings\n" +
66-
"{\n" +
67-
" \"persistent\" : {\n" +
68-
" \"plugins.ml_commons.encryption.master_key\" : \"1234567x\" \n" +
69-
" }\n" +
70-
"}");
91+
@Override
92+
public String generateMasterKey() {
93+
byte[] keyBytes = new byte[16];
94+
new SecureRandom().nextBytes(keyBytes);
95+
String base64Key = Base64.getEncoder().encodeToString(keyBytes);
96+
return base64Key;
97+
}
98+
99+
private void initMasterKey() {
100+
if (masterKey != null) {
101+
return;
102+
}
103+
AtomicReference<Exception> exceptionRef = new AtomicReference<>();
104+
105+
CountDownLatch latch = new CountDownLatch(1);
106+
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
107+
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
108+
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
109+
if (r.isExists()) {
110+
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
111+
setMasterKey(masterKey);
112+
} else {
113+
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
114+
}
115+
}, e -> {
116+
log.error("Failed to get ML encryption master key", e);
117+
exceptionRef.set(e);
118+
}), latch));
119+
} else {
120+
exceptionRef.set(new ResourceNotFoundException("ML encryption master key not initialized yet"));
121+
}
122+
123+
if (exceptionRef.get() != null) {
124+
log.debug("Failed to init master key", exceptionRef.get());
125+
if (exceptionRef.get() instanceof RuntimeException) {
126+
throw (RuntimeException) exceptionRef.get();
127+
} else {
128+
throw new MLException(exceptionRef.get());
129+
}
71130
}
72131
}
73132
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelationTest.java

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
import org.opensearch.ml.engine.MLEngine;
5959
import org.opensearch.ml.engine.ModelHelper;
6060
import org.opensearch.ml.engine.encryptor.Encryptor;
61+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
6162
import org.opensearch.search.SearchHit;
6263
import org.opensearch.search.SearchHits;
6364
import org.opensearch.search.aggregations.InternalAggregations;
@@ -128,7 +129,8 @@ public class MetricsCorrelationTest {
128129
ActionListener<MLDeployModelResponse> mlDeployModelResponseActionListener;
129130
private MetricsCorrelation metricsCorrelation;
130131
private MetricsCorrelationInput input, extendedInput;
131-
private Path djlCachePath;
132+
private Path mlCachePath;
133+
private Path mlConfigPath;
132134
private MLModel model;
133135

134136
private MetricsCorrelationModelConfig modelConfig;
@@ -144,7 +146,6 @@ public class MetricsCorrelationTest {
144146

145147
Map<String, Object> params = new HashMap<>();
146148

147-
@Mock
148149
private Encryptor encryptor;
149150

150151
public MetricsCorrelationTest() {
@@ -155,8 +156,9 @@ public void setUp() throws IOException, URISyntaxException {
155156

156157
System.setProperty("testMode", "true");
157158

158-
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
159-
mlEngine = new MLEngine(djlCachePath, encryptor);
159+
mlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
160+
encryptor = new EncryptorImpl("0000000000000001");
161+
mlEngine = new MLEngine(mlCachePath, encryptor);
160162
modelConfig = MetricsCorrelationModelConfig.builder()
161163
.modelType(MetricsCorrelation.MODEL_TYPE)
162164
.allConfig(null)

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public void processInput_TextDocsInputDataSet_PreprocessFunction_MultiTextDoc()
8484
processInput_TextDocsInputDataSet_PreprocessFunction(
8585
"{\"input\": ${parameters.input}}",
8686
"{\"parameters\": { \"input\": [\"test_value1\", \"test_value2\"] } }",
87-
"[\\\"test_value1\\\",\\\"test_value2\\\"]");
87+
"[\"test_value1\",\"test_value2\"]");
8888
}
8989

9090
@Test

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/ModelHelperTest.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
2020
import org.opensearch.ml.engine.MLEngine;
2121
import org.opensearch.ml.engine.ModelHelper;
22+
import org.opensearch.ml.engine.encryptor.Encryptor;
23+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
2224

2325
import java.io.IOException;
2426
import java.net.URISyntaxException;
@@ -50,12 +52,15 @@ public class ModelHelperTest {
5052
@Mock
5153
ActionListener<MLRegisterModelInput> registerModelListener;
5254

55+
Encryptor encryptor;
56+
5357
@Before
5458
public void setup() throws URISyntaxException {
5559
MockitoAnnotations.openMocks(this);
5660
modelFormat = MLModelFormat.TORCH_SCRIPT;
5761
modelId = "model_id";
58-
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), null);
62+
encryptor = new EncryptorImpl("0000000000000001");
63+
mlEngine = new MLEngine(Path.of("/tmp/test" + modelId), encryptor);
5964
modelHelper = new ModelHelper(mlEngine);
6065
}
6166

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/text_embedding/TextEmbeddingModelTest.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.opensearch.ml.engine.MLEngine;
2828
import org.opensearch.ml.engine.ModelHelper;
2929
import org.opensearch.ml.engine.encryptor.Encryptor;
30+
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
3031
import org.opensearch.ml.engine.utils.FileUtils;
3132

3233
import java.io.File;
@@ -62,16 +63,18 @@ public class TextEmbeddingModelTest {
6263
private ModelHelper modelHelper;
6364
private Map<String, Object> params;
6465
private TextEmbeddingModel textEmbeddingModel;
65-
private Path djlCachePath;
66+
private Path mlCachePath;
67+
private Path mlConfigPath;
6668
private TextDocsInputDataSet inputDataSet;
6769
private int dimension = 384;
6870
private MLEngine mlEngine;
6971
private Encryptor encryptor;
7072

7173
@Before
7274
public void setUp() throws URISyntaxException {
73-
djlCachePath = Path.of("/tmp/djl_cache_" + UUID.randomUUID());
74-
mlEngine = new MLEngine(djlCachePath, encryptor);
75+
mlCachePath = Path.of("/tmp/ml_cache" + UUID.randomUUID());
76+
encryptor = new EncryptorImpl("0000000000000001");
77+
mlEngine = new MLEngine(mlCachePath, encryptor);
7578
modelId = "test_model_id";
7679
modelName = "test_model_name";
7780
functionName = FunctionName.TEXT_EMBEDDING;
@@ -329,7 +332,7 @@ public void predict_BeforeInitingModel() {
329332

330333
@After
331334
public void tearDown() {
332-
FileUtils.deleteFileQuietly(djlCachePath);
335+
FileUtils.deleteFileQuietly(mlCachePath);
333336
}
334337

335338
private int findSentenceEmbeddingPosition(ModelTensors modelTensors) {

0 commit comments

Comments
 (0)