Skip to content

Commit c2de4b7

Browse files
authored
Allowlist tracestate header on remote server port (#112649)
The [`tracestate` header](https://www.elastic.co/guide/en/apm/agent/rum-js/current/distributed-tracing-guide.html#enable-tracestate) is an HTTP header used for distributed tracing; it's a valid header to persist in cross cluster requests and should therefore be allowlisted in the remote server port header check. Note: due to implementation details, `tracestate` today may be set on the fulfilling cluster (instead of arriving across the wire) _before_ the header check. Not allowing the header therefore can lead to failures to connect clusters (#112552). This PR allowlists the header to allow tracing with RCS 2.0. As a separate follow up, we may furthermore change behavior around sending the header from the query cluster to the fulfilling cluster (which we don't today). This is pending further discussion. Closes: #112552
1 parent 5f6bcc8 commit c2de4b7

File tree

4 files changed

+311
-0
lines changed

4 files changed

+311
-0
lines changed

docs/changelog/112649.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 112649
2+
summary: Allowlist `tracestate` header on remote server port
3+
area: Security
4+
type: bug
5+
issues: []
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0 and the Server Side Public License, v 1; you may not use this file except
5+
* in compliance with, at your election, the Elastic License 2.0 or the Server
6+
* Side Public License, v 1.
7+
*/
8+
9+
package org.elasticsearch.xpack.remotecluster;
10+
11+
import com.sun.net.httpserver.HttpExchange;
12+
import com.sun.net.httpserver.HttpServer;
13+
14+
import org.apache.logging.log4j.LogManager;
15+
import org.apache.logging.log4j.Logger;
16+
import org.elasticsearch.core.SuppressForbidden;
17+
import org.junit.rules.ExternalResource;
18+
19+
import java.io.BufferedReader;
20+
import java.io.IOException;
21+
import java.io.InputStream;
22+
import java.io.InputStreamReader;
23+
import java.net.InetAddress;
24+
import java.net.InetSocketAddress;
25+
import java.nio.charset.StandardCharsets;
26+
import java.util.List;
27+
import java.util.concurrent.ArrayBlockingQueue;
28+
import java.util.concurrent.TimeUnit;
29+
import java.util.function.Consumer;
30+
31+
@SuppressForbidden(reason = "Uses an HTTP server for testing")
32+
class ConsumingTestServer extends ExternalResource {
33+
private static final Logger logger = LogManager.getLogger(ConsumingTestServer.class);
34+
final ArrayBlockingQueue<String> received = new ArrayBlockingQueue<>(1000);
35+
36+
private static HttpServer server;
37+
private final Thread messageConsumerThread = consumerThread();
38+
private volatile Consumer<String> consumer;
39+
private volatile boolean consumerRunning = true;
40+
41+
@Override
42+
protected void before() throws Throwable {
43+
server = HttpServer.create();
44+
server.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
45+
server.createContext("/", this::handle);
46+
server.start();
47+
48+
messageConsumerThread.start();
49+
}
50+
51+
private Thread consumerThread() {
52+
return new Thread(() -> {
53+
while (consumerRunning) {
54+
if (consumer != null) {
55+
try {
56+
String msg = received.poll(1L, TimeUnit.SECONDS);
57+
if (msg != null && msg.isEmpty() == false) {
58+
consumer.accept(msg);
59+
}
60+
61+
} catch (InterruptedException e) {
62+
throw new RuntimeException(e);
63+
}
64+
}
65+
}
66+
});
67+
}
68+
69+
@Override
70+
protected void after() {
71+
server.stop(1);
72+
consumerRunning = false;
73+
}
74+
75+
private void handle(HttpExchange exchange) throws IOException {
76+
try (exchange) {
77+
try {
78+
try (InputStream requestBody = exchange.getRequestBody()) {
79+
if (requestBody != null) {
80+
var read = readJsonMessages(requestBody);
81+
received.addAll(read);
82+
}
83+
}
84+
85+
} catch (RuntimeException e) {
86+
logger.warn("failed to parse request", e);
87+
}
88+
exchange.sendResponseHeaders(201, 0);
89+
}
90+
}
91+
92+
private List<String> readJsonMessages(InputStream input) {
93+
// parse NDJSON
94+
return new BufferedReader(new InputStreamReader(input, StandardCharsets.UTF_8)).lines().toList();
95+
}
96+
97+
public int getPort() {
98+
return server.getAddress().getPort();
99+
}
100+
101+
public void addMessageConsumer(Consumer<String> messageConsumer) {
102+
this.consumer = messageConsumer;
103+
}
104+
}
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.remotecluster;
9+
10+
import org.elasticsearch.client.Request;
11+
import org.elasticsearch.client.RequestOptions;
12+
import org.elasticsearch.client.Response;
13+
import org.elasticsearch.common.xcontent.support.XContentMapValues;
14+
import org.elasticsearch.tasks.Task;
15+
import org.elasticsearch.test.cluster.ElasticsearchCluster;
16+
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
17+
import org.elasticsearch.test.cluster.util.resource.Resource;
18+
import org.elasticsearch.xcontent.XContentParser;
19+
import org.elasticsearch.xcontent.XContentParserConfiguration;
20+
import org.elasticsearch.xcontent.spi.XContentProvider;
21+
import org.hamcrest.Matcher;
22+
import org.hamcrest.StringDescription;
23+
import org.junit.ClassRule;
24+
import org.junit.rules.RuleChain;
25+
import org.junit.rules.TestRule;
26+
27+
import java.io.IOException;
28+
import java.util.Arrays;
29+
import java.util.Collections;
30+
import java.util.HashSet;
31+
import java.util.Map;
32+
import java.util.Set;
33+
import java.util.concurrent.CountDownLatch;
34+
import java.util.concurrent.TimeUnit;
35+
import java.util.concurrent.atomic.AtomicReference;
36+
import java.util.function.Consumer;
37+
import java.util.function.Predicate;
38+
import java.util.stream.Collectors;
39+
40+
import static org.hamcrest.Matchers.equalTo;
41+
42+
public class RemoteClusterSecurityWithApmTracingRestIT extends AbstractRemoteClusterSecurityTestCase {
43+
private static final AtomicReference<Map<String, Object>> API_KEY_MAP_REF = new AtomicReference<>();
44+
private static final XContentProvider.FormatProvider XCONTENT = XContentProvider.provider().getJsonXContent();
45+
final String traceIdValue = "0af7651916cd43dd8448eb211c80319c";
46+
final String traceParentValue = "00-" + traceIdValue + "-b7ad6b7169203331-01";
47+
48+
private static final ConsumingTestServer mockApmServer = new ConsumingTestServer();
49+
50+
static {
51+
fulfillingCluster = ElasticsearchCluster.local()
52+
.distribution(DistributionType.DEFAULT)
53+
.name("fulfilling-cluster")
54+
.apply(commonClusterConfig)
55+
.setting("telemetry.metrics.enabled", "false")
56+
.setting("telemetry.tracing.enabled", "true")
57+
.setting("telemetry.agent.metrics_interval", "1s")
58+
.setting("telemetry.agent.server_url", () -> "http://127.0.0.1:" + mockApmServer.getPort())
59+
// to ensure tracestate header is always set to cover RCS 2.0 handling of the tracestate header
60+
.setting("telemetry.agent.transaction_sample_rate", "1.0")
61+
.setting("remote_cluster_server.enabled", "true")
62+
.setting("remote_cluster.port", "0")
63+
.setting("xpack.security.remote_cluster_server.ssl.enabled", "true")
64+
.setting("xpack.security.remote_cluster_server.ssl.key", "remote-cluster.key")
65+
.setting("xpack.security.remote_cluster_server.ssl.certificate", "remote-cluster.crt")
66+
.keystore("xpack.security.remote_cluster_server.ssl.secure_key_passphrase", "remote-cluster-password")
67+
.rolesFile(Resource.fromClasspath("roles.yml"))
68+
.build();
69+
70+
queryCluster = ElasticsearchCluster.local()
71+
.distribution(DistributionType.DEFAULT)
72+
.name("query-cluster")
73+
.apply(commonClusterConfig)
74+
.setting("telemetry.metrics.enabled", "false")
75+
.setting("telemetry.tracing.enabled", "true")
76+
// to ensure tracestate header is always set to cover RCS 2.0 handling of the tracestate header
77+
.setting("telemetry.agent.transaction_sample_rate", "1.0")
78+
.setting("telemetry.agent.metrics_interval", "1s")
79+
.setting("telemetry.agent.server_url", () -> "http://127.0.0.1:" + mockApmServer.getPort())
80+
.setting("xpack.security.remote_cluster_client.ssl.enabled", "true")
81+
.setting("xpack.security.remote_cluster_client.ssl.certificate_authorities", "remote-cluster-ca.crt")
82+
.keystore("cluster.remote.my_remote_cluster.credentials", () -> {
83+
if (API_KEY_MAP_REF.get() == null) {
84+
final Map<String, Object> apiKeyMap = createCrossClusterAccessApiKey("""
85+
{
86+
"search": [
87+
{
88+
"names": ["*"]
89+
}
90+
]
91+
}""");
92+
API_KEY_MAP_REF.set(apiKeyMap);
93+
}
94+
return (String) API_KEY_MAP_REF.get().get("encoded");
95+
})
96+
.rolesFile(Resource.fromClasspath("roles.yml"))
97+
.user(REMOTE_METRIC_USER, PASS.toString(), "read_remote_shared_metrics", false)
98+
.build();
99+
}
100+
101+
@ClassRule
102+
// Use a RuleChain to ensure that fulfilling cluster is started before query cluster
103+
public static TestRule clusterRule = RuleChain.outerRule(mockApmServer).around(fulfillingCluster).around(queryCluster);
104+
105+
@SuppressWarnings("unchecked")
106+
public void testTracingCrossCluster() throws Exception {
107+
configureRemoteCluster();
108+
Set<Predicate<Map<String, Object>>> assertions = new HashSet<>(
109+
Set.of(
110+
// REST action on query cluster
111+
allTrue(
112+
transactionValue("name", equalTo("GET /_resolve/cluster/{name}")),
113+
transactionValue("trace_id", equalTo(traceIdValue))
114+
),
115+
// transport action on fulfilling cluster
116+
allTrue(
117+
transactionValue("name", equalTo("indices:admin/resolve/cluster")),
118+
transactionValue("trace_id", equalTo(traceIdValue))
119+
)
120+
)
121+
);
122+
123+
CountDownLatch finished = new CountDownLatch(1);
124+
125+
// a consumer that will remove the assertions from a map once it matched
126+
Consumer<String> messageConsumer = (String message) -> {
127+
var apmMessage = parseMap(message);
128+
if (isTransactionTraceMessage(apmMessage)) {
129+
logger.info("Apm transaction message received: {}", message);
130+
assertions.removeIf(e -> e.test(apmMessage));
131+
}
132+
133+
if (assertions.isEmpty()) {
134+
finished.countDown();
135+
}
136+
};
137+
138+
mockApmServer.addMessageConsumer(messageConsumer);
139+
140+
// Trigger an action that we know will cross clusters -- doesn't much matter which one
141+
final Request resolveRequest = new Request("GET", "/_resolve/cluster/my_remote_cluster:*");
142+
resolveRequest.setOptions(
143+
RequestOptions.DEFAULT.toBuilder()
144+
.addHeader("Authorization", headerFromRandomAuthMethod(REMOTE_METRIC_USER, PASS))
145+
.addHeader(Task.TRACE_PARENT_HTTP_HEADER, traceParentValue)
146+
);
147+
final Response response = client().performRequest(resolveRequest);
148+
assertOK(response);
149+
150+
finished.await(30, TimeUnit.SECONDS);
151+
assertThat(assertions, equalTo(Collections.emptySet()));
152+
}
153+
154+
private boolean isTransactionTraceMessage(Map<String, Object> apmMessage) {
155+
return apmMessage.containsKey("transaction");
156+
}
157+
158+
@SuppressWarnings("unchecked")
159+
private Predicate<Map<String, Object>> allTrue(Predicate<Map<String, Object>>... predicates) {
160+
var allTrueTest = Arrays.stream(predicates).reduce(v -> true, Predicate::and);
161+
return new Predicate<>() {
162+
@Override
163+
public boolean test(Map<String, Object> map) {
164+
return allTrueTest.test(map);
165+
}
166+
167+
@Override
168+
public String toString() {
169+
return Arrays.stream(predicates).map(Object::toString).collect(Collectors.joining(" and "));
170+
}
171+
};
172+
}
173+
174+
@SuppressWarnings("unchecked")
175+
private <T> Predicate<Map<String, Object>> transactionValue(String path, Matcher<T> expected) {
176+
return new Predicate<>() {
177+
@Override
178+
public boolean test(Map<String, Object> map) {
179+
var transaction = (Map<String, Object>) map.get("transaction");
180+
var value = XContentMapValues.extractValue(path, transaction);
181+
return expected.matches((T) value);
182+
}
183+
184+
@Override
185+
public String toString() {
186+
StringDescription matcherDescription = new StringDescription();
187+
expected.describeTo(matcherDescription);
188+
return path + " " + matcherDescription;
189+
}
190+
};
191+
}
192+
193+
private Map<String, Object> parseMap(String message) {
194+
try (XContentParser parser = XCONTENT.XContent().createParser(XContentParserConfiguration.EMPTY, message)) {
195+
return parser.map();
196+
} catch (IOException e) {
197+
fail(e);
198+
return Collections.emptyMap();
199+
}
200+
}
201+
}

x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/CrossClusterAccessServerTransportFilter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ final class CrossClusterAccessServerTransportFilter extends ServerTransportFilte
4141
Set.of(CROSS_CLUSTER_ACCESS_CREDENTIALS_HEADER_KEY, CROSS_CLUSTER_ACCESS_SUBJECT_INFO_HEADER_KEY)
4242
);
4343
allowedHeaders.add(AuditUtil.AUDIT_REQUEST_ID);
44+
allowedHeaders.add(Task.TRACE_STATE);
4445
allowedHeaders.addAll(Task.HEADERS_TO_COPY);
4546
ALLOWED_TRANSPORT_HEADERS = Set.copyOf(allowedHeaders);
4647
}

0 commit comments

Comments
 (0)