2424import com .amazonaws .athena .connector .lambda .data .BlockWriter ;
2525import com .amazonaws .athena .connector .lambda .domain .Split ;
2626import com .amazonaws .athena .connector .lambda .domain .TableName ;
27+ import com .amazonaws .athena .connector .lambda .domain .predicate .functions .StandardFunctions ;
2728import com .amazonaws .athena .connector .lambda .domain .spill .SpillLocation ;
2829import com .amazonaws .athena .connector .lambda .handlers .GlueMetadataHandler ;
2930import com .amazonaws .athena .connector .lambda .metadata .GetDataSourceCapabilitiesRequest ;
3940import com .amazonaws .athena .connector .lambda .metadata .ListTablesResponse ;
4041import com .amazonaws .athena .connector .lambda .metadata .MetadataRequest ;
4142import com .amazonaws .athena .connector .lambda .metadata .glue .GlueFieldLexer ;
43+ import com .amazonaws .athena .connector .lambda .metadata .optimizations .DataSourceOptimizations ;
4244import com .amazonaws .athena .connector .lambda .metadata .optimizations .OptimizationSubType ;
45+ import com .amazonaws .athena .connector .lambda .metadata .optimizations .pushdown .ComplexExpressionPushdownSubType ;
46+ import com .amazonaws .athena .connector .lambda .metadata .optimizations .pushdown .LimitPushdownSubType ;
4347import com .amazonaws .athena .connector .lambda .security .EncryptionKeyFactory ;
4448import com .amazonaws .athena .connectors .docdb .qpt .DocDBQueryPassthrough ;
49+ import com .fasterxml .jackson .databind .JsonNode ;
50+ import com .fasterxml .jackson .databind .ObjectMapper ;
4551import com .google .common .base .Strings ;
4652import com .google .common .collect .ImmutableMap ;
4753import com .mongodb .client .MongoClient ;
6268import java .util .ArrayList ;
6369import java .util .LinkedHashSet ;
6470import java .util .List ;
71+ import java .util .Map ;
6572import java .util .Set ;
66- import java .util .stream .Collectors ;
6773import java .util .stream .Stream ;
6874
75+ import static com .amazonaws .athena .connector .lambda .connection .EnvironmentConstants .ENFORCE_SSL ;
76+ import static com .amazonaws .athena .connector .lambda .connection .EnvironmentConstants .JDBC_PARAMS ;
77+ import static com .amazonaws .athena .connector .lambda .connection .EnvironmentConstants .PORT ;
6978import static com .amazonaws .athena .connector .lambda .metadata .ListTablesRequest .UNLIMITED_PAGE_SIZE_VALUE ;
7079
7180/**
@@ -86,13 +95,13 @@ public class DocDBMetadataHandler
8695
8796 //Used to denote the 'type' of this connector for diagnostic purposes.
8897 private static final String SOURCE_TYPE = "documentdb" ;
98+ private static final String CONNECTION_STRING_TEMPLATE = "mongodb://%s:%s@%s:%s/%s" ;
99+ private static final String ENFORCE_SSL_JDBC_PARAM = "ssl=true&ssl_ca_certs=rds-combined-ca-bundle.pem" ;
89100 //Field name used to store the connection string as a property on Split objects.
90101 protected static final String DOCDB_CONN_STR = "connStr" ;
91102 //The Env variable name used to store the default DocDB connection string if no catalog specific
92103 //env variable is set.
93104 private static final String DEFAULT_DOCDB = "default_docdb" ;
94- //The env secret_name to use if defined
95- private static final String SECRET_NAME = "secret_name" ;
96105 //The Glue table property that indicates that a table matching the name of an DocDB table
97106 //is indeed enabled for use by this connector.
98107 private static final String DOCDB_METADATA_FLAG = "docdb-metadata-flag" ;
@@ -103,6 +112,14 @@ public class DocDBMetadataHandler
103112 // used to filter out Glue databases which lack the docdb-metadata-flag in the URI.
104113 private static final DatabaseFilter DB_FILTER = (Database database ) -> (database .locationUri () != null && database .locationUri ().contains (DOCDB_METADATA_FLAG ));
105114
115+ private static final String SECRET_ARN_KEY = "secret_arn" ;
116+ private static final String AUTH_DB_KEY = "AUTHENTICATION_DATABASE" ;
117+
118+ // JSON credential field names
119+ private static final String USERNAME_FIELD = "username" ;
120+ private static final String PASSWORD_FIELD = "password" ;
121+ public static final String HOST = "host" ;
122+
106123 private final GlueClient glue ;
107124 private final DocDBConnectionFactory connectionFactory ;
108125 private final DocDBQueryPassthrough queryPassthrough = new DocDBQueryPassthrough ();
@@ -140,6 +157,16 @@ private MongoClient getOrCreateConn(MetadataRequest request)
140157 /**
141158 * Retrieves the DocDB connection details from an env variable matching the catalog name, if no such
142159 * env variable exists we fall back to the default env variable defined by DEFAULT_DOCDB.
160+ *
161+ * <p>For federated requests, this method dynamically constructs the connection string using:
162+ * <ul>
163+ * <li>Host and port from federated identity config options</li>
164+ * <li>Username and password extracted from AWS Secrets Manager (JSON format)</li>
165+ * <li>SSL enforcement and authentication database settings</li>
166+ * </ul>
167+ *
168+ * @param request The metadata request containing catalog name and federated identity information
169+ * @return The DocDB connection string, either from environment variables or dynamically constructed for federated requests
143170 */
144171 private String getConnStr (MetadataRequest request )
145172 {
@@ -149,6 +176,11 @@ private String getConnStr(MetadataRequest request)
149176 request .getCatalogName (), DEFAULT_DOCDB );
150177 conStr = configOptions .get (DEFAULT_DOCDB );
151178 }
179+ if (isRequestFederated (request )) {
180+ logger .info ("Using federated request to frame default_docdb connection string." );
181+ final Map <String , String > configOptionsFromFederatedIdentity = request .getIdentity ().getConfigOptions ();
182+ conStr = getConfigOptionsFromFederatedIdentity (configOptionsFromFederatedIdentity );
183+ }
152184 return conStr ;
153185 }
154186
@@ -157,6 +189,30 @@ public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAlloca
157189 {
158190 ImmutableMap .Builder <String , List <OptimizationSubType >> capabilities = ImmutableMap .builder ();
159191 queryPassthrough .addQueryPassthroughCapabilityIfEnabled (capabilities , configOptions );
192+ capabilities .put (DataSourceOptimizations .SUPPORTS_LIMIT_PUSHDOWN .withSupportedSubTypes (
193+ LimitPushdownSubType .INTEGER_CONSTANT
194+ ));
195+
196+ List <StandardFunctions > supportedFunctions = new ArrayList <>();
197+ supportedFunctions .add (StandardFunctions .AND_FUNCTION_NAME );
198+ supportedFunctions .add (StandardFunctions .IN_PREDICATE_FUNCTION_NAME );
199+ supportedFunctions .add (StandardFunctions .NOT_FUNCTION_NAME );
200+ supportedFunctions .add (StandardFunctions .IS_NULL_FUNCTION_NAME );
201+ supportedFunctions .add (StandardFunctions .EQUAL_OPERATOR_FUNCTION_NAME );
202+ supportedFunctions .add (StandardFunctions .GREATER_THAN_OPERATOR_FUNCTION_NAME );
203+ supportedFunctions .add (StandardFunctions .LESS_THAN_OPERATOR_FUNCTION_NAME );
204+ supportedFunctions .add (StandardFunctions .GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME );
205+ supportedFunctions .add (StandardFunctions .LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME );
206+ supportedFunctions .add (StandardFunctions .NOT_EQUAL_OPERATOR_FUNCTION_NAME );
207+
208+ // To check for $nin and $nor
209+
210+ capabilities .put (DataSourceOptimizations .SUPPORTS_COMPLEX_EXPRESSION_PUSHDOWN .withSupportedSubTypes (
211+ ComplexExpressionPushdownSubType .SUPPORTED_FUNCTION_EXPRESSION_TYPES
212+ .withSubTypeProperties (supportedFunctions .stream ()
213+ .map (f -> f .getFunctionName ().getFunctionName ())
214+ .toArray (String []::new ))
215+ ));
160216
161217 return new GetDataSourceCapabilitiesResponse (request .getCatalogName (), capabilities .build ());
162218 }
@@ -215,23 +271,32 @@ public ListTablesResponse doListTables(BlockAllocator blockAllocator, ListTables
215271 logger .warn ("doListTables: Unable to retrieve tables from AWSGlue in database/schema {}" , request .getSchemaName (), e );
216272 }
217273 }
218-
219274 MongoClient client = getOrCreateConn (request );
220- Stream < String > tableNames = doListTablesWithCommand ( client , request );
221- int startToken = request .getNextToken () != null ? Integer .parseInt (request .getNextToken ()) : 0 ;
275+
276+ int offset = request .getNextToken () != null ? Integer .parseInt (request .getNextToken ()) : 0 ;
222277 int pageSize = request .getPageSize ();
278+
279+ Stream <String > stream = doListTablesWithCommand (client , request ).skip (offset );
280+
281+ List <String > pagePlusOne ;
223282 String nextToken = null ;
224283
225- if (pageSize != UNLIMITED_PAGE_SIZE_VALUE ) {
226- logger .info ("Starting at token {} w/ page size {}" , startToken , pageSize );
227- tableNames = tableNames .skip (startToken ).limit (request .getPageSize ());
228- nextToken = Integer .toString (startToken + pageSize );
284+ if (pageSize == UNLIMITED_PAGE_SIZE_VALUE ) {
285+ pagePlusOne = stream .collect (java .util .stream .Collectors .toList ());
229286 }
287+ else {
288+ pagePlusOne = stream .limit ((long ) pageSize + 1 ).collect (java .util .stream .Collectors .toList ());
289+ if (pagePlusOne .size () > pageSize ) {
290+ pagePlusOne = pagePlusOne .subList (0 , pageSize );
291+ nextToken = Integer .toString (offset + pageSize );
292+ }
293+ }
294+
295+ List <TableName > tables = pagePlusOne .stream ()
296+ .map (name -> new TableName (request .getSchemaName (), name ))
297+ .collect (java .util .stream .Collectors .toList ());
230298
231- List <TableName > paginatedTables = tableNames .map (tableName -> new TableName (request .getSchemaName (), tableName )).collect (Collectors .toList ());
232- combinedTables .addAll (paginatedTables );
233- logger .info ("doListTables returned {} tables. Next token is {}" , paginatedTables .size (), nextToken );
234- return new ListTablesResponse (request .getCatalogName (), new ArrayList <>(combinedTables ), nextToken );
299+ return new ListTablesResponse (request .getCatalogName (), new ArrayList <>(tables ), nextToken );
235300 }
236301
237302 /**
@@ -303,8 +368,7 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques
303368 catch (RuntimeException ex ) {
304369 logger .warn ("doGetTable: Unable to retrieve table[{}:{}] from AWS Glue." ,
305370 request .getTableName ().getSchemaName (),
306- request .getTableName ().getTableName (),
307- ex );
371+ request .getTableName ().getTableName ());
308372 }
309373
310374 if (schema == null ) {
@@ -365,4 +429,96 @@ protected Field convertField(String name, String glueType)
365429 {
366430 return GlueFieldLexer .lex (name , glueType );
367431 }
432+
433+ /**
434+ * Constructs a DocDB connection string from federated identity configuration options.
435+ *
436+ * <p>This method dynamically builds a MongoDB connection string by:
437+ * <ul>
438+ * <li>Extracting host and port from the provided config options</li>
439+ * <li>Retrieving credentials from AWS Secrets Manager using the secret ARN</li>
440+ * <li>Parsing JSON credentials to extract username and password</li>
441+ * <li>Applying SSL enforcement and authentication database settings</li>
442+ * <li>Constructing the final MongoDB connection string with proper formatting</li>
443+ * </ul>
444+ *
445+ * <p>Expected JSON credential format from Secrets Manager:
446+ * <pre>
447+ * {
448+ * "username": "mongodbadmin",
449+ * "password": "secretpassword",
450+ * "engine": "mongo",
451+ * "host": "cluster.docdb.amazonaws.com",
452+ * "port": 27017
453+ * }
454+ * </pre>
455+ *
456+ * @param configOptions Map containing federated identity configuration including:
457+ * HOST, PORT, secret_arn, JDBC_PARAMS, ENFORCE_SSL, AUTHENTICATION_DATABASE
458+ * @return Fully constructed MongoDB connection string in format: mongodb://username:password@host:port/?jdbcParams
459+ * @throws RuntimeException if JSON credential parsing fails or required parameters are missing
460+ */
461+ private String getConfigOptionsFromFederatedIdentity (Map <String , String > configOptions )
462+ {
463+ final String secretName = getSecretNameFromArn (configOptions .get (SECRET_ARN_KEY ));
464+ final String credentials = getSecret (secretName , getRequestOverrideConfig (configOptions ));
465+ final String username ;
466+ final String password ;
467+ final String host ;
468+ try {
469+ ObjectMapper mapper = new ObjectMapper ();
470+ JsonNode credNode = mapper .readTree (credentials );
471+ username = credNode .get (USERNAME_FIELD ).asText ();
472+ password = credNode .get (PASSWORD_FIELD ).asText ();
473+ host = credNode .get (HOST ).asText ();
474+ }
475+ catch (Exception e ) {
476+ logger .error ("Failed to parse JSON credentials" , e );
477+ throw new RuntimeException ("Invalid JSON credentials format" , e );
478+ }
479+
480+ String jdbcParams = configOptions .get (JDBC_PARAMS );
481+ String enforceSsl = configOptions .get (ENFORCE_SSL );
482+ String authDb = configOptions .getOrDefault (AUTH_DB_KEY , "" );
483+
484+ if (Boolean .parseBoolean (enforceSsl )) {
485+ if (jdbcParams == null ) {
486+ jdbcParams = ENFORCE_SSL_JDBC_PARAM ;
487+ }
488+ else if (!jdbcParams .contains (ENFORCE_SSL_JDBC_PARAM )) {
489+ jdbcParams = ENFORCE_SSL_JDBC_PARAM + "&" + jdbcParams ;
490+ }
491+ }
492+
493+ String connStr = String .format (CONNECTION_STRING_TEMPLATE , username , password , host , configOptions .get (PORT ),
494+ authDb );
495+ if (jdbcParams != null ) {
496+ connStr += "?" + jdbcParams ;
497+ }
498+ return connStr ;
499+ }
500+
501+ /**
502+ * Extracts the secret name from an AWS Secrets Manager ARN.
503+ *
504+ * <p>AWS Secrets Manager ARNs follow the format:
505+ * {@code arn:aws:secretsmanager:region:account:secret:name-suffix}
506+ *
507+ * <p>This method extracts the secret name by:
508+ * <ul>
509+ * <li>Splitting the ARN by colons to get individual components</li>
510+ * <li>Taking the 7th component (index 6) which contains "name-suffix"</li>
511+ * <li>Removing the suffix (everything after the last hyphen) to get the clean secret name</li>
512+ * </ul>
513+ *
514+ * @param secretArn The full AWS Secrets Manager ARN
515+ * @return The extracted secret name without the suffix
516+ * @throws ArrayIndexOutOfBoundsException if the ARN format is invalid
517+ */
518+ private static String getSecretNameFromArn (String secretArn )
519+ {
520+ final String [] parts = secretArn .split (":" );
521+ final String nameWithSuffix = parts [6 ];
522+ return nameWithSuffix .substring (0 , nameWithSuffix .lastIndexOf ('-' ));
523+ }
368524}
0 commit comments