Skip to content

Commit d294610

Browse files
committed
Add secretmanager module and support for databases connections. #402
1 parent baca7fb commit d294610

File tree

12 files changed

+267
-19
lines changed

12 files changed

+267
-19
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
1313

1414
[![Checked with mypy](http://www.mypy-lang.org/static/mypy_badge.svg)](http://mypy-lang.org/)
15-
[![Coverage](https://img.shields.io/badge/coverage-92%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
15+
[![Coverage](https://img.shields.io/badge/coverage-91%25-brightgreen.svg)](https://pypi.org/project/awswrangler/)
1616
![Static Checking](https://github.com/awslabs/aws-data-wrangler/workflows/Static%20Checking/badge.svg?branch=master)
1717
[![Documentation Status](https://readthedocs.org/projects/aws-data-wrangler/badge/?version=latest)](https://aws-data-wrangler.readthedocs.io/?badge=latest)
1818

@@ -138,6 +138,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3
138138
- [Amazon CloudWatch Logs](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-cloudwatch-logs)
139139
- [Amazon QuickSight](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-quicksight)
140140
- [AWS STS](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#aws-sts)
141+
- [AWS Secrets Manager](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#aws-secrets-manager)
141142
- [**License**](https://github.com/awslabs/aws-data-wrangler/blob/master/LICENSE.txt)
142143
- [**Contributing**](https://github.com/awslabs/aws-data-wrangler/blob/master/CONTRIBUTING.md)
143144
- [**Legacy Docs** (pre-1.0.0)](https://aws-data-wrangler.readthedocs.io/en/0.3.3/)
@@ -202,6 +203,6 @@ Please [send a Pull Request](https://github.com/awslabs/aws-data-wrangler/edit/m
202203

203204
**Amazon SageMaker Data Wrangler** is a new SageMaker Studio feature that has a similar name but has a different purpose than the **AWS Data Wrangler** open source project.
204205

205-
* **AWS Data Wrangler** is open source, runs anywhere, and is focused on code.
206+
- **AWS Data Wrangler** is open source, runs anywhere, and is focused on code.
206207

207-
* **Amazon SageMaker Data Wrangler** is specific for the SageMaker Studio environment and is focused on a visual interface.
208+
- **Amazon SageMaker Data Wrangler** is specific for the SageMaker Studio environment and is focused on a visual interface.

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
quicksight,
1919
redshift,
2020
s3,
21+
secretsmanager,
2122
sts,
2223
timestream,
2324
)
@@ -36,6 +37,7 @@
3637
"redshift",
3738
"mysql",
3839
"postgresql",
40+
"secretsmanager",
3941
"config",
4042
"timestream",
4143
"__description__",

awswrangler/_databases.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pandas as pd
99
import pyarrow as pa
1010

11-
from awswrangler import _data_types
11+
from awswrangler import _data_types, _utils, exceptions, secretsmanager
1212
from awswrangler.catalog import get_connection
1313

1414
_logger: logging.Logger = logging.getLogger(__name__)
@@ -25,12 +25,15 @@ class ConnectionAttributes(NamedTuple):
2525
database: str
2626

2727

28-
def get_connection_attributes(
29-
connection: str,
30-
catalog_id: Optional[str] = None,
31-
boto3_session: Optional[boto3.Session] = None,
28+
def _get_dbname(cluster_id: str, boto3_session: Optional[boto3.Session] = None) -> str:
29+
client_redshift: boto3.client = _utils.client(service_name="redshift", session=boto3_session)
30+
res: Dict[str, Any] = client_redshift.describe_clusters(ClusterIdentifier=cluster_id)["Clusters"][0]
31+
return cast(str, res["DBName"])
32+
33+
34+
def _get_connection_attributes_from_catalog(
35+
connection: str, catalog_id: Optional[str], dbname: Optional[str], boto3_session: Optional[boto3.Session]
3236
) -> ConnectionAttributes:
33-
"""Get Connection Attributes."""
3437
details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[
3538
"ConnectionProperties"
3639
]
@@ -41,7 +44,51 @@ def get_connection_attributes(
4144
password=quote_plus(details["PASSWORD"]),
4245
host=details["JDBC_CONNECTION_URL"].split(":")[2].replace("/", ""),
4346
port=int(port),
44-
database=database,
47+
database=dbname if dbname is not None else database,
48+
)
49+
50+
51+
def _get_connection_attributes_from_secrets_manager(
52+
secret_id: str, dbname: Optional[str], boto3_session: Optional[boto3.Session]
53+
) -> ConnectionAttributes:
54+
secret_value: Dict[str, Any] = secretsmanager.get_secret_json(name=secret_id, boto3_session=boto3_session)
55+
kind: str = secret_value["engine"]
56+
if dbname is not None:
57+
_dbname: str = dbname
58+
elif "dbname" in secret_value:
59+
_dbname = secret_value["dbname"]
60+
else:
61+
if kind != "redshift":
62+
raise exceptions.InvalidConnection(f"The secret {secret_id} MUST have a dbname property.")
63+
_dbname = _get_dbname(cluster_id=secret_value["dbClusterIdentifier"], boto3_session=boto3_session)
64+
return ConnectionAttributes(
65+
kind=kind,
66+
user=quote_plus(secret_value["username"]),
67+
password=quote_plus(secret_value["password"]),
68+
host=secret_value["host"],
69+
port=secret_value["port"],
70+
database=_dbname,
71+
)
72+
73+
74+
def get_connection_attributes(
75+
connection: Optional[str] = None,
76+
secret_id: Optional[str] = None,
77+
catalog_id: Optional[str] = None,
78+
dbname: Optional[str] = None,
79+
boto3_session: Optional[boto3.Session] = None,
80+
) -> ConnectionAttributes:
81+
"""Get Connection Attributes."""
82+
if connection is None and secret_id is None:
83+
raise exceptions.InvalidArgumentCombination(
84+
"Failed attempt to connect. You MUST pass a connection name (Glue Catalog) OR a secret_id as argument."
85+
)
86+
if connection is not None:
87+
return _get_connection_attributes_from_catalog(
88+
connection=connection, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
89+
)
90+
return _get_connection_attributes_from_secrets_manager(
91+
secret_id=cast(str, secret_id), dbname=dbname, boto3_session=boto3_session
4592
)
4693

4794

awswrangler/mysql.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,10 @@ def _create_table(
6767

6868

6969
def connect(
70-
connection: str,
70+
connection: Optional[str] = None,
71+
secret_id: Optional[str] = None,
7172
catalog_id: Optional[str] = None,
73+
dbname: Optional[str] = None,
7274
boto3_session: Optional[boto3.Session] = None,
7375
read_timeout: Optional[int] = None,
7476
write_timeout: Optional[int] = None,
@@ -82,9 +84,14 @@ def connect(
8284
----------
8385
connection : str
8486
Glue Catalog Connection name.
87+
secret_id: Optional[str]:
88+
Specifies the secret containing the version that you want to retrieve.
89+
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
8590
catalog_id : str, optional
8691
The ID of the Data Catalog.
8792
If none is provided, the AWS account ID is used by default.
93+
dbname: Optional[str]
94+
Optional database name to overwrite the stored one.
8895
boto3_session : boto3.Session(), optional
8996
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
9097
read_timeout: Optional[int]
@@ -117,7 +124,7 @@ def connect(
117124
118125
"""
119126
attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
120-
connection=connection, catalog_id=catalog_id, boto3_session=boto3_session
127+
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
121128
)
122129
if attrs.kind != "mysql":
123130
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")

awswrangler/postgresql.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ def _create_table(
7171

7272

7373
def connect(
74-
connection: str,
74+
connection: Optional[str] = None,
75+
secret_id: Optional[str] = None,
7576
catalog_id: Optional[str] = None,
77+
dbname: Optional[str] = None,
7678
boto3_session: Optional[boto3.Session] = None,
7779
ssl_context: Optional[Dict[Any, Any]] = None,
7880
timeout: Optional[int] = None,
@@ -84,11 +86,16 @@ def connect(
8486
8587
Parameters
8688
----------
87-
connection : str
89+
connection : Optional[str]
8890
Glue Catalog Connection name.
91+
secret_id: Optional[str]:
92+
Specifies the secret containing the version that you want to retrieve.
93+
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
8994
catalog_id : str, optional
9095
The ID of the Data Catalog.
9196
If none is provided, the AWS account ID is used by default.
97+
dbname: Optional[str]
98+
Optional database name to overwrite the stored one.
9299
boto3_session : boto3.Session(), optional
93100
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
94101
ssl_context: Optional[Dict]
@@ -121,7 +128,7 @@ def connect(
121128
122129
"""
123130
attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
124-
connection=connection, catalog_id=catalog_id, boto3_session=boto3_session
131+
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
125132
)
126133
if attrs.kind != "postgresql":
127134
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)")

awswrangler/redshift.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -262,25 +262,37 @@ def _read_parquet_iterator(
262262

263263

264264
def connect(
265-
connection: str,
265+
connection: Optional[str] = None,
266+
secret_id: Optional[str] = None,
266267
catalog_id: Optional[str] = None,
268+
dbname: Optional[str] = None,
267269
boto3_session: Optional[boto3.Session] = None,
268270
ssl: bool = True,
269271
timeout: Optional[int] = None,
270272
max_prepared_statements: int = 1000,
271273
tcp_keepalive: bool = True,
272274
) -> redshift_connector.Connection:
273-
"""Return a redshift_connector connection from a Glue Catalog Connection.
275+
"""Return a redshift_connector connection from a Glue Catalog or Secret Manager.
276+
277+
Note
278+
----
279+
You MUST pass a `connection` OR `secret_id`
280+
274281
275282
https://github.com/aws/amazon-redshift-python-driver
276283
277284
Parameters
278285
----------
279-
connection : str
286+
connection : Optional[str]
280287
Glue Catalog Connection name.
288+
secret_id: Optional[str]:
289+
Specifies the secret containing the version that you want to retrieve.
290+
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
281291
catalog_id : str, optional
282292
The ID of the Data Catalog.
283293
If none is provided, the AWS account ID is used by default.
294+
dbname: Optional[str]
295+
Optional database name to overwrite the stored one.
284296
boto3_session : boto3.Session(), optional
285297
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
286298
ssl: bool
@@ -307,16 +319,27 @@ def connect(
307319
308320
Examples
309321
--------
322+
Fetching Redshit connection from Glue Catalog
323+
310324
>>> import awswrangler as wr
311325
>>> con = wr.redshift.connect("MY_GLUE_CONNECTION")
312326
>>> with con.cursor() as cursor:
313327
>>> cursor.execute("SELECT 1")
314328
>>> print(cursor.fetchall())
315329
>>> con.close()
316330
331+
Fetching Redshit connection from Secrets Manager
332+
333+
>>> import awswrangler as wr
334+
>>> con = wr.redshift.connect(secret_id="MY_SECRET")
335+
>>> with con.cursor() as cursor:
336+
>>> cursor.execute("SELECT 1")
337+
>>> print(cursor.fetchall())
338+
>>> con.close()
339+
317340
"""
318341
attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes(
319-
connection=connection, catalog_id=catalog_id, boto3_session=boto3_session
342+
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
320343
)
321344
if attrs.kind != "redshift":
322345
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)")

awswrangler/secretsmanager.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
"""Secrets Manager module."""
2+
3+
import base64
4+
import json
5+
import logging
6+
from typing import Any, Dict, Optional, Union, cast
7+
8+
import boto3
9+
10+
from awswrangler import _utils
11+
12+
_logger: logging.Logger = logging.getLogger(__name__)
13+
14+
15+
def get_secret(name: str, boto3_session: Optional[boto3.Session] = None) -> Union[str, bytes]:
16+
"""Get secret value.
17+
18+
Parameters
19+
----------
20+
name: str:
21+
Specifies the secret containing the version that you want to retrieve.
22+
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
23+
boto3_session : boto3.Session(), optional
24+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
25+
26+
Returns
27+
-------
28+
Union[str, bytes]
29+
Secret value.
30+
31+
Examples
32+
--------
33+
>>> import awswrangler as wr
34+
>>> value = wr.secretsmanager.get_secret("my-secret")
35+
36+
"""
37+
session: boto3.Session = _utils.ensure_session(session=boto3_session)
38+
client: boto3.client = _utils.client(service_name="secretsmanager", session=session)
39+
response: Dict[str, Any] = client.get_secret_value(SecretId=name)
40+
if "SecretString" in response:
41+
return cast(str, response["SecretString"])
42+
return base64.b64decode(response["SecretBinary"])
43+
44+
45+
def get_secret_json(name: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
46+
"""Get JSON secret value.
47+
48+
Parameters
49+
----------
50+
name: str:
51+
Specifies the secret containing the version that you want to retrieve.
52+
You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret.
53+
boto3_session : boto3.Session(), optional
54+
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
55+
56+
Returns
57+
-------
58+
Dict[str, Any]
59+
Secret JSON value parsed as a dictionary.
60+
61+
Examples
62+
--------
63+
>>> import awswrangler as wr
64+
>>> value = wr.secretsmanager.get_secret_json("my-secret-with-json-content")
65+
66+
"""
67+
value: Union[str, bytes] = get_secret(name=name, boto3_session=boto3_session)
68+
return cast(Dict[str, Any], json.loads(value))

0 commit comments

Comments
 (0)