Skip to content

Commit 32fcf63

Browse files
authored
Merge pull request #1821 from aws/timestream-query-nexttoken
timestream.query: add QueryId and NextToken to df attributes
2 parents 2099ea6 + bf52c42 commit 32fcf63

File tree

2 files changed

+17
-4
lines changed

2 files changed

+17
-4
lines changed

awswrangler/timestream.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,12 @@ def _process_row(schema: List[Dict[str, str]], row: Dict[str, Any]) -> List[Any]
122122
return row_processed
123123

124124

125-
def _rows_to_df(rows: List[List[Any]], schema: List[Dict[str, str]]) -> pd.DataFrame:
125+
def _rows_to_df(
126+
rows: List[List[Any]], schema: List[Dict[str, str]], df_metadata: Optional[Dict[str, str]] = None
127+
) -> pd.DataFrame:
126128
df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema])
129+
if df_metadata:
130+
df.attrs = df_metadata
127131
for col in schema:
128132
if col["type"] == "VARCHAR":
129133
df[col["name"]] = df[col["name"]].astype("string")
@@ -143,7 +147,7 @@ def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
143147

144148

145149
def _paginate_query(
146-
sql: str, pagination_config: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session] = None
150+
sql: str, chunked: bool, pagination_config: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session] = None
147151
) -> Iterator[pd.DataFrame]:
148152
client: boto3.client = _utils.client(
149153
service_name="timestream-query",
@@ -161,7 +165,13 @@ def _paginate_query(
161165
for row in page["Rows"]:
162166
rows.append(_process_row(schema=schema, row=row))
163167
if len(rows) > 0:
164-
yield _rows_to_df(rows, schema)
168+
df_metadata = {}
169+
if chunked:
170+
if "NextToken" in page:
171+
df_metadata["NextToken"] = page["NextToken"]
172+
df_metadata["QueryId"] = page["QueryId"]
173+
174+
yield _rows_to_df(rows, schema, df_metadata)
165175
rows = []
166176

167177

@@ -289,9 +299,10 @@ def query(
289299
>>> df = wr.timestream.query('SELECT * FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 10')
290300
291301
"""
292-
result_iterator = _paginate_query(sql, pagination_config, boto3_session)
302+
result_iterator = _paginate_query(sql, chunked, pagination_config, boto3_session)
293303
if chunked:
294304
return result_iterator
305+
295306
# Prepending an empty DataFrame ensures returning an empty DataFrame if result_iterator is empty
296307
return pd.concat(itertools.chain([pd.DataFrame()], result_iterator), ignore_index=True)
297308

tests/test_timestream.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_basic_scenario(timestream_database_and_table, pagination):
4747
pagination_config=pagination,
4848
)
4949
assert df.shape == (3, 8)
50+
assert df.attrs == {}
5051

5152

5253
@pytest.mark.parametrize("chunked", [False, True])
@@ -114,6 +115,7 @@ def test_chunked_scenario(timestream_database_and_table):
114115
),
115116
shapes,
116117
):
118+
assert "QueryId" in df.attrs
117119
assert df.shape == shape
118120

119121

0 commit comments

Comments
 (0)