Skip to content

Commit a706c8d

Browse files
authored
enhancement(data-api): Add boto3 session to connect (#1261)
1 parent 8ec237a commit a706c8d

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

awswrangler/data_api/rds.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import boto3
88
import pandas as pd
99

10+
from awswrangler import _utils
1011
from awswrangler.data_api import connector
1112

1213

@@ -27,6 +28,8 @@ class RdsDataApi(connector.DataApiConnector):
2728
Factor by which to increase the sleep between connection attempts to paused clusters - defaults to 1.0.
2829
retries: int
2930
Maximum number of connection attempts to paused clusters - defaults to 10.
31+
boto3_session : boto3.Session(), optional
32+
The boto3 session. If `None`, the default boto3 session is used.
3033
"""
3134

3235
def __init__(
@@ -37,12 +40,13 @@ def __init__(
3740
sleep: float = 0.5,
3841
backoff: float = 1.0,
3942
retries: int = 30,
43+
boto3_session: Optional[boto3.Session] = None,
4044
) -> None:
4145
self.resource_arn = resource_arn
4246
self.database = database
4347
self.secret_arn = secret_arn
4448
self.wait_config = connector.WaitConfig(sleep, backoff, retries)
45-
self.client = boto3.client("rds-data")
49+
self.client: boto3.client = _utils.client(service_name="rds-data", session=boto3_session)
4650
self.results: Dict[str, Dict[str, Any]] = {}
4751
logger: logging.Logger = logging.getLogger(__name__)
4852
super().__init__(self.client, logger)
@@ -114,7 +118,9 @@ def _get_statement_result(self, request_id: str) -> pd.DataFrame:
114118
return dataframe
115119

116120

117-
def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: Any) -> RdsDataApi:
121+
def connect(
122+
resource_arn: str, database: str, secret_arn: str = "", boto3_session: Optional[boto3.Session] = None, **kwargs: Any
123+
) -> RdsDataApi:
118124
"""Create a RDS Data API connection.
119125
120126
Parameters
@@ -125,14 +131,16 @@ def connect(resource_arn: str, database: str, secret_arn: str = "", **kwargs: An
125131
Target database name.
126132
secret_arn: str
127133
The ARN for the secret to be used for authentication.
134+
boto3_session : boto3.Session(), optional
135+
The boto3 session. If `None`, the default boto3 session is used.
128136
**kwargs
129137
Any additional kwargs are passed to the underlying RdsDataApi class.
130138
131139
Returns
132140
-------
133141
A RdsDataApi connection instance that can be used with `wr.rds.data_api.read_sql_query`.
134142
"""
135-
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, **kwargs)
143+
return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs)
136144

137145

138146
def read_sql_query(sql: str, con: RdsDataApi, database: Optional[str] = None) -> pd.DataFrame:

awswrangler/data_api/redshift.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import boto3
77
import pandas as pd
88

9+
from awswrangler import _utils
910
from awswrangler.data_api import connector
1011

1112

@@ -28,6 +29,8 @@ class RedshiftDataApi(connector.DataApiConnector):
2829
Factor by which to increase the sleep between result fetch attempts - defaults to 1.5.
2930
retries: int
3031
Maximum number of result fetch attempts - defaults to 15.
32+
boto3_session : boto3.Session(), optional
33+
The boto3 session. If `None`, the default boto3 session is used.
3134
"""
3235

3336
def __init__(
@@ -39,12 +42,13 @@ def __init__(
3942
sleep: float = 0.25,
4043
backoff: float = 1.5,
4144
retries: int = 15,
45+
boto3_session: Optional[boto3.Session] = None,
4246
) -> None:
4347
self.cluster_id = cluster_id
4448
self.database = database
4549
self.secret_arn = secret_arn
4650
self.db_user = db_user
47-
self.client = boto3.client("redshift-data")
51+
self.client: boto3.client = _utils.client(service_name="redshift-data", session=boto3_session)
4852
self.waiter = RedshiftDataApiWaiter(self.client, sleep, backoff, retries)
4953
logger: logging.Logger = logging.getLogger(__name__)
5054
super().__init__(self.client, logger)
@@ -162,7 +166,14 @@ class RedshiftDataApiTimeoutException(Exception):
162166
"""Indicates a statement execution did not complete in the expected wait time."""
163167

164168

165-
def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str = "", **kwargs: Any) -> RedshiftDataApi:
169+
def connect(
170+
cluster_id: str,
171+
database: str,
172+
secret_arn: str = "",
173+
db_user: str = "",
174+
boto3_session: Optional[boto3.Session] = None,
175+
**kwargs: Any,
176+
) -> RedshiftDataApi:
166177
"""Create a Redshift Data API connection.
167178
168179
Parameters
@@ -175,14 +186,18 @@ def connect(cluster_id: str, database: str, secret_arn: str = "", db_user: str =
175186
The ARN for the secret to be used for authentication - only required if `db_user` not provided.
176187
db_user: str
177188
The database user to generate temporary credentials for - only required if `secret_arn` not provided.
189+
boto3_session : boto3.Session(), optional
190+
The boto3 session. If `None`, the default boto3 session is used.
178191
**kwargs
179192
Any additional kwargs are passed to the underlying RedshiftDataApi class.
180193
181194
Returns
182195
-------
183196
A RedshiftDataApi connection instance that can be used with `wr.redshift.data_api.read_sql_query`.
184197
"""
185-
return RedshiftDataApi(cluster_id, database, secret_arn=secret_arn, db_user=db_user, **kwargs)
198+
return RedshiftDataApi(
199+
cluster_id, database, secret_arn=secret_arn, db_user=db_user, boto3_session=boto3_session, **kwargs
200+
)
186201

187202

188203
def read_sql_query(sql: str, con: RedshiftDataApi, database: Optional[str] = None) -> pd.DataFrame:

tests/test_data_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import boto3
12
import pandas as pd
23
import pytest
34

@@ -11,15 +12,15 @@ def redshift_connector(databases_parameters):
1112
cluster_id = databases_parameters["redshift"]["identifier"]
1213
database = databases_parameters["redshift"]["database"]
1314
secret_arn = databases_parameters["redshift"]["secret_arn"]
14-
conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn)
15+
conn = wr.data_api.redshift.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=None)
1516
return conn
1617

1718

1819
def create_rds_connector(rds_type, parameters):
1920
cluster_id = parameters[rds_type]["arn"]
2021
database = parameters[rds_type]["database"]
2122
secret_arn = parameters[rds_type]["secret_arn"]
22-
conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn)
23+
conn = wr.data_api.rds.connect(cluster_id, database, secret_arn=secret_arn, boto3_session=boto3.DEFAULT_SESSION)
2324
return conn
2425

2526

0 commit comments

Comments
 (0)