55from typing import Any , Dict , Iterator , List , Optional , Tuple , Union
66
77import boto3
8+ import botocore
89import pandas as pd
910import pyarrow as pa
1011import redshift_connector
@@ -54,18 +55,52 @@ def _does_table_exist(cursor: redshift_connector.Cursor, schema: Optional[str],
5455 return len (cursor .fetchall ()) > 0
5556
5657
58+ def _make_s3_auth_string (
59+ aws_access_key_id : Optional [str ] = None ,
60+ aws_secret_access_key : Optional [str ] = None ,
61+ aws_session_token : Optional [str ] = None ,
62+ iam_role : Optional [str ] = None ,
63+ boto3_session : Optional [boto3 .Session ] = None ,
64+ ) -> str :
65+ if aws_access_key_id is not None and aws_secret_access_key is not None :
66+ auth_str : str = f"ACCESS_KEY_ID '{ aws_access_key_id } '\n SECRET_ACCESS_KEY '{ aws_secret_access_key } '\n "
67+ if aws_session_token is not None :
68+ auth_str += f"SESSION_TOKEN '{ aws_session_token } '\n "
69+ elif iam_role is not None :
70+ auth_str = f"IAM_ROLE '{ iam_role } '\n "
71+ else :
72+ _logger .debug ("Attempting to get S3 authorization credentials from boto3 session." )
73+ credentials : botocore .credentials .ReadOnlyCredentials
74+ credentials = _utils .get_credentials_from_session (boto3_session = boto3_session )
75+ if credentials .access_key is None or credentials .secret_key is None :
76+ raise exceptions .InvalidArgument (
77+ "One of IAM Role or AWS ACCESS_KEY_ID and SECRET_ACCESS_KEY must be "
78+ "given. Unable to find ACCESS_KEY_ID and SECRET_ACCESS_KEY in boto3 "
79+ "session."
80+ )
81+
82+ auth_str = f"ACCESS_KEY_ID '{ credentials .access_key } '\n SECRET_ACCESS_KEY '{ credentials .secret_key } '\n "
83+ if credentials .token is not None :
84+ auth_str += f"SESSION_TOKEN '{ credentials .token } '\n "
85+
86+ return auth_str
87+
88+
5789def _copy (
5890 cursor : redshift_connector .Cursor ,
5991 path : str ,
6092 table : str ,
61- iam_role : str ,
93+ iam_role : Optional [str ] = None ,
94+ boto3_session : Optional [str ] = None ,
6295 schema : Optional [str ] = None ,
6396) -> None :
6497 if schema is None :
6598 table_name : str = f'"{ table } "'
6699 else :
67100 table_name = f'"{ schema } "."{ table } "'
68- sql : str = f"COPY { table_name } FROM '{ path } '\n IAM_ROLE '{ iam_role } '\n FORMAT AS PARQUET"
101+
102+ auth_str : str = _make_s3_auth_string (iam_role = iam_role , boto3_session = boto3_session )
103+ sql : str = f"COPY { table_name } FROM '{ path } '{ auth_str } \n FORMAT AS PARQUET"
69104 _logger .debug ("copy query:\n %s" , sql )
70105 cursor .execute (sql )
71106
@@ -388,6 +423,9 @@ def connect_temp(
388423 Default: 900
389424 auto_create : bool
390425 Create a database user with the name specified for the user named in user if one does not exist.
426+ db_groups : List[str], optional
427+ A list of the names of existing database groups that the user named in user will join for the current session,
428+ in addition to any group memberships for an existing user. If not specified, a new user is added only to PUBLIC.
391429 boto3_session : boto3.Session(), optional
392430 Boto3 Session. The default boto3 session will be used if boto3_session receive None.
393431 ssl: bool
@@ -703,7 +741,7 @@ def unload_to_files(
703741 sql : str ,
704742 path : str ,
705743 con : redshift_connector .Connection ,
706- iam_role : str ,
744+ iam_role : Optional [ str ] = None ,
707745 region : Optional [str ] = None ,
708746 max_file_size : Optional [float ] = None ,
709747 kms_key_id : Optional [str ] = None ,
@@ -730,7 +768,7 @@ def unload_to_files(
730768 con : redshift_connector.Connection
731769 Use redshift_connector.connect() to use "
732770 "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
733- iam_role : str
771+ iam_role : str, optional
734772 AWS IAM role with the related permissions.
735773 region : str, optional
736774 Specifies the AWS Region where the target Amazon S3 bucket is located.
@@ -782,10 +820,13 @@ def unload_to_files(
782820 region_str : str = f"\n REGION AS '{ region } '" if region is not None else ""
783821 max_file_size_str : str = f"\n MAXFILESIZE AS { max_file_size } MB" if max_file_size is not None else ""
784822 kms_key_id_str : str = f"\n KMS_KEY_ID '{ kms_key_id } '" if kms_key_id is not None else ""
823+
824+ auth_str : str = _make_s3_auth_string (iam_role = iam_role , boto3_session = boto3_session )
825+
785826 sql = (
786827 f"UNLOAD ('{ sql } ')\n "
787828 f"TO '{ path } '\n "
788- f"IAM_ROLE ' { iam_role } ' \n "
829+ f"{ auth_str } "
789830 "ALLOWOVERWRITE\n "
790831 "PARALLEL ON\n "
791832 "FORMAT PARQUET\n "
@@ -804,7 +845,7 @@ def unload(
804845 sql : str ,
805846 path : str ,
806847 con : redshift_connector .Connection ,
807- iam_role : str ,
848+ iam_role : Optional [ str ] ,
808849 region : Optional [str ] = None ,
809850 max_file_size : Optional [float ] = None ,
810851 kms_key_id : Optional [str ] = None ,
@@ -857,7 +898,7 @@ def unload(
857898 con : redshift_connector.Connection
858899 Use redshift_connector.connect() to use "
859900 "credentials directly or wr.redshift.connect() to fetch it from the Glue Catalog.
860- iam_role : str
901+ iam_role : str, optional
861902 AWS IAM role with the related permissions.
862903 region : str, optional
863904 Specifies the AWS Region where the target Amazon S3 bucket is located.
@@ -949,7 +990,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
949990 con : redshift_connector .Connection ,
950991 table : str ,
951992 schema : str ,
952- iam_role : str ,
993+ iam_role : Optional [ str ] = None ,
953994 parquet_infer_sampling : float = 1.0 ,
954995 mode : str = "append" ,
955996 diststyle : str = "AUTO" ,
@@ -992,7 +1033,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
9921033 Table name
9931034 schema : str
9941035 Schema name
995- iam_role : str
1036+ iam_role : str, optional
9961037 AWS IAM role with the related permissions.
9971038 parquet_infer_sampling : float
9981039 Random sample ratio of files that will have the metadata inspected.
@@ -1089,6 +1130,7 @@ def copy_from_files( # pylint: disable=too-many-locals,too-many-arguments
10891130 table = created_table ,
10901131 schema = created_schema ,
10911132 iam_role = iam_role ,
1133+ boto3_session = boto3_session ,
10921134 )
10931135 if table != created_table : # upsert
10941136 _upsert (cursor = cursor , schema = schema , table = table , temp_table = created_table , primary_keys = primary_keys )
@@ -1107,7 +1149,7 @@ def copy( # pylint: disable=too-many-arguments
11071149 con : redshift_connector .Connection ,
11081150 table : str ,
11091151 schema : str ,
1110- iam_role : str ,
1152+ iam_role : Optional [ str ] = None ,
11111153 index : bool = False ,
11121154 dtype : Optional [Dict [str , str ]] = None ,
11131155 mode : str = "append" ,
@@ -1161,7 +1203,7 @@ def copy( # pylint: disable=too-many-arguments
11611203 Table name
11621204 schema : str
11631205 Schema name
1164- iam_role : str
1206+ iam_role : str, optional
11651207 AWS IAM role with the related permissions.
11661208 index : bool
11671209 True to store the DataFrame index in file, otherwise False to ignore it.
0 commit comments