Skip to content

Commit 3d53fc7

Browse files
author
AbdulRehman Faraj
committed
DocDB Substrait Implementation
1 parent 1f77111 commit 3d53fc7

File tree

14 files changed

+2198
-321
lines changed

14 files changed

+2198
-321
lines changed

athena-docdb/pom.xml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@
1414
<artifactId>aws-athena-federation-sdk</artifactId>
1515
<version>2022.47.1</version>
1616
<classifier>withdep</classifier>
17-
<exclusions>
18-
<!-- replaced with jcl-over-slf4j -->
19-
<exclusion>
20-
<groupId>commons-logging</groupId>
21-
<artifactId>commons-logging</artifactId>
22-
</exclusion>
23-
</exclusions>
2417
</dependency>
2518
<dependency>
2619
<groupId>com.amazonaws</groupId>
@@ -34,6 +27,13 @@
3427
<artifactId>docdb</artifactId>
3528
<version>${aws-sdk-v2.version}</version>
3629
<scope>test</scope>
30+
<exclusions>
31+
<!-- replaced with jcl-over-slf4j -->
32+
<exclusion>
33+
<groupId>commons-logging</groupId>
34+
<artifactId>commons-logging</artifactId>
35+
</exclusion>
36+
</exclusions>
3737
</dependency>
3838
<!-- https://mvnrepository.com/artifact/software.amazon.awscdk/docdb -->
3939
<dependency>

athena-docdb/src/main/java/com/amazonaws/athena/connectors/docdb/DocDBMetadataHandler.java

Lines changed: 172 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import com.amazonaws.athena.connector.lambda.data.BlockWriter;
2525
import com.amazonaws.athena.connector.lambda.domain.Split;
2626
import com.amazonaws.athena.connector.lambda.domain.TableName;
27+
import com.amazonaws.athena.connector.lambda.domain.predicate.functions.StandardFunctions;
2728
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
2829
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
2930
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
@@ -39,9 +40,14 @@
3940
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
4041
import com.amazonaws.athena.connector.lambda.metadata.MetadataRequest;
4142
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
43+
import com.amazonaws.athena.connector.lambda.metadata.optimizations.DataSourceOptimizations;
4244
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
45+
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.ComplexExpressionPushdownSubType;
46+
import com.amazonaws.athena.connector.lambda.metadata.optimizations.pushdown.LimitPushdownSubType;
4347
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
4448
import com.amazonaws.athena.connectors.docdb.qpt.DocDBQueryPassthrough;
49+
import com.fasterxml.jackson.databind.JsonNode;
50+
import com.fasterxml.jackson.databind.ObjectMapper;
4551
import com.google.common.base.Strings;
4652
import com.google.common.collect.ImmutableMap;
4753
import com.mongodb.client.MongoClient;
@@ -62,10 +68,13 @@
6268
import java.util.ArrayList;
6369
import java.util.LinkedHashSet;
6470
import java.util.List;
71+
import java.util.Map;
6572
import java.util.Set;
66-
import java.util.stream.Collectors;
6773
import java.util.stream.Stream;
6874

75+
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.ENFORCE_SSL;
76+
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.JDBC_PARAMS;
77+
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.PORT;
6978
import static com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest.UNLIMITED_PAGE_SIZE_VALUE;
7079

7180
/**
@@ -86,13 +95,13 @@ public class DocDBMetadataHandler
8695

8796
//Used to denote the 'type' of this connector for diagnostic purposes.
8897
private static final String SOURCE_TYPE = "documentdb";
98+
private static final String CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s@%s:%s/%s";
99+
private static final String ENFORCE_SSL_JDBC_PARAM = "ssl=true&ssl_ca_certs=rds-combined-ca-bundle.pem";
89100
//Field name used to store the connection string as a property on Split objects.
90101
protected static final String DOCDB_CONN_STR = "connStr";
91102
//The Env variable name used to store the default DocDB connection string if no catalog specific
92103
//env variable is set.
93104
private static final String DEFAULT_DOCDB = "default_docdb";
94-
//The env secret_name to use if defined
95-
private static final String SECRET_NAME = "secret_name";
96105
//The Glue table property that indicates that a table matching the name of an DocDB table
97106
//is indeed enabled for use by this connector.
98107
private static final String DOCDB_METADATA_FLAG = "docdb-metadata-flag";
@@ -103,6 +112,14 @@ public class DocDBMetadataHandler
103112
// used to filter out Glue databases which lack the docdb-metadata-flag in the URI.
104113
private static final DatabaseFilter DB_FILTER = (Database database) -> (database.locationUri() != null && database.locationUri().contains(DOCDB_METADATA_FLAG));
105114

115+
private static final String SECRET_ARN_KEY = "secret_arn";
116+
private static final String AUTH_DB_KEY = "AUTHENTICATION_DATABASE";
117+
118+
// JSON credential field names
119+
private static final String USERNAME_FIELD = "username";
120+
private static final String PASSWORD_FIELD = "password";
121+
public static final String HOST = "host";
122+
106123
private final GlueClient glue;
107124
private final DocDBConnectionFactory connectionFactory;
108125
private final DocDBQueryPassthrough queryPassthrough = new DocDBQueryPassthrough();
@@ -140,6 +157,16 @@ private MongoClient getOrCreateConn(MetadataRequest request)
140157
/**
141158
* Retrieves the DocDB connection details from an env variable matching the catalog name, if no such
142159
* env variable exists we fall back to the default env variable defined by DEFAULT_DOCDB.
160+
*
161+
* <p>For federated requests, this method dynamically constructs the connection string using:
162+
* <ul>
163+
* <li>Host and port from federated identity config options</li>
164+
* <li>Username and password extracted from AWS Secrets Manager (JSON format)</li>
165+
* <li>SSL enforcement and authentication database settings</li>
166+
* </ul>
167+
*
168+
* @param request The metadata request containing catalog name and federated identity information
169+
* @return The DocDB connection string, either from environment variables or dynamically constructed for federated requests
143170
*/
144171
private String getConnStr(MetadataRequest request)
145172
{
@@ -149,6 +176,11 @@ private String getConnStr(MetadataRequest request)
149176
request.getCatalogName(), DEFAULT_DOCDB);
150177
conStr = configOptions.get(DEFAULT_DOCDB);
151178
}
179+
if (isRequestFederated(request)) {
180+
logger.info("Using federated request to frame default_docdb connection string.");
181+
final Map<String, String> configOptionsFromFederatedIdentity = request.getIdentity().getConfigOptions();
182+
conStr = getConfigOptionsFromFederatedIdentity(configOptionsFromFederatedIdentity);
183+
}
152184
return conStr;
153185
}
154186

@@ -157,6 +189,30 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca
157189
{
158190
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
159191
queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, configOptions);
192+
capabilities.put(DataSourceOptimizations.SUPPORTS_LIMIT_PUSHDOWN.withSupportedSubTypes(
193+
LimitPushdownSubType.INTEGER_CONSTANT
194+
));
195+
196+
List<StandardFunctions> supportedFunctions = new ArrayList<>();
197+
supportedFunctions.add(StandardFunctions.AND_FUNCTION_NAME);
198+
supportedFunctions.add(StandardFunctions.IN_PREDICATE_FUNCTION_NAME);
199+
supportedFunctions.add(StandardFunctions.NOT_FUNCTION_NAME);
200+
supportedFunctions.add(StandardFunctions.IS_NULL_FUNCTION_NAME);
201+
supportedFunctions.add(StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME);
202+
supportedFunctions.add(StandardFunctions.GREATER_THAN_OPERATOR_FUNCTION_NAME);
203+
supportedFunctions.add(StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME);
204+
supportedFunctions.add(StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME);
205+
supportedFunctions.add(StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME);
206+
supportedFunctions.add(StandardFunctions.NOT_EQUAL_OPERATOR_FUNCTION_NAME);
207+
208+
// To check for $nin and $nor
209+
210+
capabilities.put(DataSourceOptimizations.SUPPORTS_COMPLEX_EXPRESSION_PUSHDOWN.withSupportedSubTypes(
211+
ComplexExpressionPushdownSubType.SUPPORTED_FUNCTION_EXPRESSION_TYPES
212+
.withSubTypeProperties(supportedFunctions.stream()
213+
.map(f -> f.getFunctionName().getFunctionName())
214+
.toArray(String[]::new))
215+
));
160216

161217
return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
162218
}
@@ -215,23 +271,32 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables
215271
logger.warn("doListTables: Unable to retrieve tables from AWSGlue in database/schema {}", request.getSchemaName(), e);
216272
}
217273
}
218-
219274
MongoClient client = getOrCreateConn(request);
220-
Stream<String> tableNames = doListTablesWithCommand(client, request);
221-
int startToken = request.getNextToken() != null ? Integer.parseInt(request.getNextToken()) : 0;
275+
276+
int offset = request.getNextToken() != null ? Integer.parseInt(request.getNextToken()) : 0;
222277
int pageSize = request.getPageSize();
278+
279+
Stream<String> stream = doListTablesWithCommand(client, request).skip(offset);
280+
281+
List<String> pagePlusOne;
223282
String nextToken = null;
224283

225-
if (pageSize != UNLIMITED_PAGE_SIZE_VALUE) {
226-
logger.info("Starting at token {} w/ page size {}", startToken, pageSize);
227-
tableNames = tableNames.skip(startToken).limit(request.getPageSize());
228-
nextToken = Integer.toString(startToken + pageSize);
284+
if (pageSize == UNLIMITED_PAGE_SIZE_VALUE) {
285+
pagePlusOne = stream.collect(java.util.stream.Collectors.toList());
229286
}
287+
else {
288+
pagePlusOne = stream.limit((long) pageSize + 1).collect(java.util.stream.Collectors.toList());
289+
if (pagePlusOne.size() > pageSize) {
290+
pagePlusOne = pagePlusOne.subList(0, pageSize);
291+
nextToken = Integer.toString(offset + pageSize);
292+
}
293+
}
294+
295+
List<TableName> tables = pagePlusOne.stream()
296+
.map(name -> new TableName(request.getSchemaName(), name))
297+
.collect(java.util.stream.Collectors.toList());
230298

231-
List<TableName> paginatedTables = tableNames.map(tableName -> new TableName(request.getSchemaName(), tableName)).collect(Collectors.toList());
232-
combinedTables.addAll(paginatedTables);
233-
logger.info("doListTables returned {} tables. Next token is {}", paginatedTables.size(), nextToken);
234-
return new ListTablesResponse(request.getCatalogName(), new ArrayList<>(combinedTables), nextToken);
299+
return new ListTablesResponse(request.getCatalogName(), new ArrayList<>(tables), nextToken);
235300
}
236301

237302
/**
@@ -303,8 +368,7 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
303368
catch (RuntimeException ex) {
304369
logger.warn("doGetTable: Unable to retrieve table[{}:{}] from AWS Glue.",
305370
request.getTableName().getSchemaName(),
306-
request.getTableName().getTableName(),
307-
ex);
371+
request.getTableName().getTableName());
308372
}
309373

310374
if (schema == null) {
@@ -365,4 +429,96 @@ protected Field convertField(String name, String glueType)
365429
{
366430
return GlueFieldLexer.lex(name, glueType);
367431
}
432+
433+
/**
434+
* Constructs a DocDB connection string from federated identity configuration options.
435+
*
436+
* <p>This method dynamically builds a MongoDB connection string by:
437+
* <ul>
438+
* <li>Extracting host and port from the provided config options</li>
439+
* <li>Retrieving credentials from AWS Secrets Manager using the secret ARN</li>
440+
* <li>Parsing JSON credentials to extract username and password</li>
441+
* <li>Applying SSL enforcement and authentication database settings</li>
442+
* <li>Constructing the final MongoDB connection string with proper formatting</li>
443+
* </ul>
444+
*
445+
* <p>Expected JSON credential format from Secrets Manager:
446+
* <pre>
447+
* {
448+
* "username": "mongodbadmin",
449+
* "password": "secretpassword",
450+
* "engine": "mongo",
451+
* "host": "cluster.docdb.amazonaws.com",
452+
* "port": 27017
453+
* }
454+
* </pre>
455+
*
456+
* @param configOptions Map containing federated identity configuration including:
457+
* HOST, PORT, secret_arn, JDBC_PARAMS, ENFORCE_SSL, AUTHENTICATION_DATABASE
458+
* @return Fully constructed MongoDB connection string in format: mongodb://username:password@host:port/?jdbcParams
459+
* @throws RuntimeException if JSON credential parsing fails or required parameters are missing
460+
*/
461+
private String getConfigOptionsFromFederatedIdentity(Map<String, String> configOptions)
462+
{
463+
final String secretName = getSecretNameFromArn(configOptions.get(SECRET_ARN_KEY));
464+
final String credentials = getSecret(secretName, getRequestOverrideConfig(configOptions));
465+
final String username;
466+
final String password;
467+
final String host;
468+
try {
469+
ObjectMapper mapper = new ObjectMapper();
470+
JsonNode credNode = mapper.readTree(credentials);
471+
username = credNode.get(USERNAME_FIELD).asText();
472+
password = credNode.get(PASSWORD_FIELD).asText();
473+
host = credNode.get(HOST).asText();
474+
}
475+
catch (Exception e) {
476+
logger.error("Failed to parse JSON credentials", e);
477+
throw new RuntimeException("Invalid JSON credentials format", e);
478+
}
479+
480+
String jdbcParams = configOptions.get(JDBC_PARAMS);
481+
String enforceSsl = configOptions.get(ENFORCE_SSL);
482+
String authDb = configOptions.getOrDefault(AUTH_DB_KEY, "");
483+
484+
if (Boolean.parseBoolean(enforceSsl)) {
485+
if (jdbcParams == null) {
486+
jdbcParams = ENFORCE_SSL_JDBC_PARAM;
487+
}
488+
else if (!jdbcParams.contains(ENFORCE_SSL_JDBC_PARAM)) {
489+
jdbcParams = ENFORCE_SSL_JDBC_PARAM + "&" + jdbcParams;
490+
}
491+
}
492+
493+
String connStr = String.format(CONNECTION_STRING_TEMPLATE, username, password, host, configOptions.get(PORT),
494+
authDb);
495+
if (jdbcParams != null) {
496+
connStr += "?" + jdbcParams;
497+
}
498+
return connStr;
499+
}
500+
501+
/**
502+
* Extracts the secret name from an AWS Secrets Manager ARN.
503+
*
504+
* <p>AWS Secrets Manager ARNs follow the format:
505+
* {@code arn:aws:secretsmanager:region:account:secret:name-suffix}
506+
*
507+
* <p>This method extracts the secret name by:
508+
* <ul>
509+
* <li>Splitting the ARN by colons to get individual components</li>
510+
* <li>Taking the 7th component (index 6) which contains "name-suffix"</li>
511+
* <li>Removing the suffix (everything after the last hyphen) to get the clean secret name</li>
512+
* </ul>
513+
*
514+
* @param secretArn The full AWS Secrets Manager ARN
515+
* @return The extracted secret name without the suffix
516+
* @throws ArrayIndexOutOfBoundsException if the ARN format is invalid
517+
*/
518+
private static String getSecretNameFromArn(String secretArn)
519+
{
520+
final String[] parts = secretArn.split(":");
521+
final String nameWithSuffix = parts[6];
522+
return nameWithSuffix.substring(0, nameWithSuffix.lastIndexOf('-'));
523+
}
368524
}

0 commit comments

Comments
 (0)