diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java index 7d631ac2f6..d474355aa2 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java @@ -349,5 +349,5 @@ private MLCommonsSettings() {} // Feature flag for enabling telemetry static metric collection job -- MLStatsJobProcessor public static final Setting ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED = Setting - .boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Final); + .boolSetting("plugins.ml_commons.metrics_static_collection_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java index 786af9e29c..ccd9687946 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java @@ -88,6 +88,12 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) clusterService .getClusterSettings() .addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it); + clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED, it -> { + isStaticMetricCollectionEnabled = it; + for (SettingsChangeListener listener : listeners) { + listener.onStaticMetricCollectionEnabledChanged(it); + } + }); } /** diff --git a/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java b/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java index 946e88239a..c8576ffbc0 100644 --- a/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java +++ b/common/src/main/java/org/opensearch/ml/common/settings/SettingsChangeListener.java @@ -18,5 +18,20 @@ public interface SettingsChangeListener { *
  • false if multi-tenancy is disabled
  • * */ - void onMultiTenancyEnabledChanged(boolean isEnabled); + default void onMultiTenancyEnabledChanged(boolean isEnabled) { + // do nothing + } + + /** + * Callback method that gets triggered when the static metric collection setting changes. + * + * @param isEnabled A boolean value indicating the new state of the static metric collection setting: + * + */ + default void onStaticMetricCollectionEnabledChanged(boolean isEnabled) { + // do nothing + } } diff --git a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java index e1dc2b2030..72f7a49689 100644 --- a/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java +++ b/common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java @@ -129,4 +129,19 @@ public void testMultiTenancyChangeNotifiesListeners() { setting.notifyMultiTenancyListeners(true); verify(mockListener).onMultiTenancyEnabledChanged(true); } + + @Test + public void testStaticMetricCollectionSettingChangeNotifiesListeners() { + Settings settings = Settings.builder().put("plugins.ml_commons.metrics_static_collection_enabled", false).build(); + + MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings); + + SettingsChangeListener mockListener = mock(SettingsChangeListener.class); + setting.addListener(mockListener); + + mockClusterSettings.applySettings(Settings.builder().put("plugins.ml_commons.metrics_static_collection_enabled", true).build()); + + verify(mockListener).onStaticMetricCollectionEnabledChanged(true); + assertTrue(setting.isStaticMetricCollectionEnabled()); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterEventListener.java b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterEventListener.java index 53aa465eec..f883d1f3b3 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterEventListener.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/MLCommonsClusterEventListener.java @@ -89,9 +89,9 @@ public void clusterChanged(ClusterChangedEvent event) { * The following logic implements this behavior. */ for (DiscoveryNode node : state.nodes()) { - if (node.isDataNode() && Version.V_3_1_0.onOrAfter(node.getVersion())) { + if (node.isDataNode() && node.getVersion().onOrAfter(Version.V_3_1_0)) { if (mlFeatureEnabledSetting.isMetricCollectionEnabled() && mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()) { - mlTaskManager.startStatsCollectorJob(); + mlTaskManager.indexStatsCollectorJob(true); } if (clusterService.state().getMetadata().hasIndex(TASK_POLLING_JOB_INDEX)) { diff --git a/plugin/src/main/java/org/opensearch/ml/jobs/MLJobParameter.java b/plugin/src/main/java/org/opensearch/ml/jobs/MLJobParameter.java index ffad0d5022..e8c4c6b57a 100644 --- a/plugin/src/main/java/org/opensearch/ml/jobs/MLJobParameter.java +++ b/plugin/src/main/java/org/opensearch/ml/jobs/MLJobParameter.java @@ -49,14 +49,14 @@ public class MLJobParameter implements ScheduledJobParameter { public MLJobParameter() {} - public MLJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter, MLJobType jobType) { + public MLJobParameter(String name, Schedule schedule, Long lockDurationSeconds, Double jitter, MLJobType jobType, boolean isEnabled) { this.jobName = name; this.schedule = schedule; this.lockDurationSeconds = lockDurationSeconds; this.jitter = jitter; Instant now = Instant.now(); - this.isEnabled = true; + this.isEnabled = isEnabled; this.enabledTime = now; this.lastUpdateTime = now; this.jobType = jobType; diff --git a/plugin/src/main/java/org/opensearch/ml/jobs/MLJobRunner.java b/plugin/src/main/java/org/opensearch/ml/jobs/MLJobRunner.java index 7295d368c0..1a22e144bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/jobs/MLJobRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/jobs/MLJobRunner.java @@ -93,6 +93,10 @@ public void runJob(ScheduledJobParameter scheduledJobParameter, JobExecutionCont throw new IllegalArgumentException("Job parameters is invalid."); } + if (!jobParameter.isEnabled()) { + throw new IllegalStateException(String.format("Attempted to run disabled job of type: %s", jobParameter.getJobType().name())); + } + switch (jobParameter.getJobType()) { case STATS_COLLECTOR: MLStatsJobProcessor diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e1d7e78d2b..4b74852c1c 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -632,6 +632,7 @@ public Collection createComponents( modelAccessControlHelper = new ModelAccessControlHelper(clusterService, settings); connectorAccessControlHelper = new ConnectorAccessControlHelper(clusterService, settings); mlFeatureEnabledSetting = new MLFeatureEnabledSetting(clusterService, settings); + mlFeatureEnabledSetting.addListener(mlTaskManager); mlModelManager = new MLModelManager( clusterService, scriptService, diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index b51433f1d5..a5efeaba64 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -51,6 +51,7 @@ import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.common.exception.MLLimitExceededException; import org.opensearch.ml.common.exception.MLResourceNotFoundException; +import org.opensearch.ml.common.settings.SettingsChangeListener; import org.opensearch.ml.engine.indices.MLIndicesHandler; import org.opensearch.ml.jobs.MLJobParameter; import org.opensearch.ml.jobs.MLJobType; @@ -73,7 +74,7 @@ * MLTaskManager is responsible for managing MLTask. */ @Log4j2 -public class MLTaskManager { +public class MLTaskManager implements SettingsChangeListener { public static int TASK_SEMAPHORE_TIMEOUT = 5000; // 5 seconds private final Map taskCaches; private final Client client; @@ -553,7 +554,8 @@ public void startTaskPollingJob() { new IntervalSchedule(Instant.now(), 1, ChronoUnit.MINUTES), 20L, null, - MLJobType.BATCH_TASK_UPDATE + MLJobType.BATCH_TASK_UPDATE, + true ); IndexRequest indexRequest = new IndexRequest() @@ -562,24 +564,27 @@ public void startTaskPollingJob() { .source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - startJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> this.taskPollingJobStarted = true); + indexJob(indexRequest, MLJobType.BATCH_TASK_UPDATE, () -> this.taskPollingJobStarted = true); } catch (IOException e) { log.error("Failed to index task polling job", e); } } - public void startStatsCollectorJob() { - if (statsCollectorJobStarted) { - return; - } + @Override + public void onStaticMetricCollectionEnabledChanged(boolean isEnabled) { + log.info("Static metric collection setting changed to: {}", isEnabled); + indexStatsCollectorJob(isEnabled); + } + public void indexStatsCollectorJob(boolean enabled) { try { MLJobParameter jobParameter = new MLJobParameter( MLJobType.STATS_COLLECTOR.name(), new IntervalSchedule(Instant.now(), 5, ChronoUnit.MINUTES), 60L, null, - MLJobType.STATS_COLLECTOR + MLJobType.STATS_COLLECTOR, + enabled ); IndexRequest indexRequest = new IndexRequest() @@ -588,7 +593,7 @@ public void startStatsCollectorJob() { .source(jobParameter.toXContent(JsonXContent.contentBuilder(), null)) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - startJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> this.statsCollectorJobStarted = true); + indexJob(indexRequest, MLJobType.STATS_COLLECTOR, () -> {}); } catch (IOException e) { log.error("Failed to index stats collection job", e); } @@ -601,7 +606,7 @@ public void startStatsCollectorJob() { * @param jobType the type of job being started * @param successCallback callback to execute on successful job indexing */ - private void startJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) { + private void indexJob(IndexRequest indexRequest, MLJobType jobType, Runnable successCallback) { mlIndicesHandler.initMLJobsIndex(ActionListener.wrap(success -> { if (success) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/MLCommonsClusterEventListenerTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/MLCommonsClusterEventListenerTests.java new file mode 100644 index 0000000000..be671e3407 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/cluster/MLCommonsClusterEventListenerTests.java @@ -0,0 +1,136 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.cluster; + +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.common.CommonValue.TASK_POLLING_JOB_INDEX; + +import java.util.Collections; + +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.Version; +import org.opensearch.cluster.ClusterChangedEvent; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.cluster.node.DiscoveryNodeRole; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.autoredeploy.MLModelAutoReDeployer; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.model.MLModelCacheHelper; +import org.opensearch.ml.model.MLModelManager; +import org.opensearch.ml.task.MLTaskManager; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.client.Client; + +public class MLCommonsClusterEventListenerTests extends OpenSearchTestCase { + + @Mock + private ClusterService clusterService; + @Mock + private MLModelManager mlModelManager; + @Mock + private MLTaskManager mlTaskManager; + @Mock + private MLModelCacheHelper modelCacheHelper; + @Mock + private MLModelAutoReDeployer mlModelAutoReDeployer; + @Mock + private Client client; + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + @Mock + private ClusterChangedEvent event; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + + private MLCommonsClusterEventListener listener; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + listener = new MLCommonsClusterEventListener( + clusterService, + mlModelManager, + mlTaskManager, + modelCacheHelper, + mlModelAutoReDeployer, + client, + mlFeatureEnabledSetting + ); + } + + public void testClusterChanged_WithV31DataNode_MetricCollectionEnabled() { + DiscoveryNode dataNode = createDataNode(Version.V_3_1_0); + setupClusterState(dataNode, false); + + when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()).thenReturn(true); + + listener.clusterChanged(event); + + verify(mlTaskManager).indexStatsCollectorJob(true); + verify(mlTaskManager, never()).startTaskPollingJob(); + } + + public void testClusterChanged_WithV31DataNode_TaskPollingIndexExists() { + DiscoveryNode dataNode = createDataNode(Version.V_3_1_0); + setupClusterState(dataNode, true); + + when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(false); + + listener.clusterChanged(event); + + verify(mlTaskManager, never()).indexStatsCollectorJob(anyBoolean()); + verify(mlTaskManager).startTaskPollingJob(); + } + + public void testClusterChanged_WithPreV31DataNode_NoJobsStarted() { + DiscoveryNode dataNode = createDataNode(Version.V_3_0_0); + setupClusterState(dataNode, true); + + when(mlFeatureEnabledSetting.isMetricCollectionEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isStaticMetricCollectionEnabled()).thenReturn(true); + + listener.clusterChanged(event); + + verify(mlTaskManager, never()).indexStatsCollectorJob(anyBoolean()); + verify(mlTaskManager, never()).startTaskPollingJob(); + } + + private DiscoveryNode createDataNode(Version version) { + return new DiscoveryNode( + "dataNode", + "dataNodeId", + buildNewFakeTransportAddress(), + Collections.emptyMap(), + Collections.singleton(DiscoveryNodeRole.DATA_ROLE), + version + ); + } + + private void setupClusterState(DiscoveryNode node, boolean hasTaskPollingIndex) { + DiscoveryNodes nodes = DiscoveryNodes.builder().add(node).build(); + + when(event.state()).thenReturn(clusterState); + when(event.previousState()).thenReturn(clusterState); + when(event.nodesDelta()).thenReturn(mock(DiscoveryNodes.Delta.class)); + when(clusterState.nodes()).thenReturn(nodes); + when(clusterState.getMetadata()).thenReturn(metadata); + when(clusterService.state()).thenReturn(clusterState); + when(metadata.hasIndex(TASK_POLLING_JOB_INDEX)).thenReturn(hasTaskPollingIndex); + when(metadata.settings()).thenReturn(org.opensearch.common.settings.Settings.EMPTY); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/jobs/MLJobParameterTests.java b/plugin/src/test/java/org/opensearch/ml/jobs/MLJobParameterTests.java index 270b41f399..5da304ad01 100644 --- a/plugin/src/test/java/org/opensearch/ml/jobs/MLJobParameterTests.java +++ b/plugin/src/test/java/org/opensearch/ml/jobs/MLJobParameterTests.java @@ -35,7 +35,7 @@ public void setUp() { lockDurationSeconds = 20L; jitter = 0.5; jobType = null; - jobParameter = new MLJobParameter(jobName, schedule, lockDurationSeconds, jitter, jobType); + jobParameter = new MLJobParameter(jobName, schedule, lockDurationSeconds, jitter, jobType, true); } @Test @@ -54,7 +54,7 @@ public void testToXContent() throws Exception { @Test public void testNullCase() throws IOException { String newJobName = "test-job"; - MLJobParameter nullParameter = new MLJobParameter(newJobName, null, null, null, null); + MLJobParameter nullParameter = new MLJobParameter(newJobName, null, null, null, null, true); nullParameter.setLastUpdateTime(null); nullParameter.setEnabledTime(null); @@ -64,6 +64,7 @@ public void testNullCase() throws IOException { assertTrue(jsonString.contains(newJobName)); assertEquals(newJobName, nullParameter.getName()); + assertTrue(nullParameter.isEnabled()); assertNull(nullParameter.getSchedule()); assertNull(nullParameter.getLockDurationSeconds()); assertNull(nullParameter.getJitter()); diff --git a/plugin/src/test/java/org/opensearch/ml/jobs/MLJobRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/jobs/MLJobRunnerTests.java index 0b2561d7c4..7118166720 100644 --- a/plugin/src/test/java/org/opensearch/ml/jobs/MLJobRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/jobs/MLJobRunnerTests.java @@ -73,4 +73,11 @@ public void testRunJobWithNullJobType() { when(jobParameter.getJobType()).thenReturn(null); jobRunner.runJob(jobParameter, jobExecutionContext); } + + @Test(expected = IllegalStateException.class) + public void testRunJobWithDisabledJob() { + when(jobParameter.isEnabled()).thenReturn(false); + when(jobParameter.getJobType()).thenReturn(MLJobType.STATS_COLLECTOR); + jobRunner.runJob(jobParameter, jobExecutionContext); + } } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index a642895b43..474776555e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -404,7 +404,7 @@ public void testStartStatsCollectorJob() throws IOException { return null; }).when(client).index(any(), any()); - mlTaskManager.startStatsCollectorJob(); + mlTaskManager.indexStatsCollectorJob(true); ArgumentCaptor indexRequestCaptor = ArgumentCaptor.forClass(IndexRequest.class); verify(client).index(indexRequestCaptor.capture(), any()); @@ -429,7 +429,7 @@ public void testStartStatsCollectorJob_IndexException() throws IOException { return null; }).when(client).index(any(), any()); - mlTaskManager.startStatsCollectorJob(); + mlTaskManager.indexStatsCollectorJob(true); verify(client).index(any(), any()); } @@ -469,4 +469,24 @@ public void testUpdateMLTaskDirectly_TaskDoneState() { mlTaskManager.updateMLTaskDirectly("task_id", updatedFields, listener); verify(listener).onResponse(any(UpdateResponse.class)); } + + public void testOnStaticMetricCollectionEnabledChanged() { + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(0); + listener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLJobsIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(indexResponse); + return null; + }).when(client).index(any(), any()); + + mlTaskManager.onStaticMetricCollectionEnabledChanged(true); + verify(mlTaskManager).indexStatsCollectorJob(true); + + mlTaskManager.onStaticMetricCollectionEnabledChanged(false); + verify(mlTaskManager).indexStatsCollectorJob(false); + } }