Skip to content

Commit cfb04a5

Browse files
authored
Feature/copy creds (#485)
* [feat] added functionality to use boto3_session to unload to s3. * [refactor] Refactored authorization string creation and getting credentials from boto3 session. * [feat/doc] Added functionality to use boto3 session credentials for unload function. Also updated doc for db_users. * [doc] Updated docstring * [refactor] reformatted using black * [refactor] fixed black string reformat
1 parent 92ae19d commit cfb04a5

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

awswrangler/_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,16 @@ def get_region_from_session(boto3_session: Optional[boto3.Session] = None, defau
232232
raise exceptions.InvalidArgument("There is no region_name defined on boto3, please configure it.")
233233

234234

235+
def get_credentials_from_session(
236+
boto3_session: Optional[boto3.Session] = None,
237+
) -> botocore.credentials.ReadOnlyCredentials:
238+
"""Get AWS credentials from boto3 session."""
239+
session: boto3.Session = ensure_session(session=boto3_session)
240+
credentials: botocore.credentials.Credentials = session.get_credentials()
241+
frozen_credentials: botocore.credentials.ReadOnlyCredentials = credentials.get_frozen_credentials()
242+
return frozen_credentials
243+
244+
235245
def list_sampling(lst: List[Any], sampling: float) -> List[Any]:
236246
"""Random List sampling."""
237247
if sampling == 1.0:

awswrangler/redshift.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
66

77
import boto3
8+
import botocore
89
import pandas as pd
910
import pyarrow as pa
1011
import redshift_connector
@@ -54,18 +55,52 @@ def _does_table_exist(cursor: redshift_connector.Cursor, schema: Optional[str],
5455
return len(cursor.fetchall()) > 0
5556

5657

58+
def _make_s3_auth_string(
59+
aws_access_key_id: Optional[str] = None,
60+
aws_secret_access_key: Optional[str] = None,
61+
aws_session_token: Optional[str] = None,
62+
iam_role: Optional[str] = None,
63+
boto3_session: Optional[boto3.Session] = None,
64+
) -> str:
65+
if aws_access_key_id is not None and aws_secret_access_key is not None:
66+
auth_str: str = f"ACCESS_KEY_ID '{aws_access_key_id}'\nSECRET_ACCESS_KEY '{aws_secret_access_key}'\n"
67+
if aws_session_token is not None:
68+
auth_str += f"SESSION_TOKEN '{aws_session_token}'\n"
69+
elif iam_role is not None:
70+
auth_str = f"IAM_ROLE '{iam_role}'\n"
71+
else:
72+
_logger.debug("Attempting to get S3 authorization credentials from boto3 session.")
73+
credentials: botocore.credentials.ReadOnlyCredentials
74+
credentials = _utils.get_credentials_from_session(boto3_session=boto3_session)
75+
if credentials.access_key is None or credentials.secret_key is None:
76+
raise exceptions.InvalidArgument(
77+
"One of IAM Role or AWS ACCESS_KEY_ID and SECRET_ACCESS_KEY must be "
78+
"given. Unable to find ACCESS_KEY_ID and SECRET_ACCESS_KEY in boto3 "
79+
"session."
80+
)
81+
82+
auth_str = f"ACCESS_KEY_ID '{credentials.access_key}'\nSECRET_ACCESS_KEY '{credentials.secret_key}'\n"
83+
if credentials.token is not None:
84+
auth_str += f"SESSION_TOKEN '{credentials.token}'\n"
85+
86+
return auth_str
87+
88+
5789
def _copy(
5890
cursor: redshift_connector.Cursor,
5991
path: str,
6092
table: str,
61-
iam_role: str,
93+
iam_role: Optional[str] = None,
94+
boto3_session: Optional[str] = None,
6295
schema: Optional[str] = None,
6396
) -> None:
6497
if schema is None:
6598
table_name: str = f'"{table}"'
6699
else:
67100
table_name = f'"{schema}"."{table}"'
68-
sql: str = f"COPY {table_name} FROM '{path}'\nIAM_ROLE '{iam_role}'\nFORMAT AS PARQUET"
101+
102+
auth_str: str = _make_s3_auth_string(iam_role=iam_role, boto3_session=boto3_session)
103+
sql: str = f"COPY {table_name} FROM '{path}'{auth_str}\nFORMAT AS PARQUET"
69104
_logger.debug("copy query:\n%s", sql)
70105
cursor.execute(sql)
71106

@@ -388,6 +423,9 @@ def connect_temp(
388423
Default: 900
389424
auto_create : bool
390425
Create a database user with the name specified for the user named in user if one does not exist.
426+
db_groups : List[str], optional
427+
A list of the names of existing database groups that the user named in user will join for the current session,
428+
in addition to any group memberships for an existing user. If not specified, a new user is added only to PUBLIC.
391429
boto3_session : boto3.Session(), optional
392430
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
393431
ssl: bool
@@ -703,7 +741,7 @@ def unload_to_files(
703741
sql: str,
704742
path: str,
705743
con: redshift_connector.Connection,
706-
iam_role: str,
744+
iam_role: Optional[str] = None,
707745
region: Optional[str] = None,
708746
max_file_size: Optional[float] = None,
709747
kms_key_id: Optional[str] = None,
@@ -730,7 +768,7 @@ def unload_to_files(
730768
con : redshift_connector.Connection
731769
Use redshift_connector.connect() to use "
732770
"credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
733-
iam_role : str
771+
iam_role : str, optional
734772
AWS IAM role with the related permissions.
735773
region : str, optional
736774
Specifies the AWS Region where the target Amazon S3 bucket is located.
@@ -782,10 +820,13 @@ def unload_to_files(
782820
region_str: str = f"\nREGION AS '{region}'" if region is not None else ""
783821
max_file_size_str: str = f"\nMAXFILESIZE AS {max_file_size} MB" if max_file_size is not None else ""
784822
kms_key_id_str: str = f"\nKMS_KEY_ID '{kms_key_id}'" if kms_key_id is not None else ""
823+
824+
auth_str: str = _make_s3_auth_string(iam_role=iam_role, boto3_session=boto3_session)
825+
785826
sql = (
786827
f"UNLOAD ('{sql}')\n"
787828
f"TO '{path}'\n"
788-
f"IAM_ROLE '{iam_role}'\n"
829+
f"{auth_str}"
789830
"ALLOWOVERWRITE\n"
790831
"PARALLEL ON\n"
791832
"FORMAT PARQUET\n"
@@ -804,7 +845,7 @@ def unload(
804845
sql: str,
805846
path: str,
806847
con: redshift_connector.Connection,
807-
iam_role: str,
848+
iam_role: Optional[str],
808849
region: Optional[str] = None,
809850
max_file_size: Optional[float] = None,
810851
kms_key_id: Optional[str] = None,
@@ -857,7 +898,7 @@ def unload(
857898
con : redshift_connector.Connection
858899
Use redshift_connector.connect() to use "
859900
"credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
860-
iam_role : str
901+
iam_role : str, optional
861902
AWS IAM role with the related permissions.
862903
region : str, optional
863904
Specifies the AWS Region where the target Amazon S3 bucket is located.
@@ -949,7 +990,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
949990
con: redshift_connector.Connection,
950991
table: str,
951992
schema: str,
952-
iam_role: str,
993+
iam_role: Optional[str] = None,
953994
parquet_infer_sampling: float = 1.0,
954995
mode: str = "append",
955996
diststyle: str = "AUTO",
@@ -992,7 +1033,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
9921033
Table name
9931034
schema : str
9941035
Schema name
995-
iam_role : str
1036+
iam_role : str, optional
9961037
AWS IAM role with the related permissions.
9971038
parquet_infer_sampling : float
9981039
Random sample ratio of files that will have the metadata inspected.
@@ -1089,6 +1130,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
10891130
table=created_table,
10901131
schema=created_schema,
10911132
iam_role=iam_role,
1133+
boto3_session=boto3_session,
10921134
)
10931135
if table != created_table: # upsert
10941136
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)
@@ -1107,7 +1149,7 @@ def copy( # pylint: disable=too-many-arguments
11071149
con: redshift_connector.Connection,
11081150
table: str,
11091151
schema: str,
1110-
iam_role: str,
1152+
iam_role: Optional[str] = None,
11111153
index: bool = False,
11121154
dtype: Optional[Dict[str, str]] = None,
11131155
mode: str = "append",
@@ -1161,7 +1203,7 @@ def copy( # pylint: disable=too-many-arguments
11611203
Table name
11621204
schema : str
11631205
Schema name
1164-
iam_role : str
1206+
iam_role : str, optional
11651207
AWS IAM role with the related permissions.
11661208
index : bool
11671209
True to store the DataFrame index in file, otherwise False to ignore it.

0 commit comments

Comments
 (0)