Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -77,7 +77,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
static final String DATABASES_INDEX_PATTERN = DATABASES_INDEX + "*";
static final int MAX_CHUNK_SIZE = 1024 * 1024;

private final Client client;
private final ProjectClient client;
private final HttpClient httpClient;
private final ClusterService clusterService;
private final ThreadPool threadPool;
Expand All @@ -99,7 +99,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
private final ProjectId projectId;

GeoIpDownloader(
Client client,
ProjectClient client,
HttpClient httpClient,
ClusterService clusterService,
ThreadPool threadPool,
Expand All @@ -116,7 +116,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
ProjectId projectId
) {
super(id, type, action, description, parentTask, headers);
this.client = client.projectClient(projectId);
this.client = client;
this.httpClient = httpClient;
this.clusterService = clusterService;
this.threadPool = threadPool;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ protected GeoIpDownloader createTask(
) {
ProjectId projectId = projectResolver.getProjectId();
return new GeoIpDownloader(
client,
client.projectClient(projectId),
httpClient,
clusterService,
threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.AliasMetadata;
Expand Down Expand Up @@ -63,6 +64,7 @@
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -112,6 +114,7 @@
public class DatabaseNodeServiceTests extends ESTestCase {

private Client client;
private ProjectClient projectClient;
private Path geoIpTmpDir;
private ThreadPool threadPool;
private DatabaseNodeService databaseNodeService;
Expand All @@ -138,7 +141,9 @@ public void setup() throws IOException {
Settings settings = Settings.builder().put("resource.reload.interval.high", TimeValue.timeValueMillis(100)).build();
resourceWatcherService = new ResourceWatcherService(settings, threadPool);

projectClient = mock(ProjectClient.class);
client = mock(Client.class);
when(client.projectClient(any())).thenReturn(projectClient);
ingestService = mock(IngestService.class);
clusterService = mock(ClusterService.class);
geoIpTmpDir = createTempDir();
Expand All @@ -161,6 +166,8 @@ public void cleanup() {
threadPool.shutdownNow();
Releasables.close(toRelease);
toRelease.clear();
verify(client, Mockito.atLeast(0)).projectClient(any());
verifyNoMoreInteractions(client);
}

public void testCheckDatabases() throws Exception {
Expand All @@ -181,7 +188,7 @@ public void testCheckDatabases() throws Exception {
databaseNodeService.checkDatabases(state);
DatabaseReaderLazyLoader database = databaseNodeService.getDatabaseReaderLazyLoader(projectId, "GeoIP2-City.mmdb");
assertThat(database, nullValue());
verify(client, times(0)).search(any());
verify(projectClient, times(0)).search(any());
verify(ingestService, times(0)).reloadPipeline(any(), anyString());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertEquals(0, files.count());
Expand All @@ -199,7 +206,7 @@ public void testCheckDatabases() throws Exception {
databaseNodeService.checkDatabases(state);
database = databaseNodeService.getDatabaseReaderLazyLoader(projectId, "GeoIP2-City.mmdb");
assertThat(database, notNullValue());
verify(client, times(10)).search(any());
verify(projectClient, times(10)).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.count(), greaterThanOrEqualTo(1L));
}
Expand All @@ -226,7 +233,7 @@ public void testCheckDatabases_dontCheckDatabaseOnNonIngestNode() throws Excepti

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -246,7 +253,7 @@ public void testCheckDatabases_dontCheckDatabaseWhenNoDatabasesIndex() throws Ex

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -261,7 +268,7 @@ public void testCheckDatabases_dontCheckDatabaseWhenGeoIpDownloadTask() throws E

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -281,7 +288,7 @@ public void testRetrieveDatabase() throws Exception {
verify(failureHandler, never()).accept(any());
verify(chunkConsumer, times(30)).accept(any());
verify(completedHandler, times(1)).run();
verify(client, times(30)).search(any());
verify(projectClient, times(30)).search(any());
}

public void testRetrieveDatabaseCorruption() throws Exception {
Expand All @@ -305,7 +312,7 @@ public void testRetrieveDatabaseCorruption() throws Exception {
);
verify(chunkConsumer, times(10)).accept(any());
verify(completedHandler, times(0)).run();
verify(client, times(10)).search(any());
verify(projectClient, times(10)).search(any());
}

public void testUpdateDatabase() throws Exception {
Expand Down Expand Up @@ -371,8 +378,7 @@ private String mockSearches(String databaseName, int firstChunk, int lastChunk)
});
requestMap.put(databaseName + "_" + i, actionFuture);
}
when(client.projectClient(any())).thenReturn(client);
when(client.search(any())).thenAnswer(invocationOnMock -> {
when(projectClient.search(any())).thenAnswer(invocationOnMock -> {
SearchRequest req = (SearchRequest) invocationOnMock.getArguments()[0];
TermQueryBuilder term = (TermQueryBuilder) req.source().query();
String id = (String) term.value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.index.TransportIndexAction;
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand Down Expand Up @@ -702,12 +703,14 @@ private GeoIpTaskState.Metadata newGeoIpTaskStateMetadata(boolean expired) {
return new GeoIpTaskState.Metadata(0, 0, 0, randomAlphaOfLength(20), lastChecked.toEpochMilli());
}

private static class MockClient extends NoOpClient {
private static class MockClient extends NoOpClient implements ProjectClient {

private final Map<ActionType<?>, BiConsumer<? extends ActionRequest, ? extends ActionListener<?>>> handlers = new HashMap<>();
private final ProjectId projectId;

private MockClient(ThreadPool threadPool, ProjectId projectId) {
super(threadPool, TestProjectResolvers.singleProject(projectId));
this.projectId = projectId;
}

public <Response extends ActionResponse, Request extends ActionRequest> void addHandler(
Expand All @@ -717,6 +720,11 @@ public <Response extends ActionResponse, Request extends ActionRequest> void add
handlers.put(action, listener);
}

@Override
public ProjectId projectId() {
return projectId;
}

@SuppressWarnings("unchecked")
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public interface Client extends ElasticsearchClient {
/**
* Returns a client that executes every request in the context of the given project.
*/
Client projectClient(ProjectId projectId);
ProjectClient projectClient(ProjectId projectId);

/**
* Returns this client's project resolver.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.client.internal;

import org.elasticsearch.cluster.metadata.ProjectId;

/**
* A {@link Client} that is scoped to a specific project. It should execute any request in the scope of that project. This scope is usually
* defined by the thread context.
*/
public interface ProjectClient extends Client {

ProjectId projectId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,10 @@
import org.elasticsearch.client.internal.AdminClient;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.FilterClient;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.core.Nullable;
Expand All @@ -96,6 +98,7 @@ public abstract class AbstractClient implements Client {
private final ThreadPool threadPool;
private final ProjectResolver projectResolver;
private final AdminClient admin;
private final ProjectClient defaultProjectClient;

@SuppressWarnings("this-escape")
public AbstractClient(Settings settings, ThreadPool threadPool, ProjectResolver projectResolver) {
Expand All @@ -104,6 +107,14 @@ public AbstractClient(Settings settings, ThreadPool threadPool, ProjectResolver
this.projectResolver = projectResolver;
this.admin = new AdminClient(this);
this.logger = LogManager.getLogger(this.getClass());
// We create a dedicated project client for the default project to avoid having to reconstruct it on every invocation.
// This aims to reduce the overhead of creating a project client when the client is used in a single-project context.
// TODO: only create the default project client if the project resolver does not support multiple projects.
if (this instanceof ProjectClient == false) {
this.defaultProjectClient = new ProjectClientImpl(this, ProjectId.DEFAULT);
} else {
this.defaultProjectClient = null;
}
}

@Override
Expand Down Expand Up @@ -417,29 +428,13 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
}

@Override
public Client projectClient(ProjectId projectId) {
public ProjectClient projectClient(ProjectId projectId) {
// We only take the shortcut when the given project ID matches the "current" project ID. If it doesn't, we'll let #executeOnProject
// take care of error handling.
if (projectResolver.supportsMultipleProjects() == false && projectId.equals(projectResolver.getProjectId())) {
return this;
return defaultProjectClient;
}
return new FilterClient(this) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
projectResolver.executeOnProject(projectId, () -> super.doExecute(action, request, listener));
}

@Override
public Client projectClient(ProjectId projectId) {
throw new IllegalStateException(
"Unable to create a project client for project [" + projectId + "], nested project client creation is not supported"
);
}
};
return new ProjectClientImpl(this, projectId);
}

/**
Expand Down Expand Up @@ -477,4 +472,35 @@ public R get() throws InterruptedException, ExecutionException {
return super.get();
}
}

private static class ProjectClientImpl extends FilterClient implements ProjectClient {

private final ProjectId projectId;

ProjectClientImpl(Client in, ProjectId projectId) {
super(in);
this.projectId = projectId;
}

@Override
public ProjectId projectId() {
return projectId;
}

@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
projectResolver().executeOnProject(projectId, () -> super.doExecute(action, request, listener));
}

@Override
public ProjectClient projectClient(ProjectId projectId) {
throw new IllegalStateException(Strings.format("""
Unable to create a project client for project [%s] from project client with project ID [%s],\
nested project client creation is not supported""", projectId, this.projectId));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.client.internal.AdminClient;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.IndicesAdminClient;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.ProjectState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand All @@ -26,6 +27,7 @@
public abstract class AbstractStepTestCase<T extends Step> extends ESTestCase {

protected Client client;
protected ProjectClient projectClient;
protected AdminClient adminClient;
protected IndicesAdminClient indicesClient;

Expand All @@ -34,9 +36,10 @@ public void setupClient() {
client = Mockito.mock(Client.class);
adminClient = Mockito.mock(AdminClient.class);
indicesClient = Mockito.mock(IndicesAdminClient.class);
projectClient = Mockito.mock(ProjectClient.class);

Mockito.when(client.projectClient(Mockito.any())).thenReturn(client);
Mockito.when(client.admin()).thenReturn(adminClient);
Mockito.when(client.projectClient(Mockito.any())).thenReturn(projectClient);
Mockito.when(projectClient.admin()).thenReturn(adminClient);
Mockito.when(adminClient.indices()).thenReturn(indicesClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.LifecycleExecutionState;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.project.TestProjectResolvers;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -105,7 +106,7 @@ public void testPerformAction() {
}

private NoOpClient getDeleteSnapshotRequestAssertingClient(ThreadPool threadPool, String expectedSnapshotName) {
return new NoOpClient(threadPool) {
return new NoOpClient(threadPool, TestProjectResolvers.usingRequestHeader(threadPool.getThreadContext())) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void onFailure(Exception e) {

assertEquals(true, actionCompleted.get());
Mockito.verify(client).projectClient(state.projectId());
Mockito.verify(client).admin();
Mockito.verify(projectClient).admin();
Mockito.verifyNoMoreInteractions(client);
Mockito.verify(adminClient, Mockito.only()).indices();
Mockito.verify(indicesClient, Mockito.only()).close(Mockito.any(), Mockito.any());
Expand Down Expand Up @@ -110,7 +110,7 @@ public void testPerformActionFailure() {

assertSame(exception, expectThrows(Exception.class, () -> performActionAndWait(step, indexMetadata, state, null)));
Mockito.verify(client).projectClient(state.projectId());
Mockito.verify(client).admin();
Mockito.verify(projectClient).admin();
Mockito.verifyNoMoreInteractions(client);
Mockito.verify(adminClient, Mockito.only()).indices();
Mockito.verify(indicesClient, Mockito.only()).close(Mockito.any(), Mockito.any());
Expand Down
Loading