11"""Amazon Timestream Module."""
22
3- import concurrent .futures
43import itertools
54import logging
65from datetime import datetime
1110from botocore .config import Config
1211
1312from awswrangler import _data_types , _utils
13+ from awswrangler ._distributed import engine
14+ from awswrangler ._threading import _get_executor
15+ from awswrangler .distributed .ray import ray_get
1416
1517_logger : logging .Logger = logging .getLogger (__name__ )
1618
1719
20+ def _flatten_list (elements : List [List [Any ]]) -> List [Any ]:
21+ return [item for sublist in elements for item in sublist ]
22+
23+
1824def _df2list (df : pd .DataFrame ) -> List [List [Any ]]:
1925 """Extract Parameters."""
2026 parameters : List [List [Any ]] = df .values .tolist ()
@@ -27,17 +33,17 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
2733 return parameters
2834
2935
36+ @engine .dispatch_on_engine
3037def _write_batch (
38+ boto3_session : Optional [boto3 .Session ],
3139 database : str ,
3240 table : str ,
3341 cols_names : List [str ],
3442 measure_cols_names : List [str ],
3543 measure_types : List [str ],
3644 version : int ,
3745 batch : List [Any ],
38- boto3_primitives : _utils .Boto3PrimitivesType ,
3946) -> List [Dict [str , str ]]:
40- boto3_session : boto3 .Session = _utils .boto3_from_primitives (primitives = boto3_primitives )
4147 client : boto3 .client = _utils .client (
4248 service_name = "timestream-write" ,
4349 session = boto3_session ,
@@ -85,6 +91,33 @@ def _write_batch(
8591 return []
8692
8793
94+ @engine .dispatch_on_engine
95+ def _write_df (
96+ df : pd .DataFrame ,
97+ executor : Any ,
98+ database : str ,
99+ table : str ,
100+ cols_names : List [str ],
101+ measure_cols_names : List [str ],
102+ measure_types : List [str ],
103+ version : int ,
104+ boto3_session : Optional [boto3 .Session ],
105+ ) -> List [Dict [str , str ]]:
106+ batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df ), max_length = 100 )
107+ _logger .debug ("len(batches): %s" , len (batches ))
108+ return executor .map ( # type: ignore
109+ _write_batch ,
110+ boto3_session ,
111+ itertools .repeat (database ),
112+ itertools .repeat (table ),
113+ itertools .repeat (cols_names ),
114+ itertools .repeat (measure_cols_names ),
115+ itertools .repeat (measure_types ),
116+ itertools .repeat (version ),
117+ batches ,
118+ )
119+
120+
88121def _cast_value (value : str , dtype : str ) -> Any : # pylint: disable=too-many-branches,too-many-return-statements
89122 if dtype == "VARCHAR" :
90123 return value
@@ -173,14 +206,18 @@ def write(
173206 measure_col : Union [str , List [str ]],
174207 dimensions_cols : List [str ],
175208 version : int = 1 ,
176- num_threads : int = 32 ,
209+ use_threads : Union [ bool , int ] = True ,
177210 boto3_session : Optional [boto3 .Session ] = None ,
178211) -> List [Dict [str , str ]]:
179212 """Store a Pandas DataFrame into a Amazon Timestream table.
180213
214+ Note
215+ ----
216+ In case `use_threads=True`, the number of threads from os.cpu_count() is used.
217+
181218 Parameters
182219 ----------
183- df: pandas.DataFrame
220+ df : pandas.DataFrame
184221 Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
185222 database : str
186223 Amazon Timestream database name.
@@ -195,8 +232,10 @@ def write(
195232 version : int
196233 Version number used for upserts.
197234 Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
198- num_threads : str
199- Number of thread to be used for concurrent writing.
235+ use_threads : bool, int
236+ True to enable concurrent writing, False to disable multiple threads.
237+ If enabled, os.cpu_count() is used as the number of threads.
238+ If integer is provided, specified number is used.
200239 boto3_session : boto3.Session(), optional
201240 Boto3 Session. The default boto3 Session will be used if boto3_session receive None.
202241
@@ -232,29 +271,33 @@ def write(
232271 """
233272 measure_cols_names : List [str ] = measure_col if isinstance (measure_col , list ) else [measure_col ]
234273 _logger .debug ("measure_cols_names: %s" , measure_cols_names )
235- measure_types : List [str ] = [
236- _data_types .timestream_type_from_pandas (df [[measure_col_name ]]) for measure_col_name in measure_cols_names
237- ]
274+ measure_types : List [str ] = _data_types .timestream_type_from_pandas (df .loc [:, measure_cols_names ])
238275 _logger .debug ("measure_types: %s" , measure_types )
239276 cols_names : List [str ] = [time_col ] + measure_cols_names + dimensions_cols
240277 _logger .debug ("cols_names: %s" , cols_names )
241- batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [cols_names ]), max_length = 100 )
242- _logger .debug ("len(batches): %s" , len (batches ))
243- with concurrent .futures .ThreadPoolExecutor (max_workers = num_threads ) as executor :
244- res : List [List [Any ]] = list (
245- executor .map (
246- _write_batch ,
247- itertools .repeat (database ),
248- itertools .repeat (table ),
249- itertools .repeat (cols_names ),
250- itertools .repeat (measure_cols_names ),
251- itertools .repeat (measure_types ),
252- itertools .repeat (version ),
253- batches ,
254- itertools .repeat (_utils .boto3_to_primitives (boto3_session = boto3_session )),
255- )
278+ dfs = _utils .split_pandas_frame (df .loc [:, cols_names ], _utils .ensure_cpu_count (use_threads = use_threads ))
279+ _logger .debug ("len(dfs): %s" , len (dfs ))
280+
281+ executor = _get_executor (use_threads = use_threads )
282+ errors = _flatten_list (
283+ ray_get (
284+ [
285+ _write_df (
286+ df = df ,
287+ executor = executor ,
288+ database = database ,
289+ table = table ,
290+ cols_names = cols_names ,
291+ measure_cols_names = measure_cols_names ,
292+ measure_types = measure_types ,
293+ version = version ,
294+ boto3_session = boto3_session ,
295+ )
296+ for df in dfs
297+ ]
256298 )
257- return [item for sublist in res for item in sublist ]
299+ )
300+ return _flatten_list (ray_get (errors ))
258301
259302
260303def query (
0 commit comments