|
25 | 25 |
|
26 | 26 | import com.google.api.core.ApiFunction; |
27 | 27 | import com.google.api.gax.rpc.ApiCallContext; |
| 28 | +import com.google.api.gax.rpc.HeaderProvider; |
28 | 29 | import com.google.auth.oauth2.AccessToken; |
29 | 30 | import com.google.auth.oauth2.OAuth2Credentials; |
30 | 31 | import com.google.cloud.spanner.DatabaseAdminClient; |
@@ -151,6 +152,7 @@ public class GapicSpannerRpcTest { |
151 | 152 | private Server server; |
152 | 153 | private InetSocketAddress address; |
153 | 154 | private final Map<SpannerRpc.Option, Object> optionsMap = new HashMap<>(); |
| 155 | + private Metadata seenHeaders; |
154 | 156 |
|
155 | 157 | @BeforeClass |
156 | 158 | public static void checkNotEmulator() { |
@@ -183,6 +185,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall( |
183 | 185 | ServerCall<ReqT, RespT> call, |
184 | 186 | Metadata headers, |
185 | 187 | ServerCallHandler<ReqT, RespT> next) { |
| 188 | + seenHeaders = headers; |
186 | 189 | String auth = |
187 | 190 | headers.get(Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER)); |
188 | 191 | assertThat(auth).isEqualTo("Bearer " + VARIABLE_OAUTH_TOKEN); |
@@ -502,6 +505,34 @@ public void testAdminRequestsLimitExceededRetryAlgorithm() { |
502 | 505 | assertThat(alg.shouldRetry(new Exception("random exception"), null)).isFalse(); |
503 | 506 | } |
504 | 507 |
|
| 508 | + @Test |
| 509 | + public void testCustomUserAgent() { |
| 510 | + for (final String headerId : new String[] {"user-agent", "User-Agent", "USER-AGENT"}) { |
| 511 | + final HeaderProvider userAgentHeaderProvider = |
| 512 | + new HeaderProvider() { |
| 513 | + @Override |
| 514 | + public Map<String, String> getHeaders() { |
| 515 | + final Map<String, String> headers = new HashMap<>(); |
| 516 | + headers.put(headerId, "test-agent"); |
| 517 | + return headers; |
| 518 | + } |
| 519 | + }; |
| 520 | + final SpannerOptions options = |
| 521 | + createSpannerOptions().toBuilder().setHeaderProvider(userAgentHeaderProvider).build(); |
| 522 | + try (Spanner spanner = options.getService()) { |
| 523 | + final DatabaseClient databaseClient = |
| 524 | + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); |
| 525 | + |
| 526 | + try (final ResultSet rs = databaseClient.singleUse().executeQuery(SELECT1AND2)) { |
| 527 | + rs.next(); |
| 528 | + } |
| 529 | + |
| 530 | + assertThat(seenHeaders.get(Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER))) |
| 531 | + .contains("test-agent"); |
| 532 | + } |
| 533 | + } |
| 534 | + } |
| 535 | + |
505 | 536 | @SuppressWarnings("rawtypes") |
506 | 537 | private SpannerOptions createSpannerOptions() { |
507 | 538 | String endpoint = address.getHostString() + ":" + server.getPort(); |
|
0 commit comments