Skip to content

Commit 3118b5d

Browse files
fix stream transport TLS cert hot-reload by using live SSLContext from SecureTransportSettingsProvider (#20733)
Signed-off-by: Rishabh Maurya <rishabhmaurya05@gmail.com>
1 parent cdab17c commit 3118b5d

File tree

10 files changed

+754
-448
lines changed

10 files changed

+754
-448
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
2525
- Added TopN selection logic for streaming terms aggregations ([#20481](https://github.com/opensearch-project/OpenSearch/pull/20481))
2626
- Added support for Intra Segment Search ([#19704](https://github.com/opensearch-project/OpenSearch/pull/19704))
2727
- Introduce AdditionalCodecs and EnginePlugin::getAdditionalCodecs hook to allow additional Codec registration ([#20411](https://github.com/opensearch-project/OpenSearch/pull/20411))
28-
- Support TLS cert hot-reload for Arrow Flight transport ([#20700](https://github.com/opensearch-project/OpenSearch/pull/20700))
2928

3029
### Changed
3130
- Handle custom metadata files in subdirectory-store ([#20157](https://github.com/opensearch-project/OpenSearch/pull/20157))
@@ -63,6 +62,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6362
- Leveraging segment-global ordinal mapping for efficient terms aggregation ([#20624](https://github.com/opensearch-project/OpenSearch/pull/20624))
6463
- Harden detection of HTTP/3 support by ensuring Quic native libraries are available for the target platform ([#20680](https://github.com/opensearch-project/OpenSearch/pull/20680))
6564
- Fix the regression of terms agg optimization at high cardinality ([#20623](https://github.com/opensearch-project/OpenSearch/pull/20623))
65+
- Fix TLS cert hot-reload for Arrow Flight transport ([#20732](https://github.com/opensearch-project/OpenSearch/pull/20732))
6666

6767
### Dependencies
6868
- Bump `com.google.auth:google-auth-library-oauth2-http` from 1.38.0 to 1.41.0 ([#20183](https://github.com/opensearch-project/OpenSearch/pull/20183))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
/*
2+
* SPDX-License-Identifier: Apache-2.0
3+
*
4+
* The OpenSearch Contributors require contributions made to
5+
* this file be licensed under the Apache-2.0 license or a
6+
* compatible open source license.
7+
*/
8+
9+
package org.opensearch.arrow.flight.bootstrap.tls;
10+
11+
import org.apache.arrow.flight.Criteria;
12+
import org.apache.arrow.flight.FlightClient;
13+
import org.apache.arrow.flight.FlightServer;
14+
import org.apache.arrow.flight.Location;
15+
import org.apache.arrow.flight.NoOpFlightProducer;
16+
import org.apache.arrow.flight.OSFlightClient;
17+
import org.apache.arrow.flight.OSFlightServer;
18+
import org.apache.arrow.memory.RootAllocator;
19+
import org.opensearch.common.SuppressForbidden;
20+
import org.opensearch.common.settings.Settings;
21+
import org.opensearch.plugins.SecureTransportSettingsProvider;
22+
import org.opensearch.test.OpenSearchTestCase;
23+
24+
import javax.net.ssl.SSLContext;
25+
import javax.net.ssl.SSLEngine;
26+
import javax.net.ssl.TrustManager;
27+
import javax.net.ssl.X509ExtendedTrustManager;
28+
29+
import java.lang.reflect.Field;
30+
import java.net.Socket;
31+
import java.security.cert.CertificateException;
32+
import java.security.cert.X509Certificate;
33+
import java.util.List;
34+
import java.util.Locale;
35+
import java.util.Optional;
36+
import java.util.concurrent.ExecutorService;
37+
import java.util.concurrent.Executors;
38+
import java.util.function.Supplier;
39+
40+
import io.grpc.ManagedChannel;
41+
import io.netty.pkitesting.CertificateBuilder;
42+
import io.netty.pkitesting.X509Bundle;
43+
44+
import static org.mockito.ArgumentMatchers.any;
45+
import static org.mockito.Mockito.mock;
46+
import static org.mockito.Mockito.when;
47+
48+
/**
49+
* Verifies that {@link DefaultSslContextProvider} picks up a new certificate after cert reload,
50+
* without restarting the Flight server or client.
51+
* <p>
52+
* Two reload scenarios are tested:
53+
* <ol>
54+
* <li><b>In-place mutation</b>: {@code SSLContext.init()} replaces key material on the same instance.</li>
55+
* <li><b>New instance</b>: {@link SecureTransportSettingsProvider#buildSecureTransportContext}
56+
* returns a brand-new {@link SSLContext} on each call. This requires
57+
* {@link DefaultSslContextProvider} to invoke {@code buildSecureTransportContext} on every
58+
* {@code newEngine()} call rather than caching the result from startup.</li>
59+
* </ol>
60+
* {@code ManagedChannel.enterIdle()} drops the current TCP connection so the next RPC call
61+
* triggers a fresh TLS handshake on the same client instance.
62+
*/
63+
public class DefaultSslContextProviderFlightIT extends OpenSearchTestCase {
64+
65+
private Locale savedLocale;
66+
67+
@Override
68+
public void setUp() throws Exception {
69+
super.setUp();
70+
// CertificateBuilder uses BouncyCastle's SimpleDateFormat for UTCTime, which is
71+
// locale-sensitive. Non-Latin locales produce non-ASCII date strings that BouncyCastle
72+
// cannot parse back, causing test failures. Pin to Locale.ROOT for the duration.
73+
savedLocale = Locale.getDefault();
74+
Locale.setDefault(Locale.ROOT);
75+
}
76+
77+
@Override
78+
public void tearDown() throws Exception {
79+
Locale.setDefault(savedLocale);
80+
super.tearDown();
81+
}
82+
83+
/** Cert reload via in-place {@code SSLContext.init()} on the same instance. */
84+
public void testHotReloadInPlace() throws Exception {
85+
X509Bundle cert1 = selfSigned("cert1.example.com");
86+
X509Bundle cert2 = selfSigned("cert2.example.com");
87+
88+
SSLContext serverCtx = SSLContext.getInstance("TLS");
89+
serverCtx.init(cert1.toKeyManagerFactory().getKeyManagers(), null, null);
90+
91+
runReloadTest(9600, () -> serverCtx, () -> {
92+
try {
93+
serverCtx.init(cert2.toKeyManagerFactory().getKeyManagers(), null, null);
94+
} catch (Exception e) {
95+
throw new RuntimeException(e);
96+
}
97+
}, "cert1.example.com", "cert2.example.com");
98+
}
99+
100+
/** Cert reload where the provider returns a new {@link SSLContext} instance on each call. */
101+
public void testHotReloadWithNewInstance() throws Exception {
102+
X509Bundle cert1 = selfSigned("cert1.example.com");
103+
X509Bundle cert2 = selfSigned("cert2.example.com");
104+
105+
SSLContext ctx1 = SSLContext.getInstance("TLS");
106+
ctx1.init(cert1.toKeyManagerFactory().getKeyManagers(), null, null);
107+
SSLContext ctx2 = SSLContext.getInstance("TLS");
108+
ctx2.init(cert2.toKeyManagerFactory().getKeyManagers(), null, null);
109+
110+
SSLContext[] current = { ctx1 };
111+
runReloadTest(9700, () -> current[0], () -> current[0] = ctx2, "cert1.example.com", "cert2.example.com");
112+
}
113+
114+
// ---- core test logic ----
115+
116+
private void runReloadTest(int basePort, Supplier<SSLContext> serverCtxSupplier, Runnable reload, String cnBefore, String cnAfter)
117+
throws Exception {
118+
X509Certificate[] capturedChain = { null };
119+
TrustManager capturingTm = capturingTrustManager(capturedChain);
120+
121+
SecureTransportSettingsProvider.SecureTransportParameters params = mock(
122+
SecureTransportSettingsProvider.SecureTransportParameters.class
123+
);
124+
when(params.clientAuth()).thenReturn(Optional.of("NONE"));
125+
when(params.cipherSuites()).thenReturn(List.of());
126+
127+
SecureTransportSettingsProvider serverProvider = mock(SecureTransportSettingsProvider.class);
128+
when(serverProvider.buildSecureTransportContext(any())).thenAnswer(inv -> Optional.of(serverCtxSupplier.get()));
129+
when(serverProvider.parameters(any())).thenReturn(Optional.of(params));
130+
131+
SSLContext clientCtx = SSLContext.getInstance("TLS");
132+
clientCtx.init(null, new TrustManager[] { capturingTm }, null);
133+
134+
SecureTransportSettingsProvider clientProvider = mock(SecureTransportSettingsProvider.class);
135+
when(clientProvider.buildSecureTransportContext(any())).thenReturn(Optional.of(clientCtx));
136+
when(clientProvider.parameters(any())).thenReturn(Optional.of(params));
137+
138+
Settings settings = Settings.builder().put("transport.ssl.enforce_hostname_verification", false).build();
139+
140+
DefaultSslContextProvider sslServer = new DefaultSslContextProvider(serverProvider, settings);
141+
DefaultSslContextProvider sslClient = new DefaultSslContextProvider(clientProvider, settings);
142+
143+
int port = getBasePort(basePort) + randomIntBetween(0, 99);
144+
Location location = Location.forGrpcTls("localhost", port);
145+
ExecutorService exec = Executors.newSingleThreadExecutor();
146+
147+
try (RootAllocator allocator = new RootAllocator(Integer.MAX_VALUE)) {
148+
FlightServer server = OSFlightServer.builder()
149+
.allocator(allocator.newChildAllocator("server", 0, Long.MAX_VALUE))
150+
.location(location)
151+
.producer(new NoOpFlightProducer())
152+
.sslContext(sslServer.getServerSslContext())
153+
.executor(exec)
154+
.build();
155+
server.start();
156+
157+
try (
158+
FlightClient client = OSFlightClient.builder()
159+
.allocator(allocator.newChildAllocator("client", 0, Long.MAX_VALUE))
160+
.location(location)
161+
.sslContext(sslClient.getClientSslContext())
162+
.build()
163+
) {
164+
try {
165+
triggerHandshake(client);
166+
assertEquals(cnBefore, getCN(capturedChain));
167+
168+
reload.run();
169+
getChannel(client).enterIdle();
170+
171+
triggerHandshake(client);
172+
assertEquals(cnAfter, getCN(capturedChain));
173+
} finally {
174+
server.shutdown();
175+
server.awaitTermination();
176+
server.close();
177+
exec.shutdownNow();
178+
}
179+
}
180+
}
181+
}
182+
183+
// ---- helpers ----
184+
185+
private static X509Bundle selfSigned(String cn) throws Exception {
186+
return new CertificateBuilder().subject("CN=" + cn).setIsCertificateAuthority(true).buildSelfSigned();
187+
}
188+
189+
private static void triggerHandshake(FlightClient client) {
190+
try {
191+
client.listFlights(Criteria.ALL).forEach(f -> {});
192+
} catch (Exception ignored) {} // NoOpFlightProducer throws UNIMPLEMENTED; handshake already done
193+
}
194+
195+
@SuppressForbidden(reason = "need access to FlightClient's private channel field for test verification")
196+
private static ManagedChannel getChannel(FlightClient client) throws Exception {
197+
Field f = FlightClient.class.getDeclaredField("channel");
198+
f.setAccessible(true);
199+
return (ManagedChannel) f.get(client);
200+
}
201+
202+
private static String getCN(X509Certificate[] chain) {
203+
assertNotNull("TrustManager was never called — handshake did not occur", chain[0]);
204+
return chain[0].getSubjectX500Principal().getName().replaceFirst(".*CN=([^,]+).*", "$1");
205+
}
206+
207+
private static TrustManager capturingTrustManager(X509Certificate[] capturedChain) {
208+
return new X509ExtendedTrustManager() {
209+
@Override
210+
public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) {
211+
capturedChain[0] = chain[0];
212+
}
213+
214+
@Override
215+
public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) {
216+
capturedChain[0] = chain[0];
217+
}
218+
219+
@Override
220+
public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException {
221+
capturedChain[0] = chain[0];
222+
}
223+
224+
@Override
225+
public void checkClientTrusted(X509Certificate[] c, String a, SSLEngine e) {}
226+
227+
@Override
228+
public void checkClientTrusted(X509Certificate[] c, String a, Socket s) {}
229+
230+
@Override
231+
public void checkClientTrusted(X509Certificate[] c, String a) throws CertificateException {}
232+
233+
@Override
234+
public X509Certificate[] getAcceptedIssuers() {
235+
return new X509Certificate[0];
236+
}
237+
};
238+
}
239+
}

0 commit comments

Comments
 (0)