Skip to content

Commit 5e794bf

Browse files
committed
In-progress changes.
1 parent 6263cce commit 5e794bf

File tree

13 files changed

+90
-72
lines changed

13 files changed

+90
-72
lines changed

netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ private InternalProtocolNegotiators() {}
4242
*/
4343
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext,
4444
ObjectPool<? extends Executor> executorPool,
45-
Optional<Runnable> handshakeCompleteRunnable) {
45+
Optional<Runnable> handshakeCompleteRunnable,
46+
String sni) {
4647
final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext,
47-
executorPool, handshakeCompleteRunnable, null);
48+
executorPool, handshakeCompleteRunnable, null, sni);
4849
final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator {
4950

5051
@Override
@@ -71,8 +72,8 @@ public void close() {
7172
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
7273
* may happen immediately, even before the TLS Handshake is complete.
7374
*/
74-
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) {
75-
return tls(sslContext, null, Optional.absent());
75+
public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, String sni) {
76+
return tls(sslContext, null, Optional.absent(), sni);
7677
}
7778

7879
/**
@@ -170,7 +171,7 @@ public static ChannelHandler clientTlsHandler(
170171
ChannelHandler next, SslContext sslContext, String authority,
171172
ChannelLogger negotiationLogger) {
172173
return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger,
173-
Optional.absent(), null, null);
174+
Optional.absent(), null, null, sni);
174175
}
175176

176177
public static class ProtocolNegotiationHandler

netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType(
652652
case PLAINTEXT_UPGRADE:
653653
return ProtocolNegotiators.plaintextUpgrade();
654654
case TLS:
655-
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null);
655+
return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null);
656656
default:
657657
throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType);
658658
}

netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import io.grpc.internal.GrpcUtil;
4747
import io.grpc.internal.NoopSslSession;
4848
import io.grpc.internal.ObjectPool;
49+
import io.netty.channel.Channel;
4950
import io.netty.channel.ChannelDuplexHandler;
5051
import io.netty.channel.ChannelFutureListener;
5152
import io.netty.channel.ChannelHandler;
@@ -60,6 +61,7 @@
6061
import io.netty.handler.codec.http2.Http2ClientUpgradeCodec;
6162
import io.netty.handler.proxy.HttpProxyHandler;
6263
import io.netty.handler.proxy.ProxyConnectionEvent;
64+
import io.netty.handler.ssl.ClientAuth;
6365
import io.netty.handler.ssl.OpenSsl;
6466
import io.netty.handler.ssl.OpenSslEngine;
6567
import io.netty.handler.ssl.SslContext;
@@ -223,15 +225,15 @@ public static FromServerCredentialsResult from(ServerCredentials creds) {
223225
} // else use system default
224226
switch (tlsCreds.getClientAuth()) {
225227
case OPTIONAL:
226-
builder.clientAuth(io.netty.handler.ssl.ClientAuth.OPTIONAL);
228+
builder.clientAuth(ClientAuth.OPTIONAL);
227229
break;
228230

229231
case REQUIRE:
230-
builder.clientAuth(io.netty.handler.ssl.ClientAuth.REQUIRE);
232+
builder.clientAuth(ClientAuth.REQUIRE);
231233
break;
232234

233235
case NONE:
234-
builder.clientAuth(io.netty.handler.ssl.ClientAuth.NONE);
236+
builder.clientAuth(ClientAuth.NONE);
235237
break;
236238

237239
default:
@@ -578,21 +580,23 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws
578580
static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator {
579581

580582
public ClientTlsProtocolNegotiator(SslContext sslContext,
581-
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
582-
X509TrustManager x509ExtendedTrustManager) {
583+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
584+
X509TrustManager x509ExtendedTrustManager, String sni) {
583585
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
584586
this.executorPool = executorPool;
585587
if (this.executorPool != null) {
586588
this.executor = this.executorPool.getObject();
587589
}
588590
this.handshakeCompleteRunnable = handshakeCompleteRunnable;
589591
this.x509ExtendedTrustManager = x509ExtendedTrustManager;
592+
this.sni = sni;
590593
}
591594

592595
private final SslContext sslContext;
593596
private final ObjectPool<? extends Executor> executorPool;
594597
private final Optional<Runnable> handshakeCompleteRunnable;
595598
private final X509TrustManager x509ExtendedTrustManager;
599+
private final String sni;
596600
private Executor executor;
597601

598602
@Override
@@ -606,7 +610,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) {
606610
ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger();
607611
ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(),
608612
this.executor, negotiationLogger, handshakeCompleteRunnable, this,
609-
x509ExtendedTrustManager);
613+
x509ExtendedTrustManager, sni);
610614
return new WaitUntilActiveHandler(cth, negotiationLogger);
611615
}
612616

@@ -631,15 +635,17 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
631635
private Executor executor;
632636
private final Optional<Runnable> handshakeCompleteRunnable;
633637
private final X509TrustManager x509ExtendedTrustManager;
638+
private final String sni;
634639
private SSLEngine sslEngine;
635640

636641
ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority,
637-
Executor executor, ChannelLogger negotiationLogger,
638-
Optional<Runnable> handshakeCompleteRunnable,
639-
ClientTlsProtocolNegotiator clientTlsProtocolNegotiator,
640-
X509TrustManager x509ExtendedTrustManager) {
642+
Executor executor, ChannelLogger negotiationLogger,
643+
Optional<Runnable> handshakeCompleteRunnable,
644+
ClientTlsProtocolNegotiator clientTlsProtocolNegotiator,
645+
X509TrustManager x509ExtendedTrustManager, String sni) {
641646
super(next, negotiationLogger);
642647
this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext");
648+
this.sni = sni;
643649
HostPort hostPort = parseAuthority(authority);
644650
this.host = hostPort.host;
645651
this.port = hostPort.port;
@@ -651,11 +657,7 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler {
651657
@Override
652658
@IgnoreJRERequirement
653659
protected void handlerAdded0(ChannelHandlerContext ctx) {
654-
/*if (host.equals("psm-grpc-server")) {
655-
sslEngine = sslContext.newEngine(ctx.alloc(), "kannanj-psm-server-20250604-1226-8bkw5-830293263384.us-east7.run.app", 443);
656-
} else {*/
657-
sslEngine = sslContext.newEngine(ctx.alloc(), host, port);
658-
// }
660+
sslEngine = sslContext.newEngine(ctx.alloc(), sni != null? sni : host, port);
659661
SSLParameters sslParams = sslEngine.getSSLParameters();
660662
sslParams.setEndpointIdentificationAlgorithm("HTTPS");
661663
sslEngine.setSSLParameters(sslParams);
@@ -748,25 +750,27 @@ static HostPort parseAuthority(String authority) {
748750

749751
/**
750752
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
751-
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
753+
* be negotiated, the {@code handler} is added and writes to the {@link Channel}
752754
* may happen immediately, even before the TLS Handshake is complete.
755+
*
753756
* @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks
757+
* @param sni
754758
*/
755759
public static ProtocolNegotiator tls(SslContext sslContext,
756-
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
757-
X509TrustManager x509ExtendedTrustManager) {
760+
ObjectPool<? extends Executor> executorPool, Optional<Runnable> handshakeCompleteRunnable,
761+
X509TrustManager x509ExtendedTrustManager, String sni) {
758762
return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable,
759-
x509ExtendedTrustManager);
763+
x509ExtendedTrustManager, sni);
760764
}
761765

762766
/**
763767
* Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will
764-
* be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel}
768+
* be negotiated, the {@code handler} is added and writes to the {@link Channel}
765769
* may happen immediately, even before the TLS Handshake is complete.
766770
*/
767771
public static ProtocolNegotiator tls(SslContext sslContext,
768772
X509TrustManager x509ExtendedTrustManager) {
769-
return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager);
773+
return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null);
770774
}
771775

772776
public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext,
@@ -908,8 +912,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
908912
}
909913

910914
/**
911-
* Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is
912-
* added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately,
915+
* Returns a {@link ChannelHandler} that ensures that the {@code handler} is
916+
* added to the pipeline writes to the {@link Channel} may happen immediately,
913917
* even before it is active.
914918
*/
915919
public static ProtocolNegotiator plaintext() {

netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -836,7 +836,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception {
836836
.keyManager(clientCert, clientKey)
837837
.build();
838838
ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool,
839-
Optional.absent(), null);
839+
Optional.absent(), null, sni);
840840
// after starting the client, the Executor in the client pool should be used
841841
assertEquals(true, clientExecutorPool.isInUse());
842842
final NettyClientTransport transport = newTransport(negotiator);

netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -918,7 +918,7 @@ public String applicationProtocol() {
918918

919919
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
920920
"authority", elg, noopLogger, Optional.absent(),
921-
getClientTlsProtocolNegotiator(), null);
921+
getClientTlsProtocolNegotiator(), null, sni);
922922
pipeline.addLast(handler);
923923
pipeline.replace(SslHandler.class, null, goodSslHandler);
924924
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -957,7 +957,7 @@ public String applicationProtocol() {
957957

958958
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
959959
"authority", elg, noopLogger, Optional.absent(),
960-
getClientTlsProtocolNegotiator(), null);
960+
getClientTlsProtocolNegotiator(), null, sni);
961961
pipeline.addLast(handler);
962962
pipeline.replace(SslHandler.class, null, goodSslHandler);
963963
pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT);
@@ -982,7 +982,7 @@ public String applicationProtocol() {
982982

983983
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
984984
"authority", elg, noopLogger, Optional.absent(),
985-
getClientTlsProtocolNegotiator(), null);
985+
getClientTlsProtocolNegotiator(), null, sni);
986986
pipeline.addLast(handler);
987987

988988
final AtomicReference<Throwable> error = new AtomicReference<>();
@@ -1011,7 +1011,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
10111011
public void clientTlsHandler_closeDuringNegotiation() throws Exception {
10121012
ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext,
10131013
"authority", null, noopLogger, Optional.absent(),
1014-
getClientTlsProtocolNegotiator(), null);
1014+
getClientTlsProtocolNegotiator(), null, sni);
10151015
pipeline.addLast(new WriteBufferingAndExceptionHandler(handler));
10161016
ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE);
10171017

@@ -1026,7 +1026,7 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception {
10261026
private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException {
10271027
return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager(
10281028
TlsTesting.loadCert("ca.pem")).build(),
1029-
null, Optional.absent(), null);
1029+
null, Optional.absent(), null, sni);
10301030
}
10311031

10321032
@Test
@@ -1277,7 +1277,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception {
12771277
}
12781278
FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler();
12791279
ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext,
1280-
null, Optional.absent(), null);
1280+
null, Optional.absent(), null, sni);
12811281
WriteBufferingAndExceptionHandler clientWbaeh =
12821282
new WriteBufferingAndExceptionHandler(pn.newHandler(gh));
12831283

xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) {
146146

147147
childLbHelper.updateDropPolicies(config.dropCategories);
148148
childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests);
149-
childLbHelper.updateSslContextProviderSupplier(config.tlsContext);
149+
childLbHelper.updateSslContext(config.tlsContext);
150150
childLbHelper.updateFilterMetadata(config.filterMetadata);
151151

152152
childSwitchLb.handleResolvedAddresses(
@@ -184,7 +184,7 @@ public void shutdown() {
184184
if (childSwitchLb != null) {
185185
childSwitchLb.shutdown();
186186
if (childLbHelper != null) {
187-
childLbHelper.updateSslContextProviderSupplier(null);
187+
childLbHelper.updateSslContext(null);
188188
childLbHelper = null;
189189
}
190190
}
@@ -204,7 +204,7 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper {
204204
private List<DropOverload> dropPolicies = Collections.emptyList();
205205
private long maxConcurrentRequests = DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS;
206206
@Nullable
207-
private SslContextProviderSupplier sslContextProviderSupplier;
207+
private UpstreamTlsContext tlsContext;
208208
private Map<String, Struct> filterMetadata = ImmutableMap.of();
209209
@Nullable
210210
private final ServerInfo lrsServerInfo;
@@ -293,10 +293,12 @@ private List<EquivalentAddressGroup> withAdditionalAttributes(
293293
for (EquivalentAddressGroup eag : addresses) {
294294
Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set(
295295
XdsAttributes.ATTR_CLUSTER_NAME, cluster);
296-
if (sslContextProviderSupplier != null) {
296+
if (tlsContext != null) {
297297
attrBuilder.set(
298298
SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER,
299-
sslContextProviderSupplier);
299+
new SslContextProviderSupplier(tlsContext,
300+
(TlsContextManager) xdsClient.getSecurityConfig(),
301+
eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)));
300302
}
301303
newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build()));
302304
}
@@ -348,22 +350,11 @@ private void updateMaxConcurrentRequests(@Nullable Long maxConcurrentRequests) {
348350
updateBalancingState(currentState, currentPicker);
349351
}
350352

351-
private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) {
352-
UpstreamTlsContext currentTlsContext =
353-
sslContextProviderSupplier != null
354-
? (UpstreamTlsContext)sslContextProviderSupplier.getTlsContext()
355-
: null;
356-
if (Objects.equals(currentTlsContext, tlsContext)) {
353+
private void updateSslContext(@Nullable UpstreamTlsContext tlsContext) {
354+
if (Objects.equals(this.tlsContext, tlsContext)) {
357355
return;
358356
}
359-
if (sslContextProviderSupplier != null) {
360-
sslContextProviderSupplier.close();
361-
}
362-
sslContextProviderSupplier =
363-
tlsContext != null
364-
? new SslContextProviderSupplier(tlsContext,
365-
(TlsContextManager) xdsClient.getSecurityConfig())
366-
: null;
357+
this.tlsContext = tlsContext;
367358
}
368359

369360
private void updateFilterMetadata(Map<String, Struct> filterMetadata) {

xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,28 @@ public int hashCode() {
7373

7474
public static final class UpstreamTlsContext extends BaseTlsContext {
7575

76+
private final String sni;
77+
private final boolean auto_host_sni;
78+
7679
@VisibleForTesting
77-
public UpstreamTlsContext(CommonTlsContext commonTlsContext) {
78-
super(commonTlsContext);
80+
public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) {
81+
super(upstreamTlsContext.getCommonTlsContext());
82+
this.sni = upstreamTlsContext.getSni();
83+
this.auto_host_sni = upstreamTlsContext.getAutoHostSni();
7984
}
8085

8186
public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext(
8287
io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext
8388
upstreamTlsContext) {
84-
return new UpstreamTlsContext(upstreamTlsContext.getCommonTlsContext());
89+
return new UpstreamTlsContext(upstreamTlsContext);
90+
}
91+
92+
public String getSni() {
93+
return sni;
94+
}
95+
96+
public boolean getAutoHostSni() {
97+
return auto_host_sni;
8598
}
8699

87100
@Override

xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) {
213213
new SslContextProvider.Callback(ctx.executor()) {
214214

215215
@Override
216-
public void updateSslContext(SslContext sslContext) {
216+
public void updateSslContext(SslContext sslContext, String sni) {
217217
if (ctx.isRemoved()) {
218218
return;
219219
}
@@ -222,7 +222,7 @@ public void updateSslContext(SslContext sslContext) {
222222
"ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}",
223223
new Object[]{grpcHandler.getAuthority(), ctx.name()});
224224
ChannelHandler handler =
225-
InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler);
225+
InternalProtocolNegotiators.tls(sslContext, sni).newHandler(grpcHandler);
226226

227227
// Delegate rest of handshake to TLS handler
228228
ctx.pipeline().addAfter(ctx.name(), null, handler);
@@ -356,7 +356,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) {
356356
new SslContextProvider.Callback(ctx.executor()) {
357357

358358
@Override
359-
public void updateSslContext(SslContext sslContext) {
359+
public void updateSslContext(SslContext sslContext, String sni) {
360360
ChannelHandler handler =
361361
InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler);
362362

xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ protected Callback(Executor executor) {
5757
}
5858

5959
/** Informs callee of new/updated SslContext. */
60-
@VisibleForTesting public abstract void updateSslContext(SslContext sslContext);
60+
@VisibleForTesting public abstract void updateSslContext(SslContext sslContext, String sni);
6161

6262
/** Informs callee of an exception that was generated. */
6363
@VisibleForTesting protected abstract void onException(Throwable throwable);
@@ -120,7 +120,7 @@ protected final void performCallback(
120120
public void run() {
121121
try {
122122
SslContext sslContext = sslContextGetter.get();
123-
callback.updateSslContext(sslContext);
123+
callback.updateSslContext(sslContext, sni);
124124
} catch (Throwable e) {
125125
callback.onException(e);
126126
}

0 commit comments

Comments
 (0)