Skip to content

Commit c57632c

Browse files
committed
Add auto_create and db_groups to get_redshift_temp_engine(). #288
1 parent 863929c commit c57632c

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

awswrangler/db.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ def get_redshift_temp_engine(
315315
user: str,
316316
database: Optional[str] = None,
317317
duration: int = 900,
318+
auto_create: bool = True,
319+
db_groups: Optional[List[str]] = None,
318320
boto3_session: Optional[boto3.Session] = None,
319321
) -> sqlalchemy.engine.Engine:
320322
"""Get Glue connection details.
@@ -332,6 +334,12 @@ def get_redshift_temp_engine(
332334
The number of seconds until the returned temporary password expires.
333335
Constraint: minimum 900, maximum 3600.
334336
Default: 900
337+
auto_create : bool
338+
Create a database user with the name specified for the user named in user if one does not exist.
339+
db_groups: List[str], optinal
340+
A list of the names of existing database groups that the user named in DbUser will join for the current session,
341+
in addition to any group memberships for an existing user.
342+
If not specified, a new user is added only to PUBLIC.
335343
boto3_session : boto3.Session(), optional
336344
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
337345
@@ -347,9 +355,15 @@ def get_redshift_temp_engine(
347355
348356
"""
349357
client_redshift: boto3.client = _utils.client(service_name="redshift", session=boto3_session)
350-
res: Dict[str, Any] = client_redshift.get_cluster_credentials(
351-
DbUser=user, ClusterIdentifier=cluster_identifier, DurationSeconds=duration, AutoCreate=False
352-
)
358+
args: Dict[str, Any] = {
359+
"DbUser": user,
360+
"ClusterIdentifier": cluster_identifier,
361+
"DurationSeconds": duration,
362+
"AutoCreate": auto_create,
363+
}
364+
if db_groups is not None:
365+
args["DbGroups"] = db_groups
366+
res: Dict[str, Any] = client_redshift.get_cluster_credentials(**args)
353367
_user: str = _quote_plus(res["DbUser"])
354368
password: str = _quote_plus(res["DbPassword"])
355369
cluster: Dict[str, Any] = client_redshift.describe_clusters(ClusterIdentifier=cluster_identifier)["Clusters"][0]

tests/test_db.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,15 @@ def test_redshift_temp_engine(parameters):
178178
assert cursor.fetchall()[0][0] == 1
179179

180180

181+
def test_redshift_temp_engine2(parameters):
182+
engine = wr.db.get_redshift_temp_engine(
183+
cluster_identifier=parameters["redshift"]["identifier"], user="john_doe", duration=900, db_groups=[]
184+
)
185+
with engine.connect() as con:
186+
cursor = con.execute("SELECT 1")
187+
assert cursor.fetchall()[0][0] == 1
188+
189+
181190
def test_postgresql_param():
182191
engine = wr.catalog.get_engine(connection="aws-data-wrangler-postgresql")
183192
df = wr.db.read_sql_query(sql="SELECT %(value)s as col0", con=engine, params={"value": 1})

0 commit comments

Comments
 (0)