Skip to content

Commit 145f9f8

Browse files
Add feature flag for agentic search (opensearch-project#4021)
* Add feature flag for agentic search Signed-off-by: rithin-pullela-aws <[email protected]> * Address edge case and comments Signed-off-by: rithin-pullela-aws <[email protected]> * Use MLFeatureEnabled Settings Signed-off-by: rithin-pullela-aws <[email protected]> * Fix failing tests Signed-off-by: rithin-pullela-aws <[email protected]> * Prevent qp tool create when not enabled Signed-off-by: rithin-pullela-aws <[email protected]> * Fix failing test Signed-off-by: rithin-pullela-aws <[email protected]> * Fix failing tests Signed-off-by: rithin-pullela-aws <[email protected]> * Add UTs Signed-off-by: rithin-pullela-aws <[email protected]> --------- Signed-off-by: rithin-pullela-aws <[email protected]>
1 parent 4b4b409 commit 145f9f8

File tree

13 files changed

+517
-28
lines changed

13 files changed

+517
-28
lines changed

common/src/main/java/org/opensearch/ml/common/settings/MLCommonsSettings.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,12 @@ private MLCommonsSettings() {}
216216
public static final Setting<Boolean> ML_COMMONS_MEMORY_FEATURE_ENABLED = Setting
217217
.boolSetting("plugins.ml_commons.memory_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);
218218

219+
public static final Setting<Boolean> ML_COMMONS_AGENTIC_SEARCH_ENABLED = Setting
220+
.boolSetting("plugins.ml_commons.agentic_search_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221+
public static final String ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE =
222+
"The QueryPlanningTool tool for Agentic Search is not enabled. To enable, please update the setting "
223+
+ ML_COMMONS_AGENTIC_SEARCH_ENABLED.getKey();
224+
219225
public static final Setting<Boolean> ML_COMMONS_MCP_CONNECTOR_ENABLED = Setting
220226
.boolSetting("plugins.ml_commons.mcp_connector_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);
221227
public static final String ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE =

common/src/main/java/org/opensearch/ml/common/settings/MLFeatureEnabledSetting.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55

66
package org.opensearch.ml.common.settings;
77

8+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED;
89
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENT_FRAMEWORK_ENABLED;
910
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
1011
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
1112
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED;
1213
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
14+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1315
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_SERVER_ENABLED;
1416
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED;
1517
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MULTI_TENANCY_ENABLED;
@@ -52,6 +54,10 @@ public class MLFeatureEnabledSetting {
5254

5355
private volatile Boolean isExecuteToolEnabled;
5456

57+
private volatile Boolean isAgenticSearchEnabled;
58+
59+
private volatile Boolean isMcpConnectorEnabled;
60+
5561
private final List<SettingsChangeListener> listeners = new ArrayList<>();
5662

5763
public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
@@ -68,6 +74,8 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
6874
isMetricCollectionEnabled = ML_COMMONS_METRIC_COLLECTION_ENABLED.get(settings);
6975
isStaticMetricCollectionEnabled = ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED.get(settings);
7076
isExecuteToolEnabled = ML_COMMONS_EXECUTE_TOOL_ENABLED.get(settings);
77+
isAgenticSearchEnabled = ML_COMMONS_AGENTIC_SEARCH_ENABLED.get(settings);
78+
isMcpConnectorEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(settings);
7179

7280
clusterService
7381
.getClusterSettings()
@@ -91,6 +99,8 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
9199
.getClusterSettings()
92100
.addSettingsUpdateConsumer(MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED, it -> isRagSearchPipelineEnabled = it);
93101
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_EXECUTE_TOOL_ENABLED, it -> isExecuteToolEnabled = it);
102+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_AGENTIC_SEARCH_ENABLED, it -> isAgenticSearchEnabled = it);
103+
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> isMcpConnectorEnabled = it);
94104
}
95105

96106
/**
@@ -195,4 +205,12 @@ public void notifyMultiTenancyListeners(boolean isEnabled) {
195205
listener.onMultiTenancyEnabledChanged(isEnabled);
196206
}
197207
}
208+
209+
public boolean isAgenticSearchEnabled() {
210+
return isAgenticSearchEnabled;
211+
}
212+
213+
public boolean isMcpConnectorEnabled() {
214+
return isMcpConnectorEnabled;
215+
}
198216
}

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ public void setUp() {
4444
MLCommonsSettings.ML_COMMONS_RAG_PIPELINE_FEATURE_ENABLED,
4545
MLCommonsSettings.ML_COMMONS_METRIC_COLLECTION_ENABLED,
4646
MLCommonsSettings.ML_COMMONS_STATIC_METRIC_COLLECTION_ENABLED,
47-
MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED
47+
MLCommonsSettings.ML_COMMONS_EXECUTE_TOOL_ENABLED,
48+
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED,
49+
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED
4850
)
4951
);
5052
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
@@ -66,6 +68,8 @@ public void testDefaults_allFeaturesEnabled() {
6668
.put("plugins.ml_commons.rag_pipeline_feature_enabled", true)
6769
.put("plugins.ml_commons.metrics_collection_enabled", true)
6870
.put("plugins.ml_commons.metrics_static_collection_enabled", true)
71+
.put("plugins.ml_commons.mcp_connector_enabled", true)
72+
.put("plugins.ml_commons.agentic_search_enabled", true)
6973
.build();
7074

7175
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -82,6 +86,8 @@ public void testDefaults_allFeaturesEnabled() {
8286
assertTrue(setting.isRagSearchPipelineEnabled());
8387
assertTrue(setting.isMetricCollectionEnabled());
8488
assertTrue(setting.isStaticMetricCollectionEnabled());
89+
assertTrue(setting.isMcpConnectorEnabled());
90+
assertTrue(setting.isAgenticSearchEnabled());
8591
}
8692

8793
@Test
@@ -100,6 +106,8 @@ public void testDefaults_someFeaturesDisabled() {
100106
.put("plugins.ml_commons.rag_pipeline_feature_enabled", false)
101107
.put("plugins.ml_commons.metrics_collection_enabled", false)
102108
.put("plugins.ml_commons.metrics_static_collection_enabled", false)
109+
.put("plugins.ml_commons.mcp_connector_enabled", false)
110+
.put("plugins.ml_commons.agentic_search_enabled", false)
103111
.build();
104112

105113
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -116,6 +124,8 @@ public void testDefaults_someFeaturesDisabled() {
116124
assertFalse(setting.isRagSearchPipelineEnabled());
117125
assertFalse(setting.isMetricCollectionEnabled());
118126
assertFalse(setting.isStaticMetricCollectionEnabled());
127+
assertFalse(setting.isMcpConnectorEnabled());
128+
assertFalse(setting.isAgenticSearchEnabled());
119129
}
120130

121131
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
import static org.opensearch.ml.common.MLTask.STATE_FIELD;
1515
import static org.opensearch.ml.common.MLTask.TASK_ID_FIELD;
1616
import static org.opensearch.ml.common.output.model.ModelTensorOutput.INFERENCE_RESULT_FIELD;
17+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
1718
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE;
18-
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED;
1919
import static org.opensearch.ml.common.utils.MLTaskUtils.updateMLTaskDirectly;
2020

2121
import java.security.AccessController;
@@ -52,6 +52,7 @@
5252
import org.opensearch.ml.common.MLTaskType;
5353
import org.opensearch.ml.common.agent.MLAgent;
5454
import org.opensearch.ml.common.agent.MLMemorySpec;
55+
import org.opensearch.ml.common.agent.MLToolSpec;
5556
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
5657
import org.opensearch.ml.common.input.Input;
5758
import org.opensearch.ml.common.input.execute.agent.AgentMLInput;
@@ -60,6 +61,7 @@
6061
import org.opensearch.ml.common.output.model.ModelTensor;
6162
import org.opensearch.ml.common.output.model.ModelTensorOutput;
6263
import org.opensearch.ml.common.output.model.ModelTensors;
64+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
6365
import org.opensearch.ml.common.settings.SettingsChangeListener;
6466
import org.opensearch.ml.common.spi.memory.Memory;
6567
import org.opensearch.ml.common.spi.tools.Tool;
@@ -68,6 +70,7 @@
6870
import org.opensearch.ml.engine.encryptor.Encryptor;
6971
import org.opensearch.ml.engine.memory.ConversationIndexMemory;
7072
import org.opensearch.ml.engine.memory.ConversationIndexMessage;
73+
import org.opensearch.ml.engine.tools.QueryPlanningTool;
7174
import org.opensearch.ml.memory.action.conversation.CreateInteractionResponse;
7275
import org.opensearch.ml.memory.action.conversation.GetInteractionAction;
7376
import org.opensearch.ml.memory.action.conversation.GetInteractionRequest;
@@ -108,7 +111,7 @@ public class MLAgentExecutor implements Executable, SettingsChangeListener {
108111
private Map<String, Memory.Factory> memoryFactoryMap;
109112
private volatile Boolean isMultiTenancyEnabled;
110113
private Encryptor encryptor;
111-
private static volatile boolean mcpConnectorIsEnabled;
114+
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
112115

113116
public MLAgentExecutor(
114117
Client client,
@@ -118,7 +121,7 @@ public MLAgentExecutor(
118121
NamedXContentRegistry xContentRegistry,
119122
Map<String, Tool.Factory> toolFactories,
120123
Map<String, Memory.Factory> memoryFactoryMap,
121-
Boolean isMultiTenancyEnabled,
124+
MLFeatureEnabledSetting mlFeatureEnabledSetting,
122125
Encryptor encryptor
123126
) {
124127
this.client = client;
@@ -128,10 +131,9 @@ public MLAgentExecutor(
128131
this.xContentRegistry = xContentRegistry;
129132
this.toolFactories = toolFactories;
130133
this.memoryFactoryMap = memoryFactoryMap;
131-
this.isMultiTenancyEnabled = isMultiTenancyEnabled;
134+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
132135
this.encryptor = encryptor;
133-
this.mcpConnectorIsEnabled = ML_COMMONS_MCP_CONNECTOR_ENABLED.get(clusterService.getSettings());
134-
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MCP_CONNECTOR_ENABLED, it -> mcpConnectorIsEnabled = it);
136+
this.isMultiTenancyEnabled = mlFeatureEnabledSetting.isMultiTenancyEnabled();
135137
}
136138

137139
@Override
@@ -389,11 +391,21 @@ private void executeAgent(
389391
ActionListener<Output> listener
390392
) {
391393
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
392-
if (mcpConnectorConfigJSON != null && !mcpConnectorIsEnabled) {
394+
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
393395
// MCP connector provided as tools but MCP feature is disabled, so abort.
394396
listener.onFailure(new OpenSearchException(ML_COMMONS_MCP_CONNECTOR_DISABLED_MESSAGE));
395397
return;
396398
}
399+
List<MLToolSpec> tools = mlAgent.getTools();
400+
if (tools != null) {
401+
for (MLToolSpec tool : tools) {
402+
if (tool.getType().equals(QueryPlanningTool.TYPE) && !mlFeatureEnabledSetting.isAgenticSearchEnabled()) {
403+
listener.onFailure(new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE));
404+
return;
405+
}
406+
}
407+
}
408+
397409
MLAgentRunner mlAgentRunner = getAgentRunner(mlAgent);
398410
// If async is true, index ML task and return the taskID. Also add memoryID to the task if it exists
399411
if (isAsync) {

ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/QueryPlanningTool.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.engine.tools;
77

8+
import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE;
89
import static org.opensearch.ml.common.utils.StringUtils.gson;
910
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_QUERY;
1011
import static org.opensearch.ml.engine.tools.QueryPlanningPromptTemplate.DEFAULT_SYSTEM_PROMPT;
@@ -13,7 +14,9 @@
1314
import java.util.Map;
1415

1516
import org.apache.commons.text.StringSubstitutor;
17+
import org.opensearch.OpenSearchException;
1618
import org.opensearch.core.action.ActionListener;
19+
import org.opensearch.ml.common.settings.MLFeatureEnabledSetting;
1720
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
1821
import org.opensearch.ml.common.spi.tools.WithModelTool;
1922
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
@@ -116,6 +119,7 @@ public boolean validate(Map<String, String> parameters) {
116119
public static class Factory implements WithModelTool.Factory<QueryPlanningTool> {
117120
private Client client;
118121
private static volatile Factory INSTANCE;
122+
private static MLFeatureEnabledSetting mlFeatureEnabledSetting;
119123

120124
public static Factory getInstance() {
121125
if (INSTANCE != null) {
@@ -130,13 +134,18 @@ public static Factory getInstance() {
130134
}
131135
}
132136

133-
public void init(Client client) {
137+
public void init(Client client, MLFeatureEnabledSetting mlFeatureEnabledSetting) {
134138
this.client = client;
139+
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
135140
}
136141

137142
@Override
138143
public QueryPlanningTool create(Map<String, Object> map) {
139144

145+
if (!mlFeatureEnabledSetting.isAgenticSearchEnabled()) {
146+
throw new OpenSearchException(ML_COMMONS_AGENTIC_SEARCH_DISABLED_MESSAGE);
147+
}
148+
140149
MLModelTool queryGenerationTool = MLModelTool.Factory.getInstance().create(map);
141150

142151
String type = (String) map.get(GENERATION_TYPE_FIELD);

0 commit comments

Comments
 (0)