Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.block.ClusterBlockLevel;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.service.ClusterService;
Expand Down Expand Up @@ -77,7 +77,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
static final String DATABASES_INDEX_PATTERN = DATABASES_INDEX + "*";
static final int MAX_CHUNK_SIZE = 1024 * 1024;

private final Client client;
private final ProjectClient client;
private final HttpClient httpClient;
private final ClusterService clusterService;
private final ThreadPool threadPool;
Expand All @@ -99,7 +99,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
private final ProjectId projectId;

GeoIpDownloader(
Client client,
ProjectClient client,
HttpClient httpClient,
ClusterService clusterService,
ThreadPool threadPool,
Expand All @@ -116,7 +116,7 @@ public class GeoIpDownloader extends AllocatedPersistentTask {
ProjectId projectId
) {
super(id, type, action, description, parentTask, headers);
this.client = client.projectClient(projectId);
this.client = client;
this.httpClient = httpClient;
this.clusterService = clusterService;
this.threadPool = threadPool;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ protected GeoIpDownloader createTask(
) {
ProjectId projectId = projectResolver.getProjectId();
return new GeoIpDownloader(
client,
client.projectClient(projectId),
httpClient,
clusterService,
threadPool,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.metadata.AliasMetadata;
Expand Down Expand Up @@ -63,6 +64,7 @@
import org.junit.After;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.stubbing.Answer;

import java.io.ByteArrayInputStream;
Expand Down Expand Up @@ -112,6 +114,7 @@
public class DatabaseNodeServiceTests extends ESTestCase {

private Client client;
private ProjectClient projectClient;
private Path geoIpTmpDir;
private ThreadPool threadPool;
private DatabaseNodeService databaseNodeService;
Expand All @@ -138,7 +141,9 @@ public void setup() throws IOException {
Settings settings = Settings.builder().put("resource.reload.interval.high", TimeValue.timeValueMillis(100)).build();
resourceWatcherService = new ResourceWatcherService(settings, threadPool);

projectClient = mock(ProjectClient.class);
client = mock(Client.class);
when(client.projectClient(any())).thenReturn(projectClient);
ingestService = mock(IngestService.class);
clusterService = mock(ClusterService.class);
geoIpTmpDir = createTempDir();
Expand All @@ -161,6 +166,8 @@ public void cleanup() {
threadPool.shutdownNow();
Releasables.close(toRelease);
toRelease.clear();
verify(client, Mockito.atLeast(0)).projectClient(any());
verifyNoMoreInteractions(client);
}

public void testCheckDatabases() throws Exception {
Expand All @@ -181,7 +188,7 @@ public void testCheckDatabases() throws Exception {
databaseNodeService.checkDatabases(state);
DatabaseReaderLazyLoader database = databaseNodeService.getDatabaseReaderLazyLoader(projectId, "GeoIP2-City.mmdb");
assertThat(database, nullValue());
verify(client, times(0)).search(any());
verify(projectClient, times(0)).search(any());
verify(ingestService, times(0)).reloadPipeline(any(), anyString());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertEquals(0, files.count());
Expand All @@ -199,7 +206,7 @@ public void testCheckDatabases() throws Exception {
databaseNodeService.checkDatabases(state);
database = databaseNodeService.getDatabaseReaderLazyLoader(projectId, "GeoIP2-City.mmdb");
assertThat(database, notNullValue());
verify(client, times(10)).search(any());
verify(projectClient, times(10)).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.count(), greaterThanOrEqualTo(1L));
}
Expand All @@ -226,7 +233,7 @@ public void testCheckDatabases_dontCheckDatabaseOnNonIngestNode() throws Excepti

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -246,7 +253,7 @@ public void testCheckDatabases_dontCheckDatabaseWhenNoDatabasesIndex() throws Ex

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -261,7 +268,7 @@ public void testCheckDatabases_dontCheckDatabaseWhenGeoIpDownloadTask() throws E

databaseNodeService.checkDatabases(state);
assertThat(databaseNodeService.getDatabase(projectId, "GeoIP2-City.mmdb"), nullValue());
verify(client, never()).search(any());
verify(projectClient, never()).search(any());
try (Stream<Path> files = Files.list(geoIpTmpDir.resolve("geoip-databases").resolve("nodeId"))) {
assertThat(files.toList(), empty());
}
Expand All @@ -281,7 +288,7 @@ public void testRetrieveDatabase() throws Exception {
verify(failureHandler, never()).accept(any());
verify(chunkConsumer, times(30)).accept(any());
verify(completedHandler, times(1)).run();
verify(client, times(30)).search(any());
verify(projectClient, times(30)).search(any());
}

public void testRetrieveDatabaseCorruption() throws Exception {
Expand All @@ -305,7 +312,7 @@ public void testRetrieveDatabaseCorruption() throws Exception {
);
verify(chunkConsumer, times(10)).accept(any());
verify(completedHandler, times(0)).run();
verify(client, times(10)).search(any());
verify(projectClient, times(10)).search(any());
}

public void testUpdateDatabase() throws Exception {
Expand Down Expand Up @@ -371,8 +378,7 @@ private String mockSearches(String databaseName, int firstChunk, int lastChunk)
});
requestMap.put(databaseName + "_" + i, actionFuture);
}
when(client.projectClient(any())).thenReturn(client);
when(client.search(any())).thenAnswer(invocationOnMock -> {
when(projectClient.search(any())).thenAnswer(invocationOnMock -> {
SearchRequest req = (SearchRequest) invocationOnMock.getArguments()[0];
TermQueryBuilder term = (TermQueryBuilder) req.source().query();
String id = (String) term.value();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.index.TransportIndexAction;
import org.elasticsearch.action.support.broadcast.BroadcastResponse;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.block.ClusterBlocks;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand Down Expand Up @@ -702,12 +703,14 @@ private GeoIpTaskState.Metadata newGeoIpTaskStateMetadata(boolean expired) {
return new GeoIpTaskState.Metadata(0, 0, 0, randomAlphaOfLength(20), lastChecked.toEpochMilli());
}

private static class MockClient extends NoOpClient {
private static class MockClient extends NoOpClient implements ProjectClient {

private final Map<ActionType<?>, BiConsumer<? extends ActionRequest, ? extends ActionListener<?>>> handlers = new HashMap<>();
private final ProjectId projectId;

private MockClient(ThreadPool threadPool, ProjectId projectId) {
super(threadPool, TestProjectResolvers.singleProject(projectId));
this.projectId = projectId;
}

public <Response extends ActionResponse, Request extends ActionRequest> void addHandler(
Expand All @@ -717,6 +720,11 @@ public <Response extends ActionResponse, Request extends ActionRequest> void add
handlers.put(action, listener);
}

@Override
public ProjectId projectId() {
return projectId;
}

@SuppressWarnings("unchecked")
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public interface Client extends ElasticsearchClient {
/**
* Returns a client that executes every request in the context of the given project.
*/
Client projectClient(ProjectId projectId);
ProjectClient projectClient(ProjectId projectId);

/**
* Returns this client's project resolver.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.client.internal;

import org.elasticsearch.cluster.metadata.ProjectId;

/**
* A {@link Client} that is scoped to a specific project. It should execute any request in the scope of that project. This scope is usually
* defined by the thread context.
*/
public interface ProjectClient extends Client {

ProjectId projectId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
import org.elasticsearch.client.internal.AdminClient;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.FilterClient;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.common.settings.Settings;
Expand All @@ -96,6 +97,7 @@ public abstract class AbstractClient implements Client {
private final ThreadPool threadPool;
private final ProjectResolver projectResolver;
private final AdminClient admin;
private ProjectClient defaultProjectClient;

@SuppressWarnings("this-escape")
public AbstractClient(Settings settings, ThreadPool threadPool, ProjectResolver projectResolver) {
Expand Down Expand Up @@ -417,29 +419,16 @@ protected <Request extends ActionRequest, Response extends ActionResponse> void
}

@Override
public Client projectClient(ProjectId projectId) {
public ProjectClient projectClient(ProjectId projectId) {
// We only take the shortcut when the given project ID matches the "current" project ID. If it doesn't, we'll let #executeOnProject
// take care of error handling.
if (projectResolver.supportsMultipleProjects() == false && projectId.equals(projectResolver.getProjectId())) {
return this;
}
return new FilterClient(this) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
projectResolver.executeOnProject(projectId, () -> super.doExecute(action, request, listener));
}

@Override
public Client projectClient(ProjectId projectId) {
throw new IllegalStateException(
"Unable to create a project client for project [" + projectId + "], nested project client creation is not supported"
);
if (defaultProjectClient == null) {
defaultProjectClient = new ProjectClientImpl(this, ProjectId.DEFAULT);
}
};
return defaultProjectClient;
}
return new ProjectClientImpl(this, projectId);
}

/**
Expand Down Expand Up @@ -477,4 +466,28 @@ public R get() throws InterruptedException, ExecutionException {
return super.get();
}
}

private static class ProjectClientImpl extends FilterClient implements ProjectClient {

private final ProjectId projectId;

ProjectClientImpl(Client in, ProjectId projectId) {
super(in);
this.projectId = projectId;
}

@Override
public ProjectId projectId() {
return projectId;
}

@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Request request,
ActionListener<Response> listener
) {
projectResolver().executeOnProject(projectId, () -> super.doExecute(action, request, listener));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.elasticsearch.client.internal.AdminClient;
import org.elasticsearch.client.internal.Client;
import org.elasticsearch.client.internal.IndicesAdminClient;
import org.elasticsearch.client.internal.ProjectClient;
import org.elasticsearch.cluster.ClusterStateObserver;
import org.elasticsearch.cluster.ProjectState;
import org.elasticsearch.cluster.metadata.IndexMetadata;
Expand All @@ -26,6 +27,7 @@
public abstract class AbstractStepTestCase<T extends Step> extends ESTestCase {

protected Client client;
protected ProjectClient projectClient;
protected AdminClient adminClient;
protected IndicesAdminClient indicesClient;

Expand All @@ -34,9 +36,10 @@ public void setupClient() {
client = Mockito.mock(Client.class);
adminClient = Mockito.mock(AdminClient.class);
indicesClient = Mockito.mock(IndicesAdminClient.class);
projectClient = Mockito.mock(ProjectClient.class);

Mockito.when(client.projectClient(Mockito.any())).thenReturn(client);
Mockito.when(client.admin()).thenReturn(adminClient);
Mockito.when(client.projectClient(Mockito.any())).thenReturn(projectClient);
Mockito.when(projectClient.admin()).thenReturn(adminClient);
Mockito.when(adminClient.indices()).thenReturn(indicesClient);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.cluster.metadata.IndexMetadata;
import org.elasticsearch.cluster.metadata.LifecycleExecutionState;
import org.elasticsearch.cluster.metadata.ProjectMetadata;
import org.elasticsearch.cluster.project.TestProjectResolvers;
import org.elasticsearch.index.IndexVersion;
import org.elasticsearch.test.client.NoOpClient;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -105,7 +106,7 @@ public void testPerformAction() {
}

private NoOpClient getDeleteSnapshotRequestAssertingClient(ThreadPool threadPool, String expectedSnapshotName) {
return new NoOpClient(threadPool) {
return new NoOpClient(threadPool, TestProjectResolvers.usingRequestHeader(threadPool.getThreadContext())) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse> void doExecute(
ActionType<Response> action,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ public void onFailure(Exception e) {

assertEquals(true, actionCompleted.get());
Mockito.verify(client).projectClient(state.projectId());
Mockito.verify(client).admin();
Mockito.verify(projectClient).admin();
Mockito.verifyNoMoreInteractions(client);
Mockito.verify(adminClient, Mockito.only()).indices();
Mockito.verify(indicesClient, Mockito.only()).close(Mockito.any(), Mockito.any());
Expand Down Expand Up @@ -110,7 +110,7 @@ public void testPerformActionFailure() {

assertSame(exception, expectThrows(Exception.class, () -> performActionAndWait(step, indexMetadata, state, null)));
Mockito.verify(client).projectClient(state.projectId());
Mockito.verify(client).admin();
Mockito.verify(projectClient).admin();
Mockito.verifyNoMoreInteractions(client);
Mockito.verify(adminClient, Mockito.only()).indices();
Mockito.verify(indicesClient, Mockito.only()).close(Mockito.any(), Mockito.any());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ public void testNextStepKey() {
ProjectState state = projectStateFromProject(ProjectMetadata.builder(randomProjectIdOrDefault()).put(indexMetadata, true));
{
try (var threadPool = createThreadPool()) {
final var client = new NoOpClient(threadPool);
final var client = new NoOpClient(threadPool, TestProjectResolvers.usingRequestHeader(threadPool.getThreadContext()));
StepKey nextKeyOnComplete = randomStepKey();
StepKey nextKeyOnIncomplete = randomStepKey();
CreateSnapshotStep completeStep = new CreateSnapshotStep(randomStepKey(), nextKeyOnComplete, nextKeyOnIncomplete, client) {
Expand All @@ -170,7 +170,7 @@ void createSnapshot(ProjectId projectId, IndexMetadata indexMetadata, ActionList

{
try (var threadPool = createThreadPool()) {
final var client = new NoOpClient(threadPool);
final var client = new NoOpClient(threadPool, TestProjectResolvers.usingRequestHeader(threadPool.getThreadContext()));
StepKey nextKeyOnComplete = randomStepKey();
StepKey nextKeyOnIncomplete = randomStepKey();
CreateSnapshotStep incompleteStep = new CreateSnapshotStep(
Expand All @@ -191,7 +191,7 @@ void createSnapshot(ProjectId projectId, IndexMetadata indexMetadata, ActionList

{
try (var threadPool = createThreadPool()) {
final var client = new NoOpClient(threadPool);
final var client = new NoOpClient(threadPool, TestProjectResolvers.usingRequestHeader(threadPool.getThreadContext()));
StepKey nextKeyOnComplete = randomStepKey();
StepKey nextKeyOnIncomplete = randomStepKey();
CreateSnapshotStep doubleInvocationStep = new CreateSnapshotStep(
Expand Down
Loading