Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.http.client.CredentialsProvider;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.logging.log4j.core.config.plugins.util.PluginManager;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseException;
Expand All @@ -43,6 +44,9 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -53,6 +57,7 @@
import static org.elasticsearch.xpack.esql.EsqlTestUtils.reader;

public class CsvTestsDataLoader {
private static final int PARALLEL_THREADS = 10;
private static final int BULK_DATA_SIZE = 100_000;
private static final TestDataset EMPLOYEES = new TestDataset("employees", "mapping-default.json", "employees.csv").noSubfields();
private static final TestDataset EMPLOYEES_INCOMPATIBLE = new TestDataset(
Expand Down Expand Up @@ -396,12 +401,27 @@ private static void loadDataSetIntoEs(
boolean supportsSourceFieldMapping,
boolean inferenceEnabled,
IndexCreator indexCreator
) throws IOException {
if (PARALLEL_THREADS > 1) {
loadDataSetIntoEsParallel(client, supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled, indexCreator);
} else {
loadDataSetIntoEsSequential(client, supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled, indexCreator);
}
}

private static void loadDataSetIntoEsSequential(
RestClient client,
boolean supportsIndexModeLookup,
boolean supportsSourceFieldMapping,
boolean inferenceEnabled,
IndexCreator indexCreator
) throws IOException {
Logger logger = LogManager.getLogger(CsvTestsDataLoader.class);

Set<String> loadedDatasets = new HashSet<>();
for (var dataset : availableDatasetsForEs(supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled)) {
load(client, dataset, logger, indexCreator);
createIndex(client, dataset, indexCreator);
loadData(client, dataset, logger);
loadedDatasets.add(dataset.indexName);
}
forceMerge(client, loadedDatasets, logger);
Expand All @@ -410,6 +430,71 @@ private static void loadDataSetIntoEs(
}
}

private static void loadDataSetIntoEsParallel(
RestClient client,
boolean supportsIndexModeLookup,
boolean supportsSourceFieldMapping,
boolean inferenceEnabled,
IndexCreator indexCreator
) throws IOException {
Logger logger = LogManager.getLogger(CsvTestsDataLoader.class);
Set<TestDataset> datasets = availableDatasetsForEs(supportsIndexModeLookup, supportsSourceFieldMapping, inferenceEnabled);
ExecutorService executor = Executors.newFixedThreadPool(PARALLEL_THREADS);
try {
executeInParallel(
executor,
datasets,
dataset -> createIndex(client, dataset, indexCreator),
"Failed to create indices in parallel"
);

executeInParallel(executor, datasets, dataset -> loadData(client, dataset, logger), "Failed to load data in parallel");

forceMerge(client, datasets.stream().map(d -> d.indexName).collect(Collectors.toSet()), logger);

executeInParallel(
executor,
ENRICH_POLICIES,
policy -> loadEnrichPolicy(client, policy.policyName, policy.policyFileName, logger),
"Failed to load enrich policies in parallel"
);
} finally {
executor.shutdown();
}
}

@FunctionalInterface
private interface IOConsumer<T> {
void accept(T t) throws IOException;
}

private static <T> void executeInParallel(ExecutorService executor, Iterable<T> items, IOConsumer<T> consumer, String errorMessage)
throws IOException {
List<Future<?>> futures = new ArrayList<>();
for (T item : items) {
futures.add(executor.submit(() -> {
try {
consumer.accept(item);
} catch (IOException e) {
throw new RuntimeException(e);
}
}));
}

RuntimeException exception = null;
for (Future<?> future : futures) {
try {
future.get();
} catch (Exception e) {
exception = ExceptionsHelper.useOrSuppress(exception, ExceptionsHelper.convertToRuntime(e));
}
}

if (exception != null) {
throw new IOException(errorMessage, exception);
}
}

public static void createInferenceEndpoints(RestClient client) throws IOException {
if (clusterHasSparseEmbeddingInferenceEndpoint(client) == false) {
createSparseEmbeddingInferenceEndpoint(client);
Expand Down Expand Up @@ -535,11 +620,13 @@ private static URL getResource(String name) {
return result;
}

private static void load(RestClient client, TestDataset dataset, Logger logger, IndexCreator indexCreator) throws IOException {
private static void createIndex(RestClient client, TestDataset dataset, IndexCreator indexCreator) throws IOException {
URL mapping = getResource("/" + dataset.mappingFileName);
Settings indexSettings = dataset.readSettingsFile();
indexCreator.createIndex(client, dataset.indexName, readMappingFile(mapping, dataset.typeMapping), indexSettings);
}

private static void loadData(RestClient client, TestDataset dataset, Logger logger) throws IOException {
// Some examples only test that the query and mappings are valid, and don't need example data. Use .noData() for those
if (dataset.dataFileName != null) {
URL data = getResource("/data/" + dataset.dataFileName);
Expand Down