Skip to content

Commit f48a05f

Browse files
committed
feat: add redshift data api params
1 parent bcb1041 commit f48a05f

File tree

2 files changed

+47
-4
lines changed

2 files changed

+47
-4
lines changed

awswrangler/data_api/redshift.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,6 @@ def _execute_statement(
112112
) -> str:
113113
if transaction_id:
114114
raise exceptions.InvalidArgument("`transaction_id` not supported for Redshift Data API")
115-
if parameters:
116-
raise exceptions.InvalidArgument("`parameters` not supported for Redshift Data API")
117115

118116
self._validate_redshift_target()
119117
self._validate_auth_method()
@@ -130,6 +128,8 @@ def _execute_statement(
130128
args["ClusterIdentifier"] = self.cluster_id
131129
if self.workgroup_name:
132130
args["WorkgroupName"] = self.workgroup_name
131+
if parameters:
132+
args["Parameters"] = parameters
133133

134134
_logger.debug("Executing %s", sql)
135135
response = self.client.execute_statement(
@@ -285,7 +285,12 @@ def connect(
285285
)
286286

287287

288-
def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None) -> pd.DataFrame:
288+
def read_sql_query(
289+
sql: str,
290+
con: RedshiftDataApi,
291+
database: str | None = None,
292+
parameters: list[dict[str, Any]] | None = None,
293+
) -> pd.DataFrame:
289294
"""Run an SQL query on a RedshiftDataApi connection and return the result as a DataFrame.
290295
291296
Parameters
@@ -296,9 +301,29 @@ def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None)
296301
A RedshiftDataApi connection instance
297302
database
298303
Database to run query on - defaults to the database specified by `con`.
304+
parameters
305+
A list of named parameters e.g. [{"name": "id", "value": "42"}].
299306
300307
Returns
301308
-------
302309
A Pandas DataFrame containing the query results.
310+
311+
Examples
312+
--------
313+
>>> import awswrangler as wr
314+
>>> df = wr.data_api.redshift.read_sql_query(
315+
>>> sql="SELECT * FROM public.my_table",
316+
>>> con=con,
317+
>>> )
318+
319+
>>> import awswrangler as wr
320+
>>> df = wr.data_api.redshift.read_sql_query(
321+
>>> sql="SELECT * FROM public.my_table WHERE id >= :id",
322+
>>> con=con,
323+
>>> parameters=[
324+
>>> {"name": "id", "value": "42"},
325+
>>> ],
326+
>>> )
327+
303328
"""
304-
return con.execute(sql, database=database)
329+
return con.execute(sql, database=database, parameters=parameters)

tests/unit/test_data_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,24 @@ def test_data_api_redshift_basic_select(redshift_connector: "RedshiftDataApi", r
124124
assert_pandas_equals(dataframe, expected_dataframe)
125125

126126

127+
def test_data_api_redshift_parameters(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None:
128+
wr.data_api.redshift.read_sql_query(
129+
f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector
130+
)
131+
wr.data_api.redshift.read_sql_query(
132+
f"INSERT INTO public.{redshift_table} VALUES (41, 'test1'), (42, 'test2')", con=redshift_connector
133+
)
134+
expected_dataframe = pd.DataFrame([[42, "test"]], columns=["id", "name"])
135+
136+
dataframe = wr.data_api.redshift.read_sql_query(
137+
f"SELECT * FROM public.{redshift_table} WHERE id >= :id",
138+
con=redshift_connector,
139+
parameters=[{"name": "id", "value": "42"}],
140+
)
141+
142+
assert_pandas_equals(dataframe, expected_dataframe)
143+
144+
127145
def test_data_api_redshift_empty_results_select(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None:
128146
wr.data_api.redshift.read_sql_query(
129147
f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector

0 commit comments

Comments
 (0)