Skip to content

Commit e47cdae

Browse files
committed
Add data_source arg for athena queries. #392
1 parent 6a0fe06 commit e47cdae

File tree

4 files changed

+49
-8
lines changed

4 files changed

+49
-8
lines changed

awswrangler/athena/_read.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def _resolve_query_with_cache(
328328
def _resolve_query_without_cache_ctas(
329329
sql: str,
330330
database: Optional[str],
331+
data_source: Optional[str],
331332
s3_output: Optional[str],
332333
keep_files: bool,
333334
chunksize: Union[int, bool, None],
@@ -357,6 +358,7 @@ def _resolve_query_without_cache_ctas(
357358
sql=sql,
358359
wg_config=wg_config,
359360
database=database,
361+
data_source=data_source,
360362
s3_output=s3_output,
361363
workgroup=workgroup,
362364
encryption=encryption,
@@ -408,6 +410,7 @@ def _resolve_query_without_cache_ctas(
408410
def _resolve_query_without_cache_regular(
409411
sql: str,
410412
database: Optional[str],
413+
data_source: Optional[str],
411414
s3_output: Optional[str],
412415
keep_files: bool,
413416
chunksize: Union[int, bool, None],
@@ -424,6 +427,7 @@ def _resolve_query_without_cache_regular(
424427
sql=sql,
425428
wg_config=wg_config,
426429
database=database,
430+
data_source=data_source,
427431
s3_output=s3_output,
428432
workgroup=workgroup,
429433
encryption=encryption,
@@ -447,6 +451,7 @@ def _resolve_query_without_cache(
447451
# pylint: disable=too-many-branches,too-many-locals,too-many-return-statements,too-many-statements
448452
sql: str,
449453
database: str,
454+
data_source: Optional[str],
450455
ctas_approach: bool,
451456
categories: Optional[List[str]],
452457
chunksize: Union[int, bool, None],
@@ -476,6 +481,7 @@ def _resolve_query_without_cache(
476481
return _resolve_query_without_cache_ctas(
477482
sql=sql,
478483
database=database,
484+
data_source=data_source,
479485
s3_output=_s3_output,
480486
keep_files=keep_files,
481487
chunksize=chunksize,
@@ -493,6 +499,7 @@ def _resolve_query_without_cache(
493499
return _resolve_query_without_cache_regular(
494500
sql=sql,
495501
database=database,
502+
data_source=data_source,
496503
s3_output=_s3_output,
497504
keep_files=keep_files,
498505
chunksize=chunksize,
@@ -523,6 +530,7 @@ def read_sql_query(
523530
boto3_session: Optional[boto3.Session] = None,
524531
max_cache_seconds: int = 0,
525532
max_cache_query_inspections: int = 50,
533+
data_source: Optional[str] = None,
526534
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
527535
"""Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.
528536
@@ -662,6 +670,8 @@ def read_sql_query(
662670
Max number of queries that will be inspected from the history to try to find some result to reuse.
663671
The bigger the number of inspection, the bigger will be the latency for not cached queries.
664672
Only takes effect if max_cache_seconds > 0.
673+
data_source : str, optional
674+
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
665675
666676
Returns
667677
-------
@@ -701,6 +711,7 @@ def read_sql_query(
701711
return _resolve_query_without_cache(
702712
sql=sql,
703713
database=database,
714+
data_source=data_source,
704715
ctas_approach=ctas_approach,
705716
categories=categories,
706717
chunksize=chunksize,
@@ -732,6 +743,7 @@ def read_sql_table(
732743
boto3_session: Optional[boto3.Session] = None,
733744
max_cache_seconds: int = 0,
734745
max_cache_query_inspections: int = 50,
746+
data_source: Optional[str] = None,
735747
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
736748
"""Extract the full table AWS Athena and return the results as a Pandas DataFrame.
737749
@@ -868,6 +880,8 @@ def read_sql_table(
868880
Max number of queries that will be inspected from the history to try to find some result to reuse.
869881
The bigger the number of inspection, the bigger will be the latency for not cached queries.
870882
Only takes effect if max_cache_seconds > 0.
883+
data_source : str, optional
884+
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
871885
872886
Returns
873887
-------
@@ -885,6 +899,7 @@ def read_sql_table(
885899
return read_sql_query(
886900
sql=f'SELECT * FROM "{table}"',
887901
database=database,
902+
data_source=data_source,
888903
ctas_approach=ctas_approach,
889904
categories=categories,
890905
chunksize=chunksize,

awswrangler/athena/_utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def _start_query_execution(
5252
sql: str,
5353
wg_config: _WorkGroupConfig,
5454
database: Optional[str] = None,
55+
data_source: Optional[str] = None,
5556
s3_output: Optional[str] = None,
5657
workgroup: Optional[str] = None,
5758
encryption: Optional[str] = None,
@@ -81,6 +82,8 @@ def _start_query_execution(
8182
# database
8283
if database is not None:
8384
args["QueryExecutionContext"] = {"Database": database}
85+
if data_source is not None:
86+
args["QueryExecutionContext"]["Catalog"] = data_source
8487

8588
# workgroup
8689
if workgroup is not None:
@@ -312,6 +315,7 @@ def start_query_execution(
312315
encryption: Optional[str] = None,
313316
kms_key: Optional[str] = None,
314317
boto3_session: Optional[boto3.Session] = None,
318+
data_source: Optional[str] = None,
315319
) -> str:
316320
"""Start a SQL Query against AWS Athena.
317321
@@ -336,6 +340,8 @@ def start_query_execution(
336340
For SSE-KMS and CSE-KMS , this is the KMS key ARN or ID.
337341
boto3_session : boto3.Session(), optional
338342
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
343+
data_source : str, optional
344+
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
339345
340346
Returns
341347
-------
@@ -344,16 +350,24 @@ def start_query_execution(
344350
345351
Examples
346352
--------
353+
Querying into the default data source (Amazon s3 - 'AwsDataCatalog')
354+
347355
>>> import awswrangler as wr
348356
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...')
349357
358+
Querying into another data source (PostgreSQL, Redshift, etc)
359+
360+
>>> import awswrangler as wr
361+
>>> query_exec_id = wr.athena.start_query_execution(sql='...', database='...', data_source='...')
362+
350363
"""
351364
session: boto3.Session = _utils.ensure_session(session=boto3_session)
352365
wg_config: _WorkGroupConfig = _get_workgroup_config(session=session, workgroup=workgroup)
353366
return _start_query_execution(
354367
sql=sql,
355368
wg_config=wg_config,
356369
database=database,
370+
data_source=data_source,
357371
s3_output=s3_output,
358372
workgroup=workgroup,
359373
encryption=encryption,

tests/test_athena.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,10 +775,11 @@ def test_describe_table(path, glue_database, glue_table):
775775
assert wr.athena.describe_table(database=glue_database, table=glue_table).shape == (1, 4)
776776

777777

778+
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
778779
@pytest.mark.parametrize("ctas_approach", [False, True])
779-
def test_athena_nan_inf(glue_database, ctas_approach):
780+
def test_athena_nan_inf(glue_database, ctas_approach, data_source):
780781
sql = "SELECT nan() AS nan, infinity() as inf, -infinity() as inf_n, 1.2 as regular"
781-
df = wr.athena.read_sql_query(sql, glue_database, ctas_approach)
782+
df = wr.athena.read_sql_query(sql, glue_database, ctas_approach, data_source=data_source)
782783
print(df)
783784
print(df.dtypes)
784785
assert df.shape == (1, 4)

tests/test_athena_cache.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import patch
33

44
import pandas as pd
5+
import pytest
56

67
import awswrangler as wr
78

@@ -35,7 +36,8 @@ def test_athena_cache(path, glue_database, glue_table, workgroup1):
3536
assert len(list(dfs)) == 2
3637

3738

38-
def test_cache_query_ctas_approach_true(path, glue_database, glue_table):
39+
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
40+
def test_cache_query_ctas_approach_true(path, glue_database, glue_table, data_source):
3941
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
4042
paths = wr.s3.to_parquet(
4143
df=df,
@@ -54,20 +56,25 @@ def test_cache_query_ctas_approach_true(path, glue_database, glue_table):
5456
"awswrangler.athena._read._check_for_cached_results",
5557
return_value=wr.athena._read._CacheInfo(has_valid_cache=False),
5658
) as mocked_cache_attempt:
57-
df2 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=True, max_cache_seconds=0)
59+
df2 = wr.athena.read_sql_table(
60+
glue_table, glue_database, ctas_approach=True, max_cache_seconds=0, data_source=data_source
61+
)
5862
mocked_cache_attempt.assert_called()
5963
assert df.shape == df2.shape
6064
assert df.c0.sum() == df2.c0.sum()
6165

6266
with patch("awswrangler.athena._read._resolve_query_without_cache") as resolve_no_cache:
63-
df3 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=True, max_cache_seconds=900)
67+
df3 = wr.athena.read_sql_table(
68+
glue_table, glue_database, ctas_approach=True, max_cache_seconds=900, data_source=data_source
69+
)
6470
resolve_no_cache.assert_not_called()
6571
assert df.shape == df3.shape
6672
assert df.c0.sum() == df3.c0.sum()
6773
ensure_athena_query_metadata(df=df3, ctas_approach=True, encrypted=False)
6874

6975

70-
def test_cache_query_ctas_approach_false(path, glue_database, glue_table):
76+
@pytest.mark.parametrize("data_source", [None, "AwsDataCatalog"])
77+
def test_cache_query_ctas_approach_false(path, glue_database, glue_table, data_source):
7178
df = pd.DataFrame({"c0": [0, None]}, dtype="Int64")
7279
paths = wr.s3.to_parquet(
7380
df=df,
@@ -86,13 +93,17 @@ def test_cache_query_ctas_approach_false(path, glue_database, glue_table):
8693
"awswrangler.athena._read._check_for_cached_results",
8794
return_value=wr.athena._read._CacheInfo(has_valid_cache=False),
8895
) as mocked_cache_attempt:
89-
df2 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=False, max_cache_seconds=0)
96+
df2 = wr.athena.read_sql_table(
97+
glue_table, glue_database, ctas_approach=False, max_cache_seconds=0, data_source=data_source
98+
)
9099
mocked_cache_attempt.assert_called()
91100
assert df.shape == df2.shape
92101
assert df.c0.sum() == df2.c0.sum()
93102

94103
with patch("awswrangler.athena._read._resolve_query_without_cache") as resolve_no_cache:
95-
df3 = wr.athena.read_sql_table(glue_table, glue_database, ctas_approach=False, max_cache_seconds=900)
104+
df3 = wr.athena.read_sql_table(
105+
glue_table, glue_database, ctas_approach=False, max_cache_seconds=900, data_source=data_source
106+
)
96107
resolve_no_cache.assert_not_called()
97108
assert df.shape == df3.shape
98109
assert df.c0.sum() == df3.c0.sum()

0 commit comments

Comments
 (0)