Skip to content

Commit 9de83bb

Browse files
authored
Add chunked mode to timestream.query (#864)
1 parent 8839726 commit 9de83bb

File tree

2 files changed

+79
-23
lines changed

2 files changed

+79
-23
lines changed

awswrangler/timestream.py

Lines changed: 44 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
import logging
66
from datetime import datetime
7-
from typing import Any, Dict, List, Optional, cast
7+
from typing import Any, Dict, Iterator, List, Optional, Union, cast
88

99
import boto3
1010
import pandas as pd
@@ -103,6 +103,14 @@ def _process_row(schema: List[Dict[str, str]], row: Dict[str, Any]) -> List[Any]
103103
return row_processed
104104

105105

106+
def _rows_to_df(rows: List[List[Any]], schema: List[Dict[str, str]]) -> pd.DataFrame:
107+
df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema])
108+
for col in schema:
109+
if col["type"] == "VARCHAR":
110+
df[col["name"]] = df[col["name"]].astype("string")
111+
return df
112+
113+
106114
def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
107115
schema: List[Dict[str, str]] = []
108116
for col in page["ColumnInfo"]:
@@ -112,6 +120,29 @@ def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
112120
return schema
113121

114122

123+
def _paginate_query(
124+
sql: str, pagination_config: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session] = None
125+
) -> Iterator[pd.DataFrame]:
126+
client: boto3.client = _utils.client(
127+
service_name="timestream-query",
128+
session=boto3_session,
129+
botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}),
130+
)
131+
paginator = client.get_paginator("query")
132+
rows: List[List[Any]] = []
133+
schema: List[Dict[str, str]] = []
134+
page_iterator = paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {})
135+
for page in page_iterator:
136+
if not schema:
137+
schema = _process_schema(page=page)
138+
_logger.debug("schema: %s", schema)
139+
for row in page["Rows"]:
140+
rows.append(_process_row(schema=schema, row=row))
141+
if len(rows) > 0:
142+
yield _rows_to_df(rows, schema)
143+
rows = []
144+
145+
115146
def write(
116147
df: pd.DataFrame,
117148
database: str,
@@ -200,14 +231,19 @@ def write(
200231

201232

202233
def query(
203-
sql: str, pagination_config: Optional[Dict[str, Any]] = None, boto3_session: Optional[boto3.Session] = None
204-
) -> pd.DataFrame:
234+
sql: str,
235+
chunked: bool = False,
236+
pagination_config: Optional[Dict[str, Any]] = None,
237+
boto3_session: Optional[boto3.Session] = None,
238+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
205239
"""Run a query and retrieve the result as a Pandas DataFrame.
206240
207241
Parameters
208242
----------
209243
sql: str
210244
SQL query.
245+
chunked: bool
246+
If True returns dataframe iterator, and a single dataframe otherwise. False by default.
211247
pagination_config: Dict[str, Any], optional
212248
Pagination configuration dictionary of a form {'MaxItems': 10, 'PageSize': 10, 'StartingToken': '...'}
213249
boto3_session : boto3.Session(), optional
@@ -220,31 +256,16 @@ def query(
220256
221257
Examples
222258
--------
223-
Running a query and storing the result as a Pandas DataFrame
259+
Run a query and return the result as a Pandas DataFrame or an iterable.
224260
225261
>>> import awswrangler as wr
226262
>>> df = wr.timestream.query('SELECT * FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 10')
227263
228264
"""
229-
client: boto3.client = _utils.client(
230-
service_name="timestream-query",
231-
session=boto3_session,
232-
botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}),
233-
)
234-
paginator = client.get_paginator("query")
235-
rows: List[List[Any]] = []
236-
schema: List[Dict[str, str]] = []
237-
for page in paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {}):
238-
if not schema:
239-
schema = _process_schema(page=page)
240-
for row in page["Rows"]:
241-
rows.append(_process_row(schema=schema, row=row))
242-
_logger.debug("schema: %s", schema)
243-
df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema])
244-
for col in schema:
245-
if col["type"] == "VARCHAR":
246-
df[col["name"]] = df[col["name"]].astype("string")
247-
return df
265+
result_iterator = _paginate_query(sql, pagination_config, boto3_session)
266+
if chunked:
267+
return result_iterator
268+
return pd.concat(result_iterator, ignore_index=True)
248269

249270

250271
def create_database(

tests/test_timestream.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,41 @@ def test_basic_scenario(timestream_database_and_table, pagination):
4949
assert df.shape == (3, 8)
5050

5151

52+
def test_chunked_scenario(timestream_database_and_table):
53+
df = pd.DataFrame(
54+
{
55+
"time": [datetime.now() for _ in range(5)],
56+
"dim0": ["foo", "boo", "bar", "fizz", "buzz"],
57+
"dim1": [1, 2, 3, 4, 5],
58+
"measure": [1.0, 1.1, 1.2, 1.3, 1.4],
59+
}
60+
)
61+
rejected_records = wr.timestream.write(
62+
df=df,
63+
database=timestream_database_and_table,
64+
table=timestream_database_and_table,
65+
time_col="time",
66+
measure_col="measure",
67+
dimensions_cols=["dim0", "dim1"],
68+
)
69+
assert len(rejected_records) == 0
70+
shapes = [(3, 5), (2, 5)]
71+
for df, shape in zip(
72+
wr.timestream.query(
73+
f"""
74+
SELECT
75+
*
76+
FROM "{timestream_database_and_table}"."{timestream_database_and_table}"
77+
ORDER BY time ASC
78+
""",
79+
chunked=True,
80+
pagination_config={"MaxItems": 5, "PageSize": 3},
81+
),
82+
shapes,
83+
):
84+
assert df.shape == shape
85+
86+
5287
def test_versioned(timestream_database_and_table):
5388
name = timestream_database_and_table
5489
time = [datetime.now(), datetime.now(), datetime.now()]

0 commit comments

Comments
 (0)