| 
 | 1 | +/*  | 
 | 2 | + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one  | 
 | 3 | + * or more contributor license agreements. Licensed under the "Elastic License  | 
 | 4 | + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side  | 
 | 5 | + * Public License v 1"; you may not use this file except in compliance with, at  | 
 | 6 | + * your election, the "Elastic License 2.0", the "GNU Affero General Public  | 
 | 7 | + * License v3.0 only", or the "Server Side Public License, v 1".  | 
 | 8 | + */  | 
 | 9 | + | 
 | 10 | +package org.elasticsearch.search.aggregations.bucket;  | 
 | 11 | + | 
 | 12 | +import org.elasticsearch.action.bulk.BulkRequestBuilder;  | 
 | 13 | +import org.elasticsearch.action.search.SearchRequestBuilder;  | 
 | 14 | +import org.elasticsearch.action.search.TransportSearchAction;  | 
 | 15 | +import org.elasticsearch.action.support.WriteRequest;  | 
 | 16 | +import org.elasticsearch.common.settings.Settings;  | 
 | 17 | +import org.elasticsearch.common.util.CollectionUtils;  | 
 | 18 | +import org.elasticsearch.core.TimeValue;  | 
 | 19 | +import org.elasticsearch.index.mapper.OnScriptError;  | 
 | 20 | +import org.elasticsearch.plugins.Plugin;  | 
 | 21 | +import org.elasticsearch.plugins.ScriptPlugin;  | 
 | 22 | +import org.elasticsearch.script.LongFieldScript;  | 
 | 23 | +import org.elasticsearch.script.ScriptContext;  | 
 | 24 | +import org.elasticsearch.script.ScriptEngine;  | 
 | 25 | +import org.elasticsearch.search.aggregations.bucket.filter.FiltersAggregator.KeyedFilter;  | 
 | 26 | +import org.elasticsearch.search.lookup.SearchLookup;  | 
 | 27 | +import org.elasticsearch.tasks.TaskInfo;  | 
 | 28 | +import org.elasticsearch.test.ESIntegTestCase;  | 
 | 29 | +import org.elasticsearch.xcontent.XContentBuilder;  | 
 | 30 | +import org.elasticsearch.xcontent.json.JsonXContent;  | 
 | 31 | + | 
 | 32 | +import java.util.Collection;  | 
 | 33 | +import java.util.List;  | 
 | 34 | +import java.util.Map;  | 
 | 35 | +import java.util.Set;  | 
 | 36 | +import java.util.concurrent.Semaphore;  | 
 | 37 | + | 
 | 38 | +import static org.elasticsearch.index.query.QueryBuilders.termQuery;  | 
 | 39 | +import static org.elasticsearch.search.aggregations.AggregationBuilders.filters;  | 
 | 40 | +import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;  | 
 | 41 | +import static org.hamcrest.Matchers.empty;  | 
 | 42 | +import static org.hamcrest.Matchers.greaterThan;  | 
 | 43 | +import static org.hamcrest.Matchers.not;  | 
 | 44 | + | 
 | 45 | +@ESIntegTestCase.SuiteScopeTestCase  | 
 | 46 | +public class FiltersCancellationIT extends ESIntegTestCase {  | 
 | 47 | + | 
 | 48 | +    private static final String INDEX = "idx";  | 
 | 49 | +    private static final String PAUSE_FIELD = "pause";  | 
 | 50 | +    private static final String NUMERIC_FIELD = "value";  | 
 | 51 | + | 
 | 52 | +    private static final int NUM_DOCS = 100_000;  | 
 | 53 | +    private static final int SEMAPHORE_PERMITS = NUM_DOCS - 1000;  | 
 | 54 | +    private static final Semaphore SCRIPT_SEMAPHORE = new Semaphore(0);  | 
 | 55 | + | 
 | 56 | +    @Override  | 
 | 57 | +    protected Collection<Class<? extends Plugin>> nodePlugins() {  | 
 | 58 | +        return CollectionUtils.appendToCopy(super.nodePlugins(), pausableFieldPluginClass());  | 
 | 59 | +    }  | 
 | 60 | + | 
 | 61 | +    protected Class<? extends Plugin> pausableFieldPluginClass() {  | 
 | 62 | +        return PauseScriptPlugin.class;  | 
 | 63 | +    }  | 
 | 64 | + | 
 | 65 | +    @Override  | 
 | 66 | +    public void setupSuiteScopeCluster() throws Exception {  | 
 | 67 | +        try (XContentBuilder mapping = JsonXContent.contentBuilder()) {  | 
 | 68 | +            mapping.startObject();  | 
 | 69 | +            mapping.startObject("runtime");  | 
 | 70 | +            {  | 
 | 71 | +                mapping.startObject(PAUSE_FIELD);  | 
 | 72 | +                {  | 
 | 73 | +                    mapping.field("type", "long");  | 
 | 74 | +                    mapping.startObject("script").field("source", "").field("lang", PauseScriptPlugin.PAUSE_SCRIPT_LANG).endObject();  | 
 | 75 | +                }  | 
 | 76 | +                mapping.endObject();  | 
 | 77 | +                mapping.startObject(NUMERIC_FIELD);  | 
 | 78 | +                {  | 
 | 79 | +                    mapping.field("type", "long");  | 
 | 80 | +                }  | 
 | 81 | +                mapping.endObject();  | 
 | 82 | +            }  | 
 | 83 | +            mapping.endObject();  | 
 | 84 | +            mapping.endObject();  | 
 | 85 | + | 
 | 86 | +            client().admin().indices().prepareCreate(INDEX).setMapping(mapping).get();  | 
 | 87 | +        }  | 
 | 88 | + | 
 | 89 | +        int DOCS_PER_BULK = 100_000;  | 
 | 90 | +        for (int i = 0; i < NUM_DOCS; i += DOCS_PER_BULK) {  | 
 | 91 | +            BulkRequestBuilder bulk = client().prepareBulk().setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);  | 
 | 92 | +            for (int j = 0; j < DOCS_PER_BULK; j++) {  | 
 | 93 | +                int docId = i + j;  | 
 | 94 | +                bulk.add(prepareIndex(INDEX).setId(Integer.toString(docId)).setSource(NUMERIC_FIELD, docId));  | 
 | 95 | +            }  | 
 | 96 | +            bulk.get();  | 
 | 97 | +        }  | 
 | 98 | + | 
 | 99 | +        client().admin().indices().prepareForceMerge(INDEX).setMaxNumSegments(1).get();  | 
 | 100 | +    }  | 
 | 101 | + | 
 | 102 | +    public void testFiltersCountCancellation() throws Exception {  | 
 | 103 | +        ensureProperCancellation(  | 
 | 104 | +            client().prepareSearch(INDEX)  | 
 | 105 | +                .addAggregation(  | 
 | 106 | +                    filters(  | 
 | 107 | +                        "filters",  | 
 | 108 | +                        new KeyedFilter[] {  | 
 | 109 | +                            new KeyedFilter("filter1", termQuery(PAUSE_FIELD, 1)),  | 
 | 110 | +                            new KeyedFilter("filter2", termQuery(PAUSE_FIELD, 2)) }  | 
 | 111 | +                    )  | 
 | 112 | +                )  | 
 | 113 | +        );  | 
 | 114 | +    }  | 
 | 115 | + | 
 | 116 | +    public void testFiltersSubAggsCancellation() throws Exception {  | 
 | 117 | +        ensureProperCancellation(  | 
 | 118 | +            client().prepareSearch(INDEX)  | 
 | 119 | +                .addAggregation(  | 
 | 120 | +                    filters(  | 
 | 121 | +                        "filters",  | 
 | 122 | +                        new KeyedFilter[] {  | 
 | 123 | +                            new KeyedFilter("filter1", termQuery(PAUSE_FIELD, 1)),  | 
 | 124 | +                            new KeyedFilter("filter2", termQuery(PAUSE_FIELD, 2)) }  | 
 | 125 | +                    ).subAggregation(terms("sub").field(PAUSE_FIELD))  | 
 | 126 | +                )  | 
 | 127 | +        );  | 
 | 128 | +    }  | 
 | 129 | + | 
 | 130 | +    private void ensureProperCancellation(SearchRequestBuilder searchRequestBuilder) throws Exception {  | 
 | 131 | +        var searchRequestFuture = searchRequestBuilder.setTimeout(TimeValue.timeValueSeconds(1)).execute();  | 
 | 132 | +        assertFalse(searchRequestFuture.isCancelled());  | 
 | 133 | +        assertFalse(searchRequestFuture.isDone());  | 
 | 134 | + | 
 | 135 | +        // Check that there are search tasks running  | 
 | 136 | +        assertThat(getSearchTasks(), not(empty()));  | 
 | 137 | + | 
 | 138 | +        // Wait for the script field to get blocked  | 
 | 139 | +        assertBusy(() -> { assertThat(SCRIPT_SEMAPHORE.getQueueLength(), greaterThan(0)); });  | 
 | 140 | + | 
 | 141 | +        // Cancel the tasks  | 
 | 142 | +        // Warning: Adding a waitForCompletion(true)/execute() here sometimes causes tasks to not get canceled and threads to get stuck  | 
 | 143 | +        client().admin().cluster().prepareCancelTasks().setActions(TransportSearchAction.NAME + "*").get();  | 
 | 144 | + | 
 | 145 | +        SCRIPT_SEMAPHORE.release(SEMAPHORE_PERMITS);  | 
 | 146 | + | 
 | 147 | +        // Ensure the search request finished and that there are no more search tasks  | 
 | 148 | +        assertBusy(() -> {  | 
 | 149 | +            assertTrue(searchRequestFuture.isDone());  | 
 | 150 | +            assertThat(getSearchTasks(), empty());  | 
 | 151 | +        });  | 
 | 152 | +    }  | 
 | 153 | + | 
 | 154 | +    private List<TaskInfo> getSearchTasks() {  | 
 | 155 | +        return client().admin()  | 
 | 156 | +            .cluster()  | 
 | 157 | +            .prepareListTasks()  | 
 | 158 | +            .setActions(TransportSearchAction.NAME + "*")  | 
 | 159 | +            .setDetailed(true)  | 
 | 160 | +            .get()  | 
 | 161 | +            .getTasks();  | 
 | 162 | +    }  | 
 | 163 | + | 
 | 164 | +    public static class PauseScriptPlugin extends Plugin implements ScriptPlugin {  | 
 | 165 | +        public static final String PAUSE_SCRIPT_LANG = "pause";  | 
 | 166 | + | 
 | 167 | +        @Override  | 
 | 168 | +        public ScriptEngine getScriptEngine(Settings settings, Collection<ScriptContext<?>> contexts) {  | 
 | 169 | +            return new ScriptEngine() {  | 
 | 170 | +                @Override  | 
 | 171 | +                public String getType() {  | 
 | 172 | +                    return PAUSE_SCRIPT_LANG;  | 
 | 173 | +                }  | 
 | 174 | + | 
 | 175 | +                @Override  | 
 | 176 | +                @SuppressWarnings("unchecked")  | 
 | 177 | +                public <FactoryType> FactoryType compile(  | 
 | 178 | +                    String name,  | 
 | 179 | +                    String code,  | 
 | 180 | +                    ScriptContext<FactoryType> context,  | 
 | 181 | +                    Map<String, String> params  | 
 | 182 | +                ) {  | 
 | 183 | +                    if (context == LongFieldScript.CONTEXT) {  | 
 | 184 | +                        return (FactoryType) new LongFieldScript.Factory() {  | 
 | 185 | +                            @Override  | 
 | 186 | +                            public LongFieldScript.LeafFactory newFactory(  | 
 | 187 | +                                String fieldName,  | 
 | 188 | +                                Map<String, Object> params,  | 
 | 189 | +                                SearchLookup searchLookup,  | 
 | 190 | +                                OnScriptError onScriptError  | 
 | 191 | +                            ) {  | 
 | 192 | +                                return ctx -> new LongFieldScript(fieldName, params, searchLookup, onScriptError, ctx) {  | 
 | 193 | +                                    @Override  | 
 | 194 | +                                    public void execute() {  | 
 | 195 | +                                        try {  | 
 | 196 | +                                            SCRIPT_SEMAPHORE.acquire();  | 
 | 197 | +                                        } catch (InterruptedException e) {  | 
 | 198 | +                                            throw new AssertionError(e);  | 
 | 199 | +                                        }  | 
 | 200 | +                                        emit(1);  | 
 | 201 | +                                    }  | 
 | 202 | +                                };  | 
 | 203 | +                            }  | 
 | 204 | +                        };  | 
 | 205 | +                    }  | 
 | 206 | +                    throw new IllegalStateException("unsupported type " + context);  | 
 | 207 | +                }  | 
 | 208 | + | 
 | 209 | +                @Override  | 
 | 210 | +                public Set<ScriptContext<?>> getSupportedContexts() {  | 
 | 211 | +                    return Set.of(LongFieldScript.CONTEXT);  | 
 | 212 | +                }  | 
 | 213 | +            };  | 
 | 214 | +        }  | 
 | 215 | +    }  | 
 | 216 | +}  | 
0 commit comments