Skip to content

Commit b304b14

Browse files
feat: move credential handling responsibility to the proxy
stack-info: PR: #9690, branch: igorbernstein2/stack/3
1 parent e70f70b commit b304b14

File tree

5 files changed

+163
-9
lines changed

5 files changed

+163
-9
lines changed

bigtable/bigtable-proxy/pom.xml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@
5757
<groupId>io.grpc</groupId>
5858
<artifactId>grpc-netty-shaded</artifactId>
5959
</dependency>
60+
<dependency>
61+
<groupId>io.grpc</groupId>
62+
<artifactId>grpc-auth</artifactId>
63+
</dependency>
64+
<dependency>
65+
<groupId>com.google.auth</groupId>
66+
<artifactId>google-auth-library-oauth2-http</artifactId>
67+
</dependency>
6068

6169

6270
<!-- service defs -->

bigtable/bigtable-proxy/src/main/java/com/google/cloud/bigtable/examples/proxy/commands/Serve.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@
1616

1717
package com.google.cloud.bigtable.examples.proxy.commands;
1818

19+
import com.google.auth.Credentials;
20+
import com.google.auth.oauth2.GoogleCredentials;
1921
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
2022
import com.google.bigtable.admin.v2.BigtableTableAdminGrpc;
2123
import com.google.bigtable.v2.BigtableGrpc;
2224
import com.google.cloud.bigtable.examples.proxy.core.ProxyHandler;
2325
import com.google.cloud.bigtable.examples.proxy.core.Registry;
2426
import com.google.common.collect.ImmutableMap;
2527
import com.google.longrunning.OperationsGrpc;
28+
import io.grpc.CallCredentials;
2629
import io.grpc.InsecureServerCredentials;
2730
import io.grpc.ManagedChannel;
2831
import io.grpc.ManagedChannelBuilder;
2932
import io.grpc.Server;
3033
import io.grpc.ServerCallHandler;
34+
import io.grpc.auth.MoreCallCredentials;
3135
import io.grpc.netty.shaded.io.grpc.netty.NettyServerBuilder;
3236
import java.io.IOException;
3337
import java.net.InetSocketAddress;
@@ -67,6 +71,7 @@ public class Serve implements Callable<Void> {
6771

6872
ManagedChannel adminChannel = null;
6973
ManagedChannel dataChannel = null;
74+
Credentials credentials = null;
7075
Server server;
7176

7277
@Override
@@ -95,17 +100,21 @@ void start() throws IOException {
95100
.disableRetry()
96101
.build();
97102
}
103+
if (credentials == null) {
104+
credentials = GoogleCredentials.getApplicationDefault();
105+
}
106+
CallCredentials callCredentials = MoreCallCredentials.from(credentials);
98107

99108
Map<String, ServerCallHandler<byte[], byte[]>> serviceMap =
100109
ImmutableMap.of(
101110
BigtableGrpc.SERVICE_NAME,
102-
new ProxyHandler<>(dataChannel),
111+
new ProxyHandler<>(dataChannel, callCredentials),
103112
BigtableInstanceAdminGrpc.SERVICE_NAME,
104-
new ProxyHandler<>(adminChannel),
113+
new ProxyHandler<>(adminChannel, callCredentials),
105114
BigtableTableAdminGrpc.SERVICE_NAME,
106-
new ProxyHandler<>(adminChannel),
115+
new ProxyHandler<>(adminChannel, callCredentials),
107116
OperationsGrpc.SERVICE_NAME,
108-
new ProxyHandler<>(adminChannel));
117+
new ProxyHandler<>(adminChannel, callCredentials));
109118

110119
server =
111120
NettyServerBuilder.forAddress(

bigtable/bigtable-proxy/src/main/java/com/google/cloud/bigtable/examples/proxy/core/ProxyHandler.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
package com.google.cloud.bigtable.examples.proxy.core;
1818

19+
import io.grpc.CallCredentials;
1920
import io.grpc.CallOptions;
2021
import io.grpc.Channel;
2122
import io.grpc.ClientCall;
@@ -25,15 +26,23 @@
2526

2627
/** A factory pairing of an incoming server call to an outgoing client call. */
2728
public final class ProxyHandler<ReqT, RespT> implements ServerCallHandler<ReqT, RespT> {
29+
private static final Metadata.Key<String> AUTHORIZATION_KEY =
30+
Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER);
31+
2832
private final Channel channel;
33+
private final CallCredentials callCredentials;
2934

30-
public ProxyHandler(Channel channel) {
35+
public ProxyHandler(Channel channel, CallCredentials callCredentials) {
3136
this.channel = channel;
37+
this.callCredentials = callCredentials;
3238
}
3339

3440
@Override
3541
public ServerCall.Listener<ReqT> startCall(ServerCall<ReqT, RespT> serverCall, Metadata headers) {
36-
CallOptions callOptions = CallOptions.DEFAULT;
42+
// Strip incoming credentials
43+
headers.removeAll(AUTHORIZATION_KEY);
44+
// Inject proxy credentials
45+
CallOptions callOptions = CallOptions.DEFAULT.withCallCredentials(callCredentials);
3746

3847
ClientCall<ReqT, RespT> clientCall =
3948
channel.newCall(serverCall.getMethodDescriptor(), callOptions);

bigtable/bigtable-proxy/src/test/java/com/google/cloud/bigtable/examples/proxy/commands/ServeTest.java

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static com.google.common.truth.Truth.assertThat;
2222
import static com.google.common.truth.Truth.assertWithMessage;
2323

24+
import com.google.auth.Credentials;
2425
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc;
2526
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminFutureStub;
2627
import com.google.bigtable.admin.v2.BigtableInstanceAdminGrpc.BigtableInstanceAdminImplBase;
@@ -36,6 +37,7 @@
3637
import com.google.bigtable.v2.BigtableGrpc.BigtableImplBase;
3738
import com.google.bigtable.v2.CheckAndMutateRowRequest;
3839
import com.google.bigtable.v2.CheckAndMutateRowResponse;
40+
import com.google.common.collect.Lists;
3941
import com.google.common.collect.Range;
4042
import com.google.common.util.concurrent.ListenableFuture;
4143
import com.google.longrunning.GetOperationRequest;
@@ -67,7 +69,10 @@
6769
import io.grpc.testing.GrpcCleanupRule;
6870
import java.io.IOException;
6971
import java.net.ServerSocket;
72+
import java.net.URI;
7073
import java.time.Duration;
74+
import java.util.List;
75+
import java.util.Map;
7176
import java.util.UUID;
7277
import java.util.concurrent.BlockingDeque;
7378
import java.util.concurrent.BlockingQueue;
@@ -100,6 +105,7 @@ public class ServeTest {
100105
private FakeTableAdminService tableAdminService;
101106
private OperationService operationService;
102107
private ManagedChannel fakeServiceChannel;
108+
private FakeCredentials fakeCredentials;
103109

104110
// Proxy
105111
private Serve serve;
@@ -115,6 +121,8 @@ public void setUp() throws IOException {
115121
tableAdminService = new FakeTableAdminService();
116122
operationService = new OperationService();
117123

124+
fakeCredentials = new FakeCredentials();
125+
118126
grpcCleanup.register(
119127
InProcessServerBuilder.forName(targetServerName)
120128
.intercept(callContextInterceptor)
@@ -131,7 +139,9 @@ public void setUp() throws IOException {
131139
InProcessChannelBuilder.forName(targetServerName).usePlaintext().build());
132140

133141
// Create the proxy
134-
serve = createAndStartCommand(fakeServiceChannel);
142+
// Inject fakes for upstream calls. For unit tests we want to shim communications to the
143+
// bigtable service.
144+
serve = createAndStartCommand(fakeServiceChannel, fakeCredentials);
135145

136146
proxyChannel =
137147
grpcCleanup.register(
@@ -363,11 +373,86 @@ public void testDeadlinePropagation()
363373
.isIn(Range.closed(Duration.ofMinutes(9), Duration.ofMinutes(10)));
364374
}
365375

366-
private static Serve createAndStartCommand(ManagedChannel targetChannel) throws IOException {
376+
@Test
377+
public void testCredentials() throws InterruptedException, ExecutionException, TimeoutException {
378+
BigtableFutureStub proxyStub = BigtableGrpc.newFutureStub(proxyChannel);
379+
380+
CheckAndMutateRowRequest request =
381+
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
382+
final ListenableFuture<CheckAndMutateRowResponse> proxyFuture =
383+
proxyStub.checkAndMutateRow(request);
384+
StreamObserver<CheckAndMutateRowResponse> serverObserver =
385+
dataService
386+
.calls
387+
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
388+
.poll(1, TimeUnit.SECONDS);
389+
390+
assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
391+
.that(serverObserver)
392+
.isNotNull();
393+
394+
serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
395+
serverObserver.onCompleted();
396+
proxyFuture.get(1, TimeUnit.SECONDS);
397+
398+
assertThat(metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS))
399+
.hasValue("authorization", "fake-token");
400+
}
401+
402+
@Test
403+
public void testCredentialsClobber()
404+
throws InterruptedException, ExecutionException, TimeoutException {
405+
BigtableFutureStub proxyStub =
406+
BigtableGrpc.newFutureStub(proxyChannel)
407+
.withInterceptors(
408+
new ClientInterceptor() {
409+
@Override
410+
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
411+
MethodDescriptor<ReqT, RespT> methodDescriptor,
412+
CallOptions callOptions,
413+
Channel channel) {
414+
return new SimpleForwardingClientCall<ReqT, RespT>(
415+
channel.newCall(methodDescriptor, callOptions)) {
416+
@Override
417+
public void start(Listener<RespT> responseListener, Metadata headers) {
418+
headers.put(
419+
Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER),
420+
"pre-proxied-value");
421+
super.start(responseListener, headers);
422+
}
423+
};
424+
}
425+
});
426+
427+
CheckAndMutateRowRequest request =
428+
CheckAndMutateRowRequest.newBuilder().setTableName("some-table").build();
429+
final ListenableFuture<CheckAndMutateRowResponse> proxyFuture =
430+
proxyStub.checkAndMutateRow(request);
431+
StreamObserver<CheckAndMutateRowResponse> serverObserver =
432+
dataService
433+
.calls
434+
.computeIfAbsent(request, (ignored) -> new LinkedBlockingDeque<>())
435+
.poll(1, TimeUnit.SECONDS);
436+
437+
assertWithMessage("Timed out waiting for the proxied RPC on the fake server")
438+
.that(serverObserver)
439+
.isNotNull();
440+
441+
serverObserver.onNext(CheckAndMutateRowResponse.newBuilder().setPredicateMatched(true).build());
442+
serverObserver.onCompleted();
443+
proxyFuture.get(1, TimeUnit.SECONDS);
444+
445+
Metadata serverRequestHeaders = metadataInterceptor.requestHeaders.poll(1, TimeUnit.SECONDS);
446+
assertThat(serverRequestHeaders).hasValue("authorization", "fake-token");
447+
}
448+
449+
private static Serve createAndStartCommand(
450+
ManagedChannel targetChannel, FakeCredentials targetCredentials) throws IOException {
367451
for (int i = 10; i >= 0; i--) {
368452
Serve s = new Serve();
369453
s.dataChannel = targetChannel;
370454
s.adminChannel = targetChannel;
455+
s.credentials = targetCredentials;
371456

372457
try (ServerSocket serverSocket = new ServerSocket(0)) {
373458
s.listenPort = serverSocket.getLocalPort();
@@ -477,4 +562,34 @@ public void getOperation(
477562
.add(responseObserver);
478563
}
479564
}
565+
566+
private static class FakeCredentials extends Credentials {
567+
private static final String HEADER_NAME = "authorization";
568+
private String fakeValue = "fake-token";
569+
570+
@Override
571+
public String getAuthenticationType() {
572+
return "fake";
573+
}
574+
575+
@Override
576+
public Map<String, List<String>> getRequestMetadata(URI uri) throws IOException {
577+
return Map.of(HEADER_NAME, Lists.newArrayList(fakeValue));
578+
}
579+
580+
@Override
581+
public boolean hasRequestMetadata() {
582+
return true;
583+
}
584+
585+
@Override
586+
public boolean hasRequestMetadataOnly() {
587+
return true;
588+
}
589+
590+
@Override
591+
public void refresh() throws IOException {
592+
// noop
593+
}
594+
}
480595
}

bigtable/bigtable-proxy/src/test/java/com/google/cloud/bigtable/examples/proxy/utils/MetadataSubject.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import com.google.common.truth.FailureMetadata;
2222
import com.google.common.truth.Subject;
2323
import io.grpc.Metadata;
24+
import java.util.ArrayList;
25+
import java.util.Optional;
2426
import org.jspecify.annotations.Nullable;
2527

2628
public class MetadataSubject extends Subject {
@@ -52,6 +54,17 @@ public void hasValue(String key, String value) {
5254
}
5355

5456
public <T> void hasValue(Metadata.Key<T> key, T value) {
55-
check("get(" + key + ")").that(metadata.get(key)).isEqualTo(value);
57+
Iterable<T> actualValues = Optional.ofNullable(metadata.getAll(key)).orElse(new ArrayList<>());
58+
check("get(" + key + ")").that(actualValues).containsExactly(value);
59+
}
60+
61+
public void containsValue(String key, String value) {
62+
check("get(" + key + ")")
63+
.that(metadata.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)))
64+
.contains(value);
65+
}
66+
67+
public <T> void containsValue(Metadata.Key<T> key, T value) {
68+
check("get(" + key + ")").that(metadata.getAll(key)).contains(value);
5669
}
5770
}

0 commit comments

Comments
 (0)