diff --git a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java index 4f8dfa89027ee..835eb2b27664a 100644 --- a/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java +++ b/plugins/workload-management/src/internalClusterTest/java/org/opensearch/plugin/wlm/WlmAutoTaggingIT.java @@ -439,6 +439,43 @@ public void testDeleteRuleForNonexistentId() throws Exception { assertTrue("Expected error message for nonexistent rule ID", exception.getMessage().contains("no such index")); } + public void testScrollRequestsAreAlsoTagged() throws Exception { + String workloadGroupId = "wlm_auto_tag_scroll"; + String ruleId = "wlm_auto_tag_scroll_rule"; + String indexName = "scroll_tagged_index"; + + setWlmMode("enabled"); + + WorkloadGroup workloadGroup = createWorkloadGroup("scroll_tagging_group", workloadGroupId); + updateWorkloadGroupInClusterState(PUT, workloadGroup); + + FeatureType featureType = AutoTaggingRegistry.getFeatureType(WorkloadGroupFeatureType.NAME); + createRule(ruleId, "scroll tagging rule", indexName, featureType, workloadGroupId); + + indexDocument(indexName); + + assertBusy(() -> { + int completionsBefore = getCompletions(workloadGroupId); + + SearchResponse initial = client().prepareSearch(indexName) + .setQuery(QueryBuilders.matchAllQuery()) + .setScroll(TimeValue.timeValueMinutes(1)) + .setSize(1) + .get(); + + String scrollId = initial.getScrollId(); + assertNotNull("scrollId must not be null", scrollId); + + int afterInitialSearch = getCompletions(workloadGroupId); + assertTrue("Expected completions to increase after initial search with scroll", afterInitialSearch > completionsBefore); + + client().prepareSearchScroll(scrollId).setScroll(TimeValue.timeValueMinutes(1)).get(); + + int afterScroll = getCompletions(workloadGroupId); + assertTrue("Expected completions to increase after scroll request as well", afterScroll > afterInitialSearch); + }); + } + // Helper functions private void createRule(String ruleId, String ruleName, String indexPattern, FeatureType featureType, String workloadGroupId) throws Exception { diff --git a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java index c6294ed7ac242..2b57ccd74ad4e 100644 --- a/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java +++ b/plugins/workload-management/src/main/java/org/opensearch/plugin/wlm/AutoTaggingActionFilter.java @@ -11,14 +11,18 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.IndicesRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.support.ActionFilter; import org.opensearch.action.support.ActionFilterChain; import org.opensearch.action.support.ActionRequestMetadata; +import org.opensearch.cluster.metadata.OptionallyResolvedIndices; +import org.opensearch.cluster.metadata.ResolvedIndices; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; import org.opensearch.plugin.wlm.rule.attribute_extractor.IndicesExtractor; import org.opensearch.plugin.wlm.spi.AttributeExtractorExtension; import org.opensearch.rule.InMemoryRuleProcessingService; +import org.opensearch.rule.RuleAttribute; import org.opensearch.rule.attribute_extractor.AttributeExtractor; import org.opensearch.rule.autotagging.Attribute; import org.opensearch.rule.autotagging.FeatureType; @@ -31,6 +35,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import static org.opensearch.plugin.wlm.WorkloadManagementPlugin.PRINCIPAL_ATTRIBUTE_NAME; @@ -80,14 +85,44 @@ public void app ActionListener listener, ActionFilterChain chain ) { - final boolean isValidRequest = request instanceof SearchRequest; + final boolean isSearchRequest = request instanceof SearchRequest; + final boolean isSearchScrollRequest = request instanceof SearchScrollRequest; + final boolean isValidRequest = isSearchRequest || isSearchScrollRequest; if (!isValidRequest || wlmClusterSettingValuesProvider.getWlmMode() == WlmMode.DISABLED) { chain.proceed(task, action, request, listener); return; } List> attributeExtractors = new ArrayList<>(); - attributeExtractors.add(new IndicesExtractor((IndicesRequest) request)); + final OptionallyResolvedIndices optionallyResolved = actionRequestMetadata.resolvedIndices(); + final boolean hasResolvedIndices = optionallyResolved instanceof ResolvedIndices; + + if (hasResolvedIndices) { + final ResolvedIndices resolved = (ResolvedIndices) optionallyResolved; + final Set names = resolved.local().names(); + + attributeExtractors.add(new AttributeExtractor<>() { + @Override + public Attribute getAttribute() { + return RuleAttribute.INDEX_PATTERN; + } + + @Override + public Iterable extract() { + return names; + } + + @Override + public LogicalOperator getLogicalOperator() { + return LogicalOperator.AND; + } + }); + } else if (isSearchRequest) { + attributeExtractors.add(new IndicesExtractor((IndicesRequest) request)); + } else { + chain.proceed(task, action, request, listener); + return; + } if (featureType.getAllowedAttributesRegistry().containsKey(PRINCIPAL_ATTRIBUTE_NAME)) { Attribute attribute = featureType.getAllowedAttributesRegistry().get(PRINCIPAL_ATTRIBUTE_NAME); diff --git a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java index ed5e8e25843ea..2766aebb6ec28 100644 --- a/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java +++ b/plugins/workload-management/src/test/java/org/opensearch/plugin/wlm/AutoTaggingActionFilterTests.java @@ -11,8 +11,10 @@ import org.opensearch.action.ActionRequest; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest; import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchScrollRequest; import org.opensearch.action.support.ActionFilterChain; import org.opensearch.action.support.ActionRequestMetadata; +import org.opensearch.cluster.metadata.ResolvedIndices; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; import org.opensearch.core.action.ActionResponse; @@ -93,6 +95,25 @@ public void testApplyForInValidRequest() { verify(ruleProcessingService, times(0)).evaluateLabel(anyList()); } + public void testApplyForScrollRequestWithResolvedIndices() { + SearchScrollRequest request = mock(SearchScrollRequest.class); + ActionFilterChain mockFilterChain = mock(TestActionFilterChain.class); + + @SuppressWarnings("unchecked") + ActionRequestMetadata metadata = mock(ActionRequestMetadata.class); + ResolvedIndices resolved = ResolvedIndices.of("logs-scroll-index"); + when(metadata.resolvedIndices()).thenReturn(resolved); + + try (ThreadContext.StoredContext context = threadPool.getThreadContext().stashContext()) { + when(ruleProcessingService.evaluateLabel(anyList())).thenReturn(Optional.of("ScrollQG_ID")); + + autoTaggingActionFilter.apply(mock(Task.class), "Test", request, metadata, null, mockFilterChain); + + assertEquals("ScrollQG_ID", threadPool.getThreadContext().getHeader(WorkloadGroupTask.WORKLOAD_GROUP_ID_HEADER)); + verify(ruleProcessingService, times(1)).evaluateLabel(anyList()); + } + } + public enum WLMFeatureType implements FeatureType { WLM; diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 59bb88b0f6f67..00b6280af7323 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -774,7 +774,9 @@ public void sendSearchResponse(InternalSearchResponse internalSearchResponse, At raisePhaseFailure(new SearchPhaseExecutionException("", "Shard failures", null, failures)); } else { final Version minNodeVersion = clusterState.nodes().getMinNodeVersion(); - final String scrollId = request.scroll() != null ? TransportSearchHelper.buildScrollId(queryResults, minNodeVersion) : null; + final String scrollId = request.scroll() != null + ? TransportSearchHelper.buildScrollId(queryResults, request.indices(), minNodeVersion) + : null; final String searchContextId; if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minNodeVersion); diff --git a/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java b/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java index b723b97b5c413..82009af3b0cd1 100644 --- a/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java +++ b/server/src/main/java/org/opensearch/action/search/ParsedScrollId.java @@ -53,11 +53,13 @@ public class ParsedScrollId { private final String type; private final SearchContextIdForNode[] context; + private final String[] originalIndices; - ParsedScrollId(String source, String type, SearchContextIdForNode[] context) { + ParsedScrollId(String source, String type, SearchContextIdForNode[] context, String[] originalIndices) { this.source = source; this.type = type; this.context = context; + this.originalIndices = originalIndices; } public String getSource() { @@ -72,6 +74,10 @@ public SearchContextIdForNode[] getContext() { return context; } + public String[] getOriginalIndices() { + return originalIndices; + } + public boolean hasLocalIndices() { return Arrays.stream(context).anyMatch(c -> c.getClusterAlias() == null); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java b/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java index 044efdc36d04f..a25a3ff719c52 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java +++ b/server/src/main/java/org/opensearch/action/search/SearchScrollRequest.java @@ -61,6 +61,7 @@ public class SearchScrollRequest extends ActionRequest implements ToXContentObje private String scrollId; private Scroll scroll; + private transient ParsedScrollId parsedScrollId; public SearchScrollRequest() {} @@ -103,7 +104,10 @@ public SearchScrollRequest scrollId(String scrollId) { } public ParsedScrollId parseScrollId() { - return TransportSearchHelper.parseScrollId(scrollId); + if (parsedScrollId == null && scrollId != null) { + parsedScrollId = TransportSearchHelper.parseScrollId(scrollId); + } + return parsedScrollId; } /** diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java b/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java index 5c260e02e7275..900b1223c6916 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchHelper.java @@ -60,7 +60,13 @@ static InternalScrollSearchRequest internalScrollSearchRequest(ShardSearchContex return new InternalScrollSearchRequest(request, id); } + public static final Version INDICES_IN_SCROLL_ID_VERSION = Version.V_3_3_2; + static String buildScrollId(AtomicArray searchPhaseResults, Version version) { + return buildScrollId(searchPhaseResults, null, version); + } + + static String buildScrollId(AtomicArray searchPhaseResults, String[] originalIndices, Version version) { try { BytesStreamOutput out = new BytesStreamOutput(); out.writeString(INCLUDE_CONTEXT_UUID); @@ -78,6 +84,19 @@ static String buildScrollId(AtomicArray searchPhase out.writeString(searchShardTarget.getNodeId()); } } + + if (version.onOrAfter(INDICES_IN_SCROLL_ID_VERSION)) { + // To keep autotagging consistent between the initial SearchRequest + // and subsequent SearchScrollRequests, we store exactly the same + // index targets that were visible to the indices attribute during + // the "search" phase + if (originalIndices != null && originalIndices.length > 0) { + out.writeVInt(originalIndices.length); + for (String index : originalIndices) { + out.writeString(index); + } + } + } byte[] bytes = BytesReference.toBytes(out.bytes()); return Base64.getUrlEncoder().encodeToString(bytes); } catch (IOException e) { @@ -114,10 +133,22 @@ static ParsedScrollId parseScrollId(String scrollId) { } context[i] = new SearchContextIdForNode(clusterAlias, target, new ShardSearchContextId(contextUUID, id)); } + + String[] originalIndices; + if (in.getPosition() < bytes.length) { + final int numOriginalIndices = in.readVInt(); + originalIndices = new String[numOriginalIndices]; + for (int i = 0; i < numOriginalIndices; i++) { + originalIndices[i] = in.readString(); + } + } else { + originalIndices = new String[0]; + } + if (in.getPosition() != bytes.length) { throw new IllegalArgumentException("Not all bytes were read"); } - return new ParsedScrollId(scrollId, type, context); + return new ParsedScrollId(scrollId, type, context, originalIndices); } catch (Exception e) { throw new IllegalArgumentException("Cannot parse scroll id", e); } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java index b0f98a4c1703b..3db9e93fdd092 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchScrollAction.java @@ -34,6 +34,9 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.TransportIndicesResolvingAction; +import org.opensearch.cluster.metadata.OptionallyResolvedIndices; +import org.opensearch.cluster.metadata.ResolvedIndices; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; @@ -48,7 +51,9 @@ * * @opensearch.internal */ -public class TransportSearchScrollAction extends HandledTransportAction { +public class TransportSearchScrollAction extends HandledTransportAction + implements + TransportIndicesResolvingAction { private final ClusterService clusterService; private final SearchTransportService searchTransportService; @@ -79,7 +84,7 @@ protected void doExecute(Task task, SearchScrollRequest request, ActionListener< ((WorkloadGroupTask) task).setWorkloadGroupId(threadPool.getThreadContext()); } - ParsedScrollId scrollId = TransportSearchHelper.parseScrollId(request.scrollId()); + ParsedScrollId scrollId = request.parseScrollId(); Runnable action; switch (scrollId.getType()) { case ParsedScrollId.QUERY_THEN_FETCH_TYPE: @@ -114,4 +119,26 @@ protected void doExecute(Task task, SearchScrollRequest request, ActionListener< listener.onFailure(e); } } + + @Override + public OptionallyResolvedIndices resolveIndices(SearchScrollRequest request) { + try { + final String scrollIdString = request.scrollId(); + if (scrollIdString == null || scrollIdString.isEmpty()) { + return OptionallyResolvedIndices.unknown(); + } + + final ParsedScrollId parsed = request.parseScrollId(); + if (parsed == null) { + return OptionallyResolvedIndices.unknown(); + } + final String[] originalIndices = parsed.getOriginalIndices(); + if (originalIndices == null || originalIndices.length == 0) { + return OptionallyResolvedIndices.unknown(); + } + return ResolvedIndices.of(originalIndices); + } catch (Exception e) { + return OptionallyResolvedIndices.unknown(); + } + } } diff --git a/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java b/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java index 2d90bf9ba1bdd..0985cb5308802 100644 --- a/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java +++ b/server/src/test/java/org/opensearch/action/search/ParsedScrollIdTests.java @@ -50,7 +50,12 @@ public void testHasLocalIndices() { new ShardSearchContextId(randomAlphaOfLength(8), randomLong()) ); } - final ParsedScrollId parsedScrollId = new ParsedScrollId(randomAlphaOfLength(8), randomAlphaOfLength(8), searchContextIdForNodes); + final ParsedScrollId parsedScrollId = new ParsedScrollId( + randomAlphaOfLength(8), + randomAlphaOfLength(8), + searchContextIdForNodes, + new String[0] + ); assertEquals(hasLocal, parsedScrollId.hasLocalIndices()); } diff --git a/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java index 12ab735c4d324..f5ceef0885520 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchScrollAsyncActionTests.java @@ -481,7 +481,7 @@ protected void onFirstPhaseResult(int shardId, SearchAsyncActionTests.TestSearch private static ParsedScrollId getParsedScrollId(SearchContextIdForNode... idsForNodes) { List searchContextIdForNodes = Arrays.asList(idsForNodes); Collections.shuffle(searchContextIdForNodes, random()); - return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0])); + return new ParsedScrollId("", "test", searchContextIdForNodes.toArray(new SearchContextIdForNode[0]), new String[0]); } private ActionListener dummyListener() {