Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.elasticsearch.Version;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.IOUtils;
Expand All @@ -38,6 +39,8 @@
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.regex.Pattern;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -251,10 +254,7 @@ static RestClient twoClients(RestClient localClient, RestClient remoteClient) th
return bulkClient.performRequest(request);
} else {
Request[] clones = cloneRequests(request, 2);
Response resp1 = remoteClient.performRequest(clones[0]);
Response resp2 = localClient.performRequest(clones[1]);
assertEquals(resp1.getStatusLine().getStatusCode(), resp2.getStatusLine().getStatusCode());
return resp2;
return runInParallel(localClient, remoteClient, clones);
}
});
doAnswer(invocation -> {
Expand Down Expand Up @@ -289,6 +289,44 @@ static Request[] cloneRequests(Request orig, int numClones) throws IOException {
return clones;
}

/**
* Run {@link #cloneRequests cloned} requests in parallel.
*/
static Response runInParallel(RestClient localClient, RestClient remoteClient, Request[] clones) throws Throwable {
CompletableFuture<Response> remoteResponse = new CompletableFuture<>();
CompletableFuture<Response> localResponse = new CompletableFuture<>();
remoteClient.performRequestAsync(clones[0], new ResponseListener() {
@Override
public void onSuccess(Response response) {
remoteResponse.complete(response);
}

@Override
public void onFailure(Exception exception) {
remoteResponse.completeExceptionally(exception);
}
});
localClient.performRequestAsync(clones[1], new ResponseListener() {
@Override
public void onSuccess(Response response) {
localResponse.complete(response);
}

@Override
public void onFailure(Exception exception) {
localResponse.completeExceptionally(exception);
}
});
try {
Response remote = remoteResponse.get();
Response local = localResponse.get();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: you could use PlainActionFuture instead of anonymous listeners implementations here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this code comes from a revert of a revert of an approved PR, I do not want to fiddle with it.

assertEquals(remote.getStatusLine().getStatusCode(), local.getStatusLine().getStatusCode());
return local;
} catch (ExecutionException e) {
throw e.getCause();
}
}

/**
* Convert FROM employees ... => FROM *:employees,employees
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.logging.log4j.core.config.plugins.util.PluginManager;
import org.apache.lucene.util.IOConsumer;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
Expand All @@ -30,6 +31,7 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.logging.LogManager;
import org.elasticsearch.logging.Logger;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.ESRestTestCase;
import org.elasticsearch.xcontent.XContentType;

Expand All @@ -43,6 +45,7 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Semaphore;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
Expand All @@ -55,6 +58,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.reader;

public class CsvTestsDataLoader {
private static final int PARALLEL_THREADS = 10;
private static final int BULK_DATA_SIZE = 100_000;
private static final TestDataset EMPLOYEES = new TestDataset("employees", "mapping-default.json", "employees.csv").noSubfields();
private static final TestDataset EMPLOYEES_INCOMPATIBLE = new TestDataset(
Expand Down Expand Up @@ -429,18 +433,42 @@ private static void loadDataSetIntoEs(
IndexCreator indexCreator
) throws IOException {
Logger logger = LogManager.getLogger(CsvTestsDataLoader.class);
List<TestDataset> datasets = availableDatasetsForEs(
supportsIndexModeLookup,
supportsSourceFieldMapping,
inferenceEnabled,
timeSeriesOnly
).stream().toList();

logger.info("Creating test indices");
executeInParallel(datasets, dataset -> createIndex(client, dataset, indexCreator), "Failed to create indices in parallel");

Set<String> loadedDatasets = new HashSet<>();
logger.info("Loading test datasets");
for (var dataset : availableDatasetsForEs(supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled, timeSeriesOnly)) {
load(client, dataset, logger, indexCreator);
loadedDatasets.add(dataset.indexName);
}
forceMerge(client, loadedDatasets, logger);
executeInParallel(datasets, dataset -> loadData(client, dataset, logger), "Failed to load data in parallel");

forceMerge(client, datasets.stream().map(d -> d.indexName).collect(Collectors.toSet()), logger);

logger.info("Loading enrich policies");
for (var policy : ENRICH_POLICIES) {
loadEnrichPolicy(client, policy.policyName, policy.policyFileName, logger);
}
executeInParallel(
ENRICH_POLICIES,
policy -> loadEnrichPolicy(client, policy.policyName, policy.policyFileName, logger),
"Failed to load enrich policies in parallel"
);

}

private static <T> void executeInParallel(List<T> items, IOConsumer<T> consumer, String errorMessage) {
Semaphore semaphore = new Semaphore(PARALLEL_THREADS);
ESTestCase.runInParallel(items.size(), i -> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ESTestCase.runInParallel internally creates a thread per each task.
Today we have 47 data sets but we only want to populate 10 at the time. I believe it would be much cheaper to use Executors.newFixedThreadPool(threads) rather than create a thread per task, especially if we want to limit parallelism.

try {
semaphore.acquire();
consumer.accept(items.get(i));
} catch (IOException | InterruptedException e) {
throw new RuntimeException(errorMessage, e);
} finally {
semaphore.release();
}
});
}

public static void createInferenceEndpoints(RestClient client) throws IOException {
Expand Down Expand Up @@ -598,12 +626,14 @@ private static URL getResource(String name) {
return result;
}

private static void load(RestClient client, TestDataset dataset, Logger logger, IndexCreator indexCreator) throws IOException {
logger.info("Loading dataset [{}] into ES index [{}]", dataset.dataFileName, dataset.indexName);
private static void createIndex(RestClient client, TestDataset dataset, IndexCreator indexCreator) throws IOException {
URL mapping = getResource("/" + dataset.mappingFileName);
Settings indexSettings = dataset.readSettingsFile();
indexCreator.createIndex(client, dataset.indexName, readMappingFile(mapping, dataset.typeMapping), indexSettings);
}

private static void loadData(RestClient client, TestDataset dataset, Logger logger) throws IOException {
logger.info("Loading dataset [{}] into ES index [{}]", dataset.dataFileName, dataset.indexName);
// Some examples only test that the query and mappings are valid, and don't need example data. Use .noData() for those
if (dataset.dataFileName != null) {
URL data = getResource("/data/" + dataset.dataFileName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,18 @@

import java.net.ConnectException;

import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class CsvTestsDataLoaderTests extends ESTestCase {

public void testCsvTestsDataLoaderExecution() {
ConnectException ce = expectThrows(ConnectException.class, () -> CsvTestsDataLoader.main(new String[] {}));
assertThat(ce.getMessage(), startsWith("Connection refused"));
Throwable cause = expectThrows(AssertionError.class, () -> CsvTestsDataLoader.main(new String[] {}));
// find the root cause
while (cause.getCause() != null) {
cause = cause.getCause();
}
assertThat(cause, instanceOf(ConnectException.class));
assertThat(cause.getMessage(), startsWith("Connection refused"));
}
}