Skip to content

Commit 61703c8

Browse files
committed
Added testRemoteClusterAction
1 parent c2aaf2f commit 61703c8

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsCrossClusterIT.java

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,45 @@
77

88
package org.elasticsearch.xpack.inference.integration;
99

10+
import org.elasticsearch.action.support.PlainActionFuture;
11+
import org.elasticsearch.client.internal.Client;
12+
import org.elasticsearch.client.internal.RemoteClusterClient;
1013
import org.elasticsearch.common.settings.Settings;
14+
import org.elasticsearch.common.util.concurrent.EsExecutors;
15+
import org.elasticsearch.inference.TaskType;
1116
import org.elasticsearch.license.LicenseSettings;
1217
import org.elasticsearch.plugins.Plugin;
1318
import org.elasticsearch.test.AbstractMultiClustersTestCase;
19+
import org.elasticsearch.transport.RemoteClusterService;
20+
import org.elasticsearch.xcontent.XContentBuilder;
1421
import org.elasticsearch.xpack.core.inference.action.GetInferenceFieldsAction;
22+
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
23+
import org.elasticsearch.xpack.inference.FakeMlPlugin;
1524
import org.elasticsearch.xpack.inference.LocalStateInferencePlugin;
25+
import org.elasticsearch.xpack.inference.mock.TestInferenceServicePlugin;
26+
import org.junit.Before;
1627

28+
import java.io.IOException;
1729
import java.util.Collection;
1830
import java.util.List;
1931
import java.util.Map;
2032
import java.util.Set;
2133

34+
import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
35+
import static org.elasticsearch.xpack.inference.integration.GetInferenceFieldsIT.assertInferenceFieldsMap;
36+
import static org.elasticsearch.xpack.inference.integration.GetInferenceFieldsIT.assertInferenceResultsMap;
37+
import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.createInferenceEndpoint;
38+
import static org.elasticsearch.xpack.inference.integration.IntegrationTestUtils.generateSemanticTextMapping;
2239
import static org.hamcrest.Matchers.containsString;
2340

2441
public class GetInferenceFieldsCrossClusterIT extends AbstractMultiClustersTestCase {
2542
private static final String REMOTE_CLUSTER = "cluster_a";
43+
private static final String INDEX_NAME = "test-index";
44+
private static final String INFERENCE_FIELD = "test-inference-field";
45+
private static final String INFERENCE_ID = "test-inference-id";
46+
private static final Map<String, Object> INFERENCE_ENDPOINT_SERVICE_SETTINGS = Map.of("model", "my_model", "api_key", "my_api_key");
47+
48+
private boolean clustersConfigured = false;
2649

2750
@Override
2851
protected List<String> remoteClusterAlias() {
@@ -41,7 +64,15 @@ protected Settings nodeSettings() {
4164

4265
@Override
4366
protected Collection<Class<? extends Plugin>> nodePlugins(String clusterAlias) {
44-
return List.of(LocalStateInferencePlugin.class);
67+
return List.of(LocalStateInferencePlugin.class, TestInferenceServicePlugin.class, FakeMlPlugin.class);
68+
}
69+
70+
@Before
71+
public void configureClusters() throws Exception {
72+
if (clustersConfigured == false) {
73+
setupTwoClusters();
74+
clustersConfigured = true;
75+
}
4576
}
4677

4778
public void testRemoteIndex() {
@@ -52,4 +83,39 @@ public void testRemoteIndex() {
5283
);
5384
assertThat(e.getMessage(), containsString("GetInferenceFieldsAction does not support remote indices"));
5485
}
86+
87+
public void testRemoteClusterAction() {
88+
RemoteClusterClient remoteClusterClient = client().getRemoteClusterClient(
89+
REMOTE_CLUSTER,
90+
EsExecutors.DIRECT_EXECUTOR_SERVICE,
91+
RemoteClusterService.DisconnectedStrategy.RECONNECT_IF_DISCONNECTED
92+
);
93+
94+
var request = new GetInferenceFieldsAction.Request(Set.of(INDEX_NAME), Set.of(INFERENCE_FIELD), false, false, "foo");
95+
PlainActionFuture<GetInferenceFieldsAction.Response> future = new PlainActionFuture<>();
96+
remoteClusterClient.execute(GetInferenceFieldsAction.REMOTE_TYPE, request, future);
97+
98+
var response = future.actionGet(TEST_REQUEST_TIMEOUT);
99+
assertInferenceFieldsMap(
100+
response.getInferenceFieldsMap(),
101+
Map.of(INDEX_NAME, Set.of(new GetInferenceFieldsIT.InferenceFieldAndId(INFERENCE_FIELD, INFERENCE_ID)))
102+
);
103+
assertInferenceResultsMap(response.getInferenceResultsMap(), Map.of(INFERENCE_ID, TextExpansionResults.class));
104+
}
105+
106+
private void setupTwoClusters() throws IOException {
107+
setupCluster(LOCAL_CLUSTER);
108+
setupCluster(REMOTE_CLUSTER);
109+
}
110+
111+
private void setupCluster(String clusterAlias) throws IOException {
112+
final Client client = client(clusterAlias);
113+
114+
createInferenceEndpoint(client, TaskType.SPARSE_EMBEDDING, INFERENCE_ID, INFERENCE_ENDPOINT_SERVICE_SETTINGS);
115+
116+
int dataNodeCount = cluster(clusterAlias).numDataNodes();
117+
XContentBuilder mappings = generateSemanticTextMapping(Map.of(INFERENCE_FIELD, INFERENCE_ID));
118+
Settings indexSettings = indexSettings(randomIntBetween(1, dataNodeCount), 0).build();
119+
assertAcked(client.admin().indices().prepareCreate(INDEX_NAME).setSettings(indexSettings).setMapping(mappings));
120+
}
55121
}

x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/GetInferenceFieldsIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ private static <T extends Exception> void assertFailedRequest(
434434
exceptionValidator.accept(exception);
435435
}
436436

437-
private static void assertInferenceFieldsMap(
437+
static void assertInferenceFieldsMap(
438438
Map<String, List<InferenceFieldMetadata>> inferenceFieldsMap,
439439
Map<String, Set<InferenceFieldAndId>> expectedInferenceFields
440440
) {
@@ -458,7 +458,7 @@ private static void assertInferenceFieldsMap(
458458
}
459459
}
460460

461-
private static void assertInferenceResultsMap(
461+
static void assertInferenceResultsMap(
462462
Map<String, InferenceResults> inferenceResultsMap,
463463
Map<String, Class<? extends InferenceResults>> expectedInferenceResults
464464
) {
@@ -490,5 +490,5 @@ private static Map<String, Class<? extends InferenceResults>> filterExpectedInfe
490490
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
491491
}
492492

493-
private record InferenceFieldAndId(String field, String inferenceId) {}
493+
record InferenceFieldAndId(String field, String inferenceId) {}
494494
}

0 commit comments

Comments
 (0)