Skip to content

Commit d35ee71

Browse files
feat: move credential handling responsibility to the proxy
stack-info: PR: #9690, branch: igorbernstein2/stack/3
1 parent 834bf5b commit d35ee71

File tree

4 files changed

+153
-9
lines changed

4 files changed

+153
-9
lines changed

bigtable/bigtable-proxy/src/main/java/com/google/cloud/bigtable/examples/proxy/commands/Serve.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,22 @@
1515
*/
1616
package com.google.cloud.bigtable.examples.proxy.commands;
1717

18+
import com.google.auth.Credentials;
19+
import com.google.auth.oauth2.GoogleCredentials;
1820
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
1921
import com.google.bigtable.admin.v2.BigtableTableAdminGrpc;
2022
import com.google.bigtable.v2.BigtableGrpc;
2123
import com.google.cloud.bigtable.examples.proxy.core.ProxyHandler;
2224
import com.google.cloud.bigtable.examples.proxy.core.Registry;
2325
import com.google.common.collect.ImmutableMap;
2426
import com.google.longrunning.OperationsGrpc;
27+
import io.grpc.CallCredentials;
2528
import io.grpc.InsecureServerCredentials;
2629
import io.grpc.ManagedChannel;
2730
import io.grpc.ManagedChannelBuilder;
2831
import io.grpc.Server;
2932
import io.grpc.ServerCallHandler;
33+
import io.grpc.auth.MoreCallCredentials;
3034
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
3135
import java.io.IOException;
3236
import java.net.InetSocketAddress;
@@ -66,6 +70,7 @@ public class Serve implements Callable<Void> {
6670

6771
ManagedChannel adminChannel = null;
6872
ManagedChannel dataChannel = null;
73+
Credentials credentials = null;
6974
Server server;
7075

7176
@Override
@@ -93,17 +98,21 @@ void start() throws IOException {
9398
.disableRetry()
9499
.build();
95100
}
101+
if (credentials == null) {
102+
credentials = GoogleCredentials.getApplicationDefault();
103+
}
104+
CallCredentials callCredentials = MoreCallCredentials.from(credentials);
96105

97106
Map<String, ServerCallHandler<byte[], byte[]>> serviceMap =
98107
ImmutableMap.of(
99108
BigtableGrpc.SERVICE_NAME,
100-
new ProxyHandler<>(dataChannel),
109+
new ProxyHandler<>(dataChannel, callCredentials),
101110
BigtableInstanceAdminGrpc.SERVICE_NAME,
102-
new ProxyHandler<>(adminChannel),
111+
new ProxyHandler<>(adminChannel, callCredentials),
103112
BigtableTableAdminGrpc.SERVICE_NAME,
104-
new ProxyHandler<>(adminChannel),
113+
new ProxyHandler<>(adminChannel, callCredentials),
105114
OperationsGrpc.SERVICE_NAME,
106-
new ProxyHandler<>(adminChannel));
115+
new ProxyHandler<>(adminChannel, callCredentials));
107116

108117
server =
109118
NettyServerBuilder.forAddress(

bigtable/bigtable-proxy/src/main/java/com/google/cloud/bigtable/examples/proxy/core/ProxyHandler.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616
package com.google.cloud.bigtable.examples.proxy.core;
1717

18+
import io.grpc.CallCredentials;
1819
import io.grpc.CallOptions;
1920
import io.grpc.Channel;
2021
import io.grpc.ClientCall;
@@ -24,15 +25,23 @@
2425

2526
/** A factory pairing of an incoming server call to an outgoing client call. */
2627
public final class ProxyHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
28+
private final Metadata.Key<String> AUTHORIZATION_KEY =
29+
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);
30+
2731
private final Channel channel;
32+
private final CallCredentials callCredentials;
2833

29-
public ProxyHandler(Channel channel) {
34+
public ProxyHandler(Channel channel, CallCredentials callCredentials) {
3035
this.channel = channel;
36+
this.callCredentials = callCredentials;
3137
}
3238

3339
@Override
3440
public ServerCall.Listener<ReqT> startCall(ServerCall<ReqT, RespT> serverCall, Metadata headers) {
35-
CallOptions callOptions = CallOptions.DEFAULT;
41+
// Strip incoming credentials
42+
headers.removeAll(AUTHORIZATION_KEY);
43+
// Inject proxy credentials
44+
CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(callCredentials);
3645

3746
ClientCall<ReqT, RespT> clientCall =
3847
channel.newCall(serverCall.getMethodDescriptor(), callOptions);

bigtable/bigtable-proxy/src/test/java/com/google/cloud/bigtable/examples/proxy/commands/ServeTest.java

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.google.common.truth.Truth.assertThat;
2020
import static com.google.common.truth.Truth.assertWithMessage;
2121

22+
import com.google.auth.Credentials;
2223
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
2324
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminFutureStub;
2425
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminImplBase;
@@ -34,6 +35,7 @@
3435
import com.google.bigtable.v2.BigtableGrpc.BigtableImplBase;
3536
import com.google.bigtable.v2.CheckAndMutateRowRequest;
3637
import com.google.bigtable.v2.CheckAndMutateRowResponse;
38+
import com.google.common.collect.Lists;
3739
import com.google.common.util.concurrent.ListenableFuture;
3840
import com.google.longrunning.GetOperationRequest;
3941
import com.google.longrunning.Operation;
@@ -62,6 +64,9 @@
6264
import io.grpc.testing.GrpcCleanupRule;
6365
import java.io.IOException;
6466
import java.net.ServerSocket;
67+
import java.net.URI;
68+
import java.util.List;
69+
import java.util.Map;
6570
import java.util.UUID;
6671
import java.util.concurrent.BlockingDeque;
6772
import java.util.concurrent.BlockingQueue;
@@ -93,6 +98,7 @@ public class ServeTest {
9398
private FakeTableAdminService tableAdminService;
9499
private OperationService operationService;
95100
private ManagedChannel fakeServiceChannel;
101+
private FakeCredentials fakeCredentials;
96102

97103
// Proxy
98104
private Serve serve;
@@ -107,6 +113,8 @@ public void setUp() throws IOException {
107113
tableAdminService = new FakeTableAdminService();
108114
operationService = new OperationService();
109115

116+
fakeCredentials = new FakeCredentials();
117+
110118
grpcCleanup.register(
111119
InProcessServerBuilder.forName(TARGET_SERVER_NAME)
112120
.intercept(metadataInterceptor)
@@ -122,7 +130,9 @@ public void setUp() throws IOException {
122130
InProcessChannelBuilder.forName(TARGET_SERVER_NAME).usePlaintext().build());
123131

124132
// Create the proxy
125-
serve = createAndStartCommand(fakeServiceChannel);
133+
// Inject fakes for upstream calls. For unit tests we want to shim communications to the
134+
// bigtable service.
135+
serve = createAndStartCommand(fakeServiceChannel, fakeCredentials);
126136

127137
proxyChannel =
128138
grpcCleanup.register(
@@ -318,11 +328,84 @@ public void onClose(Status status, Metadata trailers) {
318328
assertThat(clientRecvTrailer.get()).hasValue("trailer", "trailer-value");
319329
}
320330

321-
private static Serve createAndStartCommand(ManagedChannel targetChannel) throws IOException {
331+
@Test
332+
public void testCredentials() throws InterruptedException, ExecutionException, TimeoutException {
333+
BigtableFutureStub proxyStub = BigtableGrpc.newFutureStub(proxyChannel);
334+
335+
CheckAndMutateRowRequest request =
336+
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
337+
ListenableFuture<CheckAndMutateRowResponse> proxyFuture = proxyStub.checkAndMutateRow(request);
338+
StreamObserver<CheckAndMutateRowResponse> serverObserver =
339+
dataService
340+
.calls
341+
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
342+
.poll(1, TimeUnit.SECONDS);
343+
344+
assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
345+
.that(serverObserver)
346+
.isNotNull();
347+
348+
serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
349+
serverObserver.onCompleted();
350+
proxyFuture.get(1, TimeUnit.SECONDS);
351+
352+
assertThat(metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS))
353+
.hasValue("authorization", "fake-token");
354+
}
355+
356+
@Test
357+
public void testCredentialsClobber()
358+
throws InterruptedException, ExecutionException, TimeoutException {
359+
BigtableFutureStub proxyStub =
360+
BigtableGrpc.newFutureStub(proxyChannel)
361+
.withInterceptors(
362+
new ClientInterceptor() {
363+
@Override
364+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
365+
MethodDescriptor<ReqT, RespT> methodDescriptor,
366+
CallOptions callOptions,
367+
Channel channel) {
368+
return new SimpleForwardingClientCall<ReqT, RespT>(
369+
channel.newCall(methodDescriptor, callOptions)) {
370+
@Override
371+
public void start(Listener<RespT> responseListener, Metadata headers) {
372+
headers.put(
373+
Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER),
374+
"pre-proxied-value");
375+
super.start(responseListener, headers);
376+
}
377+
};
378+
}
379+
});
380+
381+
CheckAndMutateRowRequest request =
382+
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
383+
ListenableFuture<CheckAndMutateRowResponse> proxyFuture = proxyStub.checkAndMutateRow(request);
384+
StreamObserver<CheckAndMutateRowResponse> serverObserver =
385+
dataService
386+
.calls
387+
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
388+
.poll(1, TimeUnit.SECONDS);
389+
390+
assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
391+
.that(serverObserver)
392+
.isNotNull();
393+
394+
serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
395+
serverObserver.onCompleted();
396+
proxyFuture.get(1, TimeUnit.SECONDS);
397+
398+
Metadata serverRequestHeaders = metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS);
399+
assertThat(serverRequestHeaders).hasValue("authorization", "fake-token");
400+
}
401+
402+
private static Serve createAndStartCommand(
403+
ManagedChannel targetChannel, FakeCredentials targetCredentials) throws IOException {
322404
for (int i = 10; i >= 0; i--) {
323405
Serve s = new Serve();
324406
s.dataChannel = targetChannel;
325407
s.adminChannel = targetChannel;
408+
s.credentials = targetCredentials;
326409

327410
try (ServerSocket serverSocket = new ServerSocket(0)) {
328411
s.listenPort = serverSocket.getLocalPort();
@@ -419,4 +502,34 @@ public void getOperation(
419502
.add(responseObserver);
420503
}
421504
}
505+
506+
private static class FakeCredentials extends Credentials {
507+
private static final String HEADER_NAME = "authorization";
508+
private String fakeValue = "fake-token";
509+
510+
@Override
511+
public String getAuthenticationType() {
512+
return "fake";
513+
}
514+
515+
@Override
516+
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
517+
return Map.of(HEADER_NAME, Lists.newArrayList(fakeValue));
518+
}
519+
520+
@Override
521+
public boolean hasRequestMetadata() {
522+
return true;
523+
}
524+
525+
@Override
526+
public boolean hasRequestMetadataOnly() {
527+
return true;
528+
}
529+
530+
@Override
531+
public void refresh() throws IOException {
532+
// noop
533+
}
534+
}
422535
}

bigtable/bigtable-proxy/src/test/java/com/google/cloud/bigtable/examples/proxy/utils/MetadataSubject.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import com.google.common.truth.FailureMetadata;
66
import com.google.common.truth.Subject;
77
import io.grpc.Metadata;
8+
import java.util.ArrayList;
9+
import java.util.Optional;
810
import org.jspecify.annotations.Nullable;
911

1012
public class MetadataSubject extends Subject {
@@ -36,6 +38,17 @@ public void hasValue(String key, String value) {
3638
}
3739

3840
public <T> void hasValue(Metadata.Key<T> key, T value) {
39-
check("get(" + key + ")").that(metadata.get(key)).isEqualTo(value);
41+
Iterable<T> actualValues = Optional.ofNullable(metadata.getAll(key)).orElse(new ArrayList<>());
42+
check("get(" + key + ")").that(actualValues).containsExactly(value);
43+
}
44+
45+
public void containsValue(String key, String value) {
46+
check("get(" + key + ")")
47+
.that(metadata.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)))
48+
.contains(value);
49+
}
50+
51+
public <T> void containsValue(Metadata.Key<T> key, T value) {
52+
check("get(" + key + ")").that(metadata.getAll(key)).contains(value);
4053
}
4154
}

0 commit comments

Comments
 (0)