diff --git a/awswrangler/data_api/rds.py b/awswrangler/data_api/rds.py index dd217b917..4a28715ba 100644 --- a/awswrangler/data_api/rds.py +++ b/awswrangler/data_api/rds.py @@ -256,7 +256,9 @@ def connect( return RdsDataApi(resource_arn, database, secret_arn=secret_arn, boto3_session=boto3_session, **kwargs) -def read_sql_query(sql: str, con: RdsDataApi, database: str | None = None) -> pd.DataFrame: +def read_sql_query( + sql: str, con: RdsDataApi, database: str | None = None, parameters: list[dict[str, Any]] | None = None +) -> pd.DataFrame: """Run an SQL query on an RdsDataApi connection and return the result as a DataFrame. Parameters @@ -267,12 +269,31 @@ def read_sql_query(sql: str, con: RdsDataApi, database: str | None = None) -> pd A RdsDataApi connection instance database Database to run query on - defaults to the database specified by `con`. + parameters + A list of named parameters e.g. [{"name": "col", "value": {"stringValue": "val1"}}]. Returns ------- A Pandas DataFrame containing the query results. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.data_api.rds.read_sql_query( + >>> sql="SELECT * FROM public.my_table", + >>> con=con, + >>> ) + + >>> import awswrangler as wr + >>> df = wr.data_api.rds.read_sql_query( + >>> sql="SELECT * FROM public.my_table WHERE col = :name", + >>> con=con, + >>> parameters=[ + >>> {"name": "col1", "value": {"stringValue": "val1"}} + >>> ], + >>> ) """ - return con.execute(sql, database=database) + return con.execute(sql, database=database, parameters=parameters) def _drop_table(con: RdsDataApi, table: str, database: str, transaction_id: str, sql_mode: str) -> None: diff --git a/awswrangler/data_api/redshift.py b/awswrangler/data_api/redshift.py index 86ceccc8d..283a30bb6 100644 --- a/awswrangler/data_api/redshift.py +++ b/awswrangler/data_api/redshift.py @@ -112,8 +112,6 @@ def _execute_statement( ) -> str: if transaction_id: raise exceptions.InvalidArgument("`transaction_id` not supported for Redshift Data API") - if parameters: - raise exceptions.InvalidArgument("`parameters` not supported for Redshift Data API") self._validate_redshift_target() self._validate_auth_method() @@ -130,6 +128,8 @@ def _execute_statement( args["ClusterIdentifier"] = self.cluster_id if self.workgroup_name: args["WorkgroupName"] = self.workgroup_name + if parameters: + args["Parameters"] = parameters # type: ignore[assignment] _logger.debug("Executing %s", sql) response = self.client.execute_statement( @@ -285,7 +285,12 @@ def connect( ) -def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None) -> pd.DataFrame: +def read_sql_query( + sql: str, + con: RedshiftDataApi, + database: str | None = None, + parameters: list[dict[str, Any]] | None = None, +) -> pd.DataFrame: """Run an SQL query on a RedshiftDataApi connection and return the result as a DataFrame. Parameters @@ -296,9 +301,29 @@ def read_sql_query(sql: str, con: RedshiftDataApi, database: str | None = None) A RedshiftDataApi connection instance database Database to run query on - defaults to the database specified by `con`. + parameters + A list of named parameters e.g. [{"name": "id", "value": "42"}]. Returns ------- A Pandas DataFrame containing the query results. + + Examples + -------- + >>> import awswrangler as wr + >>> df = wr.data_api.redshift.read_sql_query( + >>> sql="SELECT * FROM public.my_table", + >>> con=con, + >>> ) + + >>> import awswrangler as wr + >>> df = wr.data_api.redshift.read_sql_query( + >>> sql="SELECT * FROM public.my_table WHERE id >= :id", + >>> con=con, + >>> parameters=[ + >>> {"name": "id", "value": "42"}, + >>> ], + >>> ) + """ - return con.execute(sql, database=database) + return con.execute(sql, database=database, parameters=parameters) diff --git a/tests/unit/test_data_api.py b/tests/unit/test_data_api.py index 2657ccd24..c3d335a7c 100644 --- a/tests/unit/test_data_api.py +++ b/tests/unit/test_data_api.py @@ -124,6 +124,24 @@ def test_data_api_redshift_basic_select(redshift_connector: "RedshiftDataApi", r assert_pandas_equals(dataframe, expected_dataframe) +def test_data_api_redshift_parameters(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None: + wr.data_api.redshift.read_sql_query( + f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector + ) + wr.data_api.redshift.read_sql_query( + f"INSERT INTO public.{redshift_table} VALUES (41, 'test1'), (42, 'test2')", con=redshift_connector + ) + expected_dataframe = pd.DataFrame([[42, "test2"]], columns=["id", "name"]) + + dataframe = wr.data_api.redshift.read_sql_query( + f"SELECT * FROM public.{redshift_table} WHERE id >= :id", + con=redshift_connector, + parameters=[{"name": "id", "value": "42"}], + ) + + assert_pandas_equals(dataframe, expected_dataframe) + + def test_data_api_redshift_empty_results_select(redshift_connector: "RedshiftDataApi", redshift_table: str) -> None: wr.data_api.redshift.read_sql_query( f"CREATE TABLE public.{redshift_table} (id INT, name VARCHAR)", con=redshift_connector @@ -301,3 +319,25 @@ def test_data_api_postgresql(postgresql_serverless_connector: "RdsDataApi", post ) expected_dataframe = pd.DataFrame([["test"]], columns=["name"]) assert_pandas_equals(out_frame, expected_dataframe) + + +def test_data_api_mysql_parameters( + mysql_serverless_connector: "RdsDataApi", + mysql_serverless_table: str, +) -> None: + database = "test" + df = pd.DataFrame([[42, "test"]], columns=["id", "name"]) + + wr.data_api.rds.to_sql( + df=df, + con=mysql_serverless_connector, + table=mysql_serverless_table, + database=database, + ) + + out_df = wr.data_api.rds.read_sql_query( + f"SELECT * FROM {database}.{mysql_serverless_table} WHERE name = :name", + con=mysql_serverless_connector, + parameters=[{"name": "name", "value": {"stringValue": "test"}}], + ) + assert_pandas_equals(out_df, df)