@@ -31,7 +31,8 @@ def _write_batch(
3131 database : str ,
3232 table : str ,
3333 cols_names : List [str ],
34- measure_type : str ,
34+ measure_cols_names : List [str ],
35+ measure_types : List [str ],
3536 version : int ,
3637 batch : List [Any ],
3738 boto3_primitives : _utils .Boto3PrimitivesType ,
@@ -43,27 +44,41 @@ def _write_batch(
4344 botocore_config = Config (read_timeout = 20 , max_pool_connections = 5000 , retries = {"max_attempts" : 10 }),
4445 )
4546 try :
47+ time_loc = 0
48+ measure_cols_loc = 1
49+ dimensions_cols_loc = 1 + len (measure_cols_names )
50+ records : List [Dict [str , Any ]] = []
51+ for rec in batch :
52+ record : Dict [str , Any ] = {
53+ "Dimensions" : [
54+ {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
55+ for name , value in zip (cols_names [dimensions_cols_loc :], rec [dimensions_cols_loc :])
56+ ],
57+ "Time" : str (round (rec [time_loc ].timestamp () * 1_000 )),
58+ "TimeUnit" : "MILLISECONDS" ,
59+ "Version" : version ,
60+ }
61+ if len (measure_cols_names ) == 1 :
62+ record ["MeasureName" ] = measure_cols_names [0 ]
63+ record ["MeasureValueType" ] = measure_types [0 ]
64+ record ["MeasureValue" ] = str (rec [measure_cols_loc ])
65+ else :
66+ record ["MeasureName" ] = measure_cols_names [0 ]
67+ record ["MeasureValueType" ] = "MULTI"
68+ record ["MeasureValues" ] = [
69+ {"Name" : measure_name , "Value" : str (measure_value ), "Type" : measure_value_type }
70+ for measure_name , measure_value , measure_value_type in zip (
71+ measure_cols_names , rec [measure_cols_loc :dimensions_cols_loc ], measure_types
72+ )
73+ ]
74+ records .append (record )
4675 _utils .try_it (
4776 f = client .write_records ,
4877 ex = (client .exceptions .ThrottlingException , client .exceptions .InternalServerException ),
4978 max_num_tries = 5 ,
5079 DatabaseName = database ,
5180 TableName = table ,
52- Records = [
53- {
54- "Dimensions" : [
55- {"Name" : name , "DimensionValueType" : "VARCHAR" , "Value" : str (value )}
56- for name , value in zip (cols_names [2 :], rec [2 :])
57- ],
58- "MeasureName" : cols_names [1 ],
59- "MeasureValueType" : measure_type ,
60- "MeasureValue" : str (rec [1 ]),
61- "Time" : str (round (rec [0 ].timestamp () * 1_000 )),
62- "TimeUnit" : "MILLISECONDS" ,
63- "Version" : version ,
64- }
65- for rec in batch
66- ],
81+ Records = records ,
6782 )
6883 except client .exceptions .RejectedRecordsException as ex :
6984 return cast (List [Dict [str , str ]], ex .response ["RejectedRecords" ])
@@ -148,7 +163,7 @@ def write(
148163 database : str ,
149164 table : str ,
150165 time_col : str ,
151- measure_col : str ,
166+ measure_col : Union [ str , List [ str ]] ,
152167 dimensions_cols : List [str ],
153168 version : int = 1 ,
154169 num_threads : int = 32 ,
@@ -166,8 +181,8 @@ def write(
166181 Amazon Timestream table name.
167182 time_col : str
168183 DataFrame column name to be used as time. MUST be a timestamp column.
169- measure_col : str
170- DataFrame column name to be used as measure.
184+ measure_col : Union[ str, List[str]]
185+ DataFrame column name(s) to be used as measure.
171186 dimensions_cols : List[str]
172187 List of DataFrame column names to be used as dimensions.
173188 version : int
@@ -208,9 +223,13 @@ def write(
208223 >>> assert len(rejected_records) == 0
209224
210225 """
211- measure_type : str = _data_types .timestream_type_from_pandas (df [[measure_col ]])
212- _logger .debug ("measure_type: %s" , measure_type )
213- cols_names : List [str ] = [time_col , measure_col ] + dimensions_cols
226+ measure_cols_names : List [str ] = measure_col if isinstance (measure_col , list ) else [measure_col ]
227+ _logger .debug ("measure_cols_names: %s" , measure_cols_names )
228+ measure_types : List [str ] = [
229+ _data_types .timestream_type_from_pandas (df [[measure_col_name ]]) for measure_col_name in measure_cols_names
230+ ]
231+ _logger .debug ("measure_types: %s" , measure_types )
232+ cols_names : List [str ] = [time_col ] + measure_cols_names + dimensions_cols
214233 _logger .debug ("cols_names: %s" , cols_names )
215234 batches : List [List [Any ]] = _utils .chunkify (lst = _df2list (df = df [cols_names ]), max_length = 100 )
216235 _logger .debug ("len(batches): %s" , len (batches ))
@@ -221,7 +240,8 @@ def write(
221240 itertools .repeat (database ),
222241 itertools .repeat (table ),
223242 itertools .repeat (cols_names ),
224- itertools .repeat (measure_type ),
243+ itertools .repeat (measure_cols_names ),
244+ itertools .repeat (measure_types ),
225245 itertools .repeat (version ),
226246 batches ,
227247 itertools .repeat (_utils .boto3_to_primitives (boto3_session = boto3_session )),
0 commit comments