Skip to content

Commit c31995e

Browse files
committed
In-progress changes for Authority verify in okhttp transport.
1 parent a79982c commit c31995e

File tree

4 files changed

+50
-21
lines changed

4 files changed

+50
-21
lines changed

core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,10 @@ public ManagedChannelImplBuilder(SocketAddress directServerAddress, String autho
360360
InternalConfiguratorRegistry.configureChannelBuilder(this);
361361
}
362362

363+
public ChannelCredentials getChannelCredentials() {
364+
return channelCredentials;
365+
}
366+
363367
@Override
364368
public ManagedChannelImplBuilder directExecutor() {
365369
return executor(MoreExecutors.directExecutor());

okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
import io.grpc.internal.SharedResourceHolder.Resource;
4848
import io.grpc.internal.SharedResourcePool;
4949
import io.grpc.internal.TransportTracer;
50+
import io.grpc.internal.TransportTracer.Factory;
5051
import io.grpc.okhttp.internal.CipherSuite;
5152
import io.grpc.okhttp.internal.ConnectionSpec;
5253
import io.grpc.okhttp.internal.Platform;
@@ -536,7 +537,8 @@ OkHttpTransportFactory buildTransportFactory() {
536537
keepAliveWithoutCalls,
537538
maxInboundMetadataSize,
538539
transportTracerFactory,
539-
useGetForSafeMethods);
540+
useGetForSafeMethods,
541+
managedChannelImplBuilder.getChannelCredentials());
540542
}
541543

542544
OkHttpChannelBuilder disableCheckAuthority() {
@@ -799,6 +801,7 @@ static final class OkHttpTransportFactory implements ClientTransportFactory {
799801
private final boolean keepAliveWithoutCalls;
800802
final int maxInboundMetadataSize;
801803
final boolean useGetForSafeMethods;
804+
private final ChannelCredentials channelCredentials;
802805
private boolean closed;
803806

804807
private OkHttpTransportFactory(
@@ -815,8 +818,9 @@ private OkHttpTransportFactory(
815818
int flowControlWindow,
816819
boolean keepAliveWithoutCalls,
817820
int maxInboundMetadataSize,
818-
TransportTracer.Factory transportTracerFactory,
819-
boolean useGetForSafeMethods) {
821+
Factory transportTracerFactory,
822+
boolean useGetForSafeMethods,
823+
ChannelCredentials channelCredentials) {
820824
this.executorPool = executorPool;
821825
this.executor = executorPool.getObject();
822826
this.scheduledExecutorServicePool = scheduledExecutorServicePool;
@@ -834,6 +838,7 @@ private OkHttpTransportFactory(
834838
this.keepAliveWithoutCalls = keepAliveWithoutCalls;
835839
this.maxInboundMetadataSize = maxInboundMetadataSize;
836840
this.useGetForSafeMethods = useGetForSafeMethods;
841+
this.channelCredentials = channelCredentials;
837842

838843
this.transportTracerFactory =
839844
Preconditions.checkNotNull(transportTracerFactory, "transportTracerFactory");
@@ -861,7 +866,8 @@ public void run() {
861866
options.getUserAgent(),
862867
options.getEagAttributes(),
863868
options.getHttpConnectProxiedSocketAddress(),
864-
tooManyPingsRunnable);
869+
tooManyPingsRunnable,
870+
channelCredentials);
865871
if (enableKeepAlive) {
866872
transport.enableKeepAlive(
867873
true, keepAliveTimeNanosState.get(), keepAliveTimeoutNanos, keepAliveWithoutCalls);
@@ -897,7 +903,7 @@ public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials ch
897903
keepAliveWithoutCalls,
898904
maxInboundMetadataSize,
899905
transportTracerFactory,
900-
useGetForSafeMethods);
906+
useGetForSafeMethods, managedChannelImplBuilder.getChannelCredentials());
901907
return new SwapChannelCredentialsResult(factory, result.callCredentials);
902908
}
903909

okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.google.common.util.concurrent.SettableFuture;
3030
import io.grpc.Attributes;
3131
import io.grpc.CallOptions;
32+
import io.grpc.ChannelCredentials;
3233
import io.grpc.ClientStreamTracer;
3334
import io.grpc.Grpc;
3435
import io.grpc.HttpConnectProxiedSocketAddress;
@@ -42,6 +43,8 @@
4243
import io.grpc.Status;
4344
import io.grpc.Status.Code;
4445
import io.grpc.StatusException;
46+
import io.grpc.TlsChannelCredentials;
47+
import io.grpc.internal.ClientStream;
4548
import io.grpc.internal.ClientStreamListener.RpcProgress;
4649
import io.grpc.internal.ConnectionClientTransport;
4750
import io.grpc.internal.GrpcAttributes;
@@ -54,6 +57,7 @@
5457
import io.grpc.internal.StatsTraceContext;
5558
import io.grpc.internal.TransportTracer;
5659
import io.grpc.okhttp.ExceptionHandlingFrameWriter.TransportExceptionHandler;
60+
import io.grpc.okhttp.OkHttpChannelBuilder.OkHttpTransportFactory;
5761
import io.grpc.okhttp.internal.ConnectionSpec;
5862
import io.grpc.okhttp.internal.Credentials;
5963
import io.grpc.okhttp.internal.StatusLine;
@@ -82,6 +86,7 @@
8286
import java.util.List;
8387
import java.util.Locale;
8488
import java.util.Map;
89+
import java.util.Optional;
8590
import java.util.Random;
8691
import java.util.concurrent.BrokenBarrierException;
8792
import java.util.concurrent.CountDownLatch;
@@ -99,6 +104,7 @@
99104
import javax.net.ssl.SSLSession;
100105
import javax.net.ssl.SSLSocket;
101106
import javax.net.ssl.SSLSocketFactory;
107+
import javax.net.ssl.X509ExtendedTrustManager;
102108
import okio.Buffer;
103109
import okio.BufferedSink;
104110
import okio.BufferedSource;
@@ -114,6 +120,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
114120
OutboundFlowController.Transport {
115121
private static final Map<ErrorCode, Status> ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap();
116122
private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName());
123+
private final ChannelCredentials channelCredentials;
117124

118125
private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
119126
Map<ErrorCode, Status> errorToStatus = new EnumMap<>(ErrorCode.class);
@@ -205,6 +212,8 @@ private static Map<ErrorCode, Status> buildErrorCodeToStatusMap() {
205212
private final boolean useGetForSafeMethods;
206213
@GuardedBy("lock")
207214
private final TransportTracer transportTracer;
215+
private Optional<X509ExtendedTrustManager> x509ExtendedTrustManager;
216+
208217
@GuardedBy("lock")
209218
private final InUseStateAggregator<OkHttpClientStream> inUseState =
210219
new InUseStateAggregator<OkHttpClientStream>() {
@@ -233,13 +242,14 @@ protected void handleNotInUse() {
233242
SettableFuture<Void> connectedFuture;
234243

235244
public OkHttpClientTransport(
236-
OkHttpChannelBuilder.OkHttpTransportFactory transportFactory,
245+
OkHttpTransportFactory transportFactory,
237246
InetSocketAddress address,
238247
String authority,
239248
@Nullable String userAgent,
240249
Attributes eagAttrs,
241250
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
242-
Runnable tooManyPingsRunnable) {
251+
Runnable tooManyPingsRunnable,
252+
ChannelCredentials channelCredentials) {
243253
this(
244254
transportFactory,
245255
address,
@@ -249,19 +259,21 @@ public OkHttpClientTransport(
249259
GrpcUtil.STOPWATCH_SUPPLIER,
250260
new Http2(),
251261
proxiedAddr,
252-
tooManyPingsRunnable);
262+
tooManyPingsRunnable,
263+
channelCredentials);
253264
}
254265

255266
private OkHttpClientTransport(
256-
OkHttpChannelBuilder.OkHttpTransportFactory transportFactory,
267+
OkHttpTransportFactory transportFactory,
257268
InetSocketAddress address,
258269
String authority,
259270
@Nullable String userAgent,
260271
Attributes eagAttrs,
261272
Supplier<Stopwatch> stopwatchFactory,
262273
Variant variant,
263274
@Nullable HttpConnectProxiedSocketAddress proxiedAddr,
264-
Runnable tooManyPingsRunnable) {
275+
Runnable tooManyPingsRunnable,
276+
ChannelCredentials channelCredentials) {
265277
this.address = Preconditions.checkNotNull(address, "address");
266278
this.defaultAuthority = authority;
267279
this.maxMessageSize = transportFactory.maxMessageSize;
@@ -291,6 +303,7 @@ private OkHttpClientTransport(
291303
this.attributes = Attributes.newBuilder()
292304
.set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build();
293305
this.useGetForSafeMethods = transportFactory.useGetForSafeMethods;
306+
this.channelCredentials = channelCredentials;
294307
initTransportTracer();
295308
}
296309

@@ -316,7 +329,8 @@ private OkHttpClientTransport(
316329
stopwatchFactory,
317330
variant,
318331
null,
319-
tooManyPingsRunnable);
332+
tooManyPingsRunnable,
333+
null);
320334
this.connectingCallback = connectingCallback;
321335
this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture");
322336
}
@@ -389,13 +403,18 @@ public void ping(final PingCallback callback, Executor executor) {
389403
}
390404

391405
@Override
392-
public OkHttpClientStream newStream(
406+
public ClientStream newStream(
393407
MethodDescriptor<?, ?> method, Metadata headers, CallOptions callOptions,
394408
ClientStreamTracer[] tracers) {
395409
Preconditions.checkNotNull(method, "method");
396410
Preconditions.checkNotNull(headers, "headers");
397411
StatsTraceContext statsTraceContext =
398412
StatsTraceContext.newClientContext(tracers, getAttributes(), headers);
413+
if (callOptions.getAuthority() != null && channelCredentials instanceof TlsChannelCredentials) {
414+
if (x509ExtendedTrustManager == null) {
415+
416+
}
417+
}
399418
// FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope
400419
synchronized (lock) { // to make @GuardedBy linter happy
401420
return new OkHttpClientStream(

okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ public void testToString() throws Exception {
241241
/*userAgent=*/ null,
242242
EAG_ATTRS,
243243
NO_PROXY,
244-
tooManyPingsRunnable);
244+
tooManyPingsRunnable, channelCredentials);
245245
String s = clientTransport.toString();
246246
assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport"));
247247
assertTrue("Unexpected: " + s, s.contains(address.toString()));
@@ -259,7 +259,7 @@ public void testTransportExecutorWithTooFewThreads() throws Exception {
259259
null,
260260
EAG_ATTRS,
261261
NO_PROXY,
262-
tooManyPingsRunnable);
262+
tooManyPingsRunnable, channelCredentials);
263263
clientTransport.start(transportListener);
264264
ArgumentCaptor<Status> statusCaptor = ArgumentCaptor.forClass(Status.class);
265265
verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture());
@@ -1726,7 +1726,7 @@ public void invalidAuthorityPropagates() {
17261726
"userAgent",
17271727
EAG_ATTRS,
17281728
NO_PROXY,
1729-
tooManyPingsRunnable);
1729+
tooManyPingsRunnable, channelCredentials);
17301730

17311731
String host = clientTransport.getOverridenHost();
17321732
int port = clientTransport.getOverridenPort();
@@ -1744,7 +1744,7 @@ public void unreachableServer() throws Exception {
17441744
"userAgent",
17451745
EAG_ATTRS,
17461746
NO_PROXY,
1747-
tooManyPingsRunnable);
1747+
tooManyPingsRunnable, channelCredentials);
17481748

17491749
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
17501750
clientTransport.start(listener);
@@ -1774,7 +1774,7 @@ public void customSocketFactory() throws Exception {
17741774
"userAgent",
17751775
EAG_ATTRS,
17761776
NO_PROXY,
1777-
tooManyPingsRunnable);
1777+
tooManyPingsRunnable, channelCredentials);
17781778

17791779
ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class);
17801780
clientTransport.start(listener);
@@ -1799,7 +1799,7 @@ public void proxy_200() throws Exception {
17991799
.setTargetAddress(targetAddress)
18001800
.setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort()))
18011801
.build(),
1802-
tooManyPingsRunnable);
1802+
tooManyPingsRunnable, channelCredentials);
18031803
clientTransport.start(transportListener);
18041804

18051805
Socket sock = serverSocket.accept();
@@ -1848,7 +1848,7 @@ public void proxy_500() throws Exception {
18481848
.setTargetAddress(targetAddress)
18491849
.setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort()))
18501850
.build(),
1851-
tooManyPingsRunnable);
1851+
tooManyPingsRunnable, channelCredentials);
18521852
clientTransport.start(transportListener);
18531853

18541854
Socket sock = serverSocket.accept();
@@ -1896,7 +1896,7 @@ public void proxy_immediateServerClose() throws Exception {
18961896
.setTargetAddress(targetAddress)
18971897
.setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort()))
18981898
.build(),
1899-
tooManyPingsRunnable);
1899+
tooManyPingsRunnable, channelCredentials);
19001900
clientTransport.start(transportListener);
19011901

19021902
Socket sock = serverSocket.accept();
@@ -1927,7 +1927,7 @@ public void proxy_serverHangs() throws Exception {
19271927
.setTargetAddress(targetAddress)
19281928
.setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort()))
19291929
.build(),
1930-
tooManyPingsRunnable);
1930+
tooManyPingsRunnable, channelCredentials);
19311931
clientTransport.proxySocketTimeout = 10;
19321932
clientTransport.start(transportListener);
19331933

0 commit comments

Comments
 (0)