Skip to content

Commit 2d0957f

Browse files
committed
refactor checks to ml indices to return true when MultiTenancy enabled
Signed-off-by: Brian Flores <[email protected]>
1 parent 8ef3528 commit 2d0957f

File tree

9 files changed

+54
-11
lines changed

9 files changed

+54
-11
lines changed

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
import org.opensearch.ml.engine.Executable;
6969
import org.opensearch.ml.engine.annotation.Function;
7070
import org.opensearch.ml.engine.encryptor.Encryptor;
71+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
7172
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7273
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
7374
import org.opensearch.ml.engine.tools.QueryPlanningTool;
@@ -173,7 +174,7 @@ public void execute(Input input, ActionListener<Output> listener) {
173174
.fetchSourceContext(fetchSourceContext)
174175
.build();
175176

176-
if (clusterService.state().metadata().hasIndex(ML_AGENT_INDEX)) {
177+
if (MLIndicesHandler.doesMultiTenantIndexExists(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_AGENT_INDEX)) {
177178
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
178179
sdkClient
179180
.getDataObjectAsync(getDataObjectRequest, client.threadPool().executor("opensearch_ml_general"))

ml-algorithms/src/main/java/org/opensearch/ml/engine/indices/MLIndicesHandler.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import org.opensearch.ml.common.CommonValue;
3131
import org.opensearch.ml.common.MLIndex;
3232
import org.opensearch.ml.common.exception.MLException;
33+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3334
import org.opensearch.transport.client.Client;
3435

3536
import lombok.AccessLevel;
@@ -44,6 +45,7 @@ public class MLIndicesHandler {
4445

4546
ClusterService clusterService;
4647
Client client;
48+
MLFeatureEnabledSetting mlFeatureEnabledSetting;
4749
private static final Map<String, AtomicBoolean> indexMappingUpdated = new HashMap<>();
4850

4951
static {
@@ -52,6 +54,10 @@ public class MLIndicesHandler {
5254
}
5355
}
5456

57+
public static boolean doesMultiTenantIndexExists(ClusterService clusterService, boolean isMultiTenancyEnabled, String indexName) {
58+
return isMultiTenancyEnabled || clusterService.state().metadata().hasIndex(indexName);
59+
}
60+
5561
public void initModelGroupIndexIfAbsent(ActionListener<Boolean> listener) {
5662
initMLIndexIfAbsent(MLIndex.MODEL_GROUP, listener);
5763
}
@@ -105,7 +111,7 @@ public void initMLIndexIfAbsent(MLIndex index, ActionListener<Boolean> listener)
105111
String mapping = index.getMapping();
106112
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
107113
ActionListener<Boolean> internalListener = ActionListener.runBefore(listener, () -> threadContext.restore());
108-
if (!clusterService.state().metadata().hasIndex(indexName)) {
114+
if (!MLIndicesHandler.doesMultiTenantIndexExists(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), indexName)) {
109115
ActionListener<CreateIndexResponse> actionListener = ActionListener.wrap(r -> {
110116
if (r.isAcknowledged()) {
111117
log.info("create index:{}", indexName);

ml-algorithms/src/test/java/org/opensearch/ml/engine/indices/MLIndicesHandlerTest.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import org.opensearch.common.settings.Settings;
3636
import org.opensearch.common.util.concurrent.ThreadContext;
3737
import org.opensearch.core.action.ActionListener;
38+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
3839
import org.opensearch.threadpool.ThreadPool;
3940
import org.opensearch.transport.client.AdminClient;
4041
import org.opensearch.transport.client.Client;
@@ -74,6 +75,9 @@ public class MLIndicesHandlerTest {
7475
@Mock
7576
private ThreadPool threadPool;
7677

78+
@Mock
79+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
80+
7781
Settings settings;
7882
ThreadContext threadContext;
7983
MLIndicesHandler indicesHandler;
@@ -102,7 +106,7 @@ public void setUp() {
102106
threadContext = new ThreadContext(settings);
103107
when(client.threadPool()).thenReturn(threadPool);
104108
when(threadPool.getThreadContext()).thenReturn(threadContext);
105-
indicesHandler = new MLIndicesHandler(clusterService, client);
109+
indicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
106110
}
107111

108112
@Test

plugin/src/main/java/org/opensearch/ml/action/connector/ExecuteConnectorTransportAction.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.opensearch.ml.engine.MLEngineClassLoader;
2727
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
2828
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
29+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
2930
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
3031
import org.opensearch.script.ScriptService;
3132
import org.opensearch.tasks.Task;
@@ -74,7 +75,8 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
7475
String connectorId = executeConnectorRequest.getConnectorId();
7576
String connectorAction = ConnectorAction.ActionType.EXECUTE.name();
7677

77-
if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
78+
if (MLIndicesHandler
79+
.doesMultiTenantIndexExists(clusterService, mlFeatureEnabledSetting.isMultiTenancyEnabled(), ML_CONNECTOR_INDEX)) {
7880
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
7981
if (connectorAccessControlHelper.validateConnectorAccess(client, connector)) {
8082
// adding tenantID as null, because we are not implement multi-tenancy for this feature yet.

plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
import org.opensearch.ml.common.connector.HttpConnector;
4040
import org.opensearch.ml.common.exception.MLException;
4141
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
42+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
43+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
4244
import org.opensearch.ml.helper.ModelAccessControlHelper;
4345
import org.opensearch.ml.utils.RestActionUtils;
4446
import org.opensearch.remote.metadata.client.SdkClient;
@@ -65,17 +67,20 @@ public class MLSearchHandler {
6567
private ModelAccessControlHelper modelAccessControlHelper;
6668

6769
private ClusterService clusterService;
70+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
6871

6972
public MLSearchHandler(
7073
Client client,
7174
NamedXContentRegistry xContentRegistry,
7275
ModelAccessControlHelper modelAccessControlHelper,
73-
ClusterService clusterService
76+
ClusterService clusterService,
77+
MLFeatureEnabledSetting mlFeatureEnabledSetting
7478
) {
7579
this.modelAccessControlHelper = modelAccessControlHelper;
7680
this.client = client;
7781
this.xContentRegistry = xContentRegistry;
7882
this.clusterService = clusterService;
83+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
7984
}
8085

8186
/**
@@ -132,7 +137,12 @@ public void search(SdkClient sdkClient, SearchRequest request, String tenantId,
132137
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
133138
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
134139
if (modelAccessControlHelper.skipModelAccessControl(user)
135-
|| !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
140+
|| !MLIndicesHandler
141+
.doesMultiTenantIndexExists(
142+
clusterService,
143+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
144+
CommonValue.ML_MODEL_GROUP_INDEX
145+
)) {
136146

137147
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
138148
.builder()

plugin/src/main/java/org/opensearch/ml/action/tasks/CancelBatchJobTransportAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
5656
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
5757
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
58+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
5859
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
5960
import org.opensearch.ml.helper.ModelAccessControlHelper;
6061
import org.opensearch.ml.model.MLModelManager;
@@ -199,7 +200,12 @@ private void processRemoteBatchPrediction(MLTask mlTask, ActionListener<MLCancel
199200
if (model.getConnector() != null) {
200201
Connector connector = model.getConnector();
201202
executeConnector(connector, mlInput, actionListener);
202-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
203+
} else if (MLIndicesHandler
204+
.doesMultiTenantIndexExists(
205+
clusterService,
206+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
207+
ML_CONNECTOR_INDEX
208+
)) {
203209
ActionListener<Connector> listener = ActionListener
204210
.wrap(connector -> { executeConnector(connector, mlInput, actionListener); }, e -> {
205211
log.error("Failed to get connector {}", model.getConnectorId(), e);

plugin/src/main/java/org/opensearch/ml/action/tasks/GetTaskTransportAction.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@
8282
import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils;
8383
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
8484
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
85+
import org.opensearch.ml.engine.indices.MLIndicesHandler;
8586
import org.opensearch.ml.engine.utils.S3Utils;
8687
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
8788
import org.opensearch.ml.helper.ModelAccessControlHelper;
@@ -390,7 +391,12 @@ private void processRemoteBatchPrediction(
390391
remoteJob,
391392
actionListener
392393
);
393-
} else if (clusterService.state().metadata().hasIndex(ML_CONNECTOR_INDEX)) {
394+
} else if (MLIndicesHandler
395+
.doesMultiTenantIndexExists(
396+
clusterService,
397+
mlFeatureEnabledSetting.isMultiTenancyEnabled(),
398+
ML_CONNECTOR_INDEX
399+
)) {
394400
ActionListener<Connector> listener = ActionListener.wrap(connector -> {
395401
executeConnector(
396402
connector,

plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,7 @@ public Collection<Object> createComponents(
599599
Settings settings = environment.settings();
600600
Path dataPath = environment.dataFiles()[0];
601601

602-
mlIndicesHandler = new MLIndicesHandler(clusterService, client);
602+
mlIndicesHandler = new MLIndicesHandler(clusterService, client, mlFeatureEnabledSetting);
603603

604604
SdkClient sdkClient = SdkClientFactory
605605
.createSdkClient(
@@ -802,7 +802,13 @@ public Collection<Object> createComponents(
802802
MLToolExecutor toolExecutor = new MLToolExecutor(client, sdkClient, settings, clusterService, xContentRegistry, toolFactories);
803803
MLEngineClassLoader.register(FunctionName.TOOL, toolExecutor);
804804

805-
MLSearchHandler mlSearchHandler = new MLSearchHandler(client, xContentRegistry, modelAccessControlHelper, clusterService);
805+
MLSearchHandler mlSearchHandler = new MLSearchHandler(
806+
client,
807+
xContentRegistry,
808+
modelAccessControlHelper,
809+
clusterService,
810+
mlFeatureEnabledSetting
811+
);
806812
MLModelAutoReDeployer mlModelAutoRedeployer = new MLModelAutoReDeployer(
807813
clusterService,
808814
client,

plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ public class SearchModelTransportActionTests extends OpenSearchTestCase {
114114
public void setup() {
115115
MockitoAnnotations.openMocks(this);
116116
sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap());
117-
mlSearchHandler = spy(new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService));
117+
mlSearchHandler = spy(
118+
new MLSearchHandler(client, namedXContentRegistry, modelAccessControlHelper, clusterService, mlFeatureEnabledSetting)
119+
);
118120
searchModelTransportAction = new SearchModelTransportAction(
119121
transportService,
120122
actionFilters,

0 commit comments

Comments
 (0)