44import itertools
55import logging
66from datetime import datetime
7- from typing import Any , Dict , List , Optional , cast
7+ from typing import Any , Dict , Iterator , List , Optional , Union , cast
88
99import boto3
1010import pandas as pd
@@ -103,6 +103,14 @@ def _process_row(schema: List[Dict[str, str]], row: Dict[str, Any]) -> List[Any]
103103 return row_processed
104104
105105
106+ def _rows_to_df (rows : List [List [Any ]], schema : List [Dict [str , str ]]) -> pd .DataFrame :
107+ df = pd .DataFrame (data = rows , columns = [c ["name" ] for c in schema ])
108+ for col in schema :
109+ if col ["type" ] == "VARCHAR" :
110+ df [col ["name" ]] = df [col ["name" ]].astype ("string" )
111+ return df
112+
113+
106114def _process_schema (page : Dict [str , Any ]) -> List [Dict [str , str ]]:
107115 schema : List [Dict [str , str ]] = []
108116 for col in page ["ColumnInfo" ]:
@@ -112,6 +120,29 @@ def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
112120 return schema
113121
114122
123+ def _paginate_query (
124+ sql : str , pagination_config : Optional [Dict [str , Any ]], boto3_session : Optional [boto3 .Session ] = None
125+ ) -> Iterator [pd .DataFrame ]:
126+ client : boto3 .client = _utils .client (
127+ service_name = "timestream-query" ,
128+ session = boto3_session ,
129+ botocore_config = Config (read_timeout = 60 , retries = {"max_attempts" : 10 }),
130+ )
131+ paginator = client .get_paginator ("query" )
132+ rows : List [List [Any ]] = []
133+ schema : List [Dict [str , str ]] = []
134+ page_iterator = paginator .paginate (QueryString = sql , PaginationConfig = pagination_config or {})
135+ for page in page_iterator :
136+ if not schema :
137+ schema = _process_schema (page = page )
138+ _logger .debug ("schema: %s" , schema )
139+ for row in page ["Rows" ]:
140+ rows .append (_process_row (schema = schema , row = row ))
141+ if len (rows ) > 0 :
142+ yield _rows_to_df (rows , schema )
143+ rows = []
144+
145+
115146def write (
116147 df : pd .DataFrame ,
117148 database : str ,
@@ -200,14 +231,19 @@ def write(
200231
201232
202233def query (
203- sql : str , pagination_config : Optional [Dict [str , Any ]] = None , boto3_session : Optional [boto3 .Session ] = None
204- ) -> pd .DataFrame :
234+ sql : str ,
235+ chunked : bool = False ,
236+ pagination_config : Optional [Dict [str , Any ]] = None ,
237+ boto3_session : Optional [boto3 .Session ] = None ,
238+ ) -> Union [pd .DataFrame , Iterator [pd .DataFrame ]]:
205239 """Run a query and retrieve the result as a Pandas DataFrame.
206240
207241 Parameters
208242 ----------
209243 sql: str
210244 SQL query.
245+ chunked: bool
246+ If True returns dataframe iterator, and a single dataframe otherwise. False by default.
211247 pagination_config: Dict[str, Any], optional
212248 Pagination configuration dictionary of a form {'MaxItems': 10, 'PageSize': 10, 'StartingToken': '...'}
213249 boto3_session : boto3.Session(), optional
@@ -220,31 +256,16 @@ def query(
220256
221257 Examples
222258 --------
223- Running a query and storing the result as a Pandas DataFrame
259+ Run a query and return the result as a Pandas DataFrame or an iterable.
224260
225261 >>> import awswrangler as wr
226262 >>> df = wr.timestream.query('SELECT * FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 10')
227263
228264 """
229- client : boto3 .client = _utils .client (
230- service_name = "timestream-query" ,
231- session = boto3_session ,
232- botocore_config = Config (read_timeout = 60 , retries = {"max_attempts" : 10 }),
233- )
234- paginator = client .get_paginator ("query" )
235- rows : List [List [Any ]] = []
236- schema : List [Dict [str , str ]] = []
237- for page in paginator .paginate (QueryString = sql , PaginationConfig = pagination_config or {}):
238- if not schema :
239- schema = _process_schema (page = page )
240- for row in page ["Rows" ]:
241- rows .append (_process_row (schema = schema , row = row ))
242- _logger .debug ("schema: %s" , schema )
243- df = pd .DataFrame (data = rows , columns = [c ["name" ] for c in schema ])
244- for col in schema :
245- if col ["type" ] == "VARCHAR" :
246- df [col ["name" ]] = df [col ["name" ]].astype ("string" )
247- return df
265+ result_iterator = _paginate_query (sql , pagination_config , boto3_session )
266+ if chunked :
267+ return result_iterator
268+ return pd .concat (result_iterator , ignore_index = True )
248269
249270
250271def create_database (
0 commit comments