Skip to content

Commit 4df8f26

Browse files
enhancement: Support optional measure_name in wr.timestream.write() (#1925)
* adding measure name for MULTI records * support scalar value measure_name * doc formatting * Consolidating unit test
1 parent 91682ce commit 4df8f26

File tree

2 files changed

+71
-3
lines changed

2 files changed

+71
-3
lines changed

awswrangler/timestream.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _write_batch(
4444
version: int,
4545
batch: List[Any],
4646
boto3_primitives: _utils.Boto3PrimitivesType,
47+
measure_name: Optional[str] = None,
4748
) -> List[Dict[str, str]]:
4849
boto3_session: boto3.Session = _utils.boto3_from_primitives(primitives=boto3_primitives)
4950
client: boto3.client = _utils.client(
@@ -67,11 +68,11 @@ def _write_batch(
6768
"Version": version,
6869
}
6970
if len(measure_cols_names) == 1:
70-
record["MeasureName"] = measure_cols_names[0]
71+
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
7172
record["MeasureValueType"] = measure_types[0]
7273
record["MeasureValue"] = str(rec[measure_cols_loc])
7374
else:
74-
record["MeasureName"] = measure_cols_names[0]
75+
record["MeasureName"] = measure_name if measure_name else measure_cols_names[0]
7576
record["MeasureValueType"] = "MULTI"
7677
record["MeasureValues"] = [
7778
_format_measure(measure_name, measure_value, measure_value_type)
@@ -192,6 +193,7 @@ def write(
192193
dimensions_cols: List[str],
193194
version: int = 1,
194195
num_threads: int = 32,
196+
measure_name: Optional[str] = None,
195197
boto3_session: Optional[boto3.Session] = None,
196198
) -> List[Dict[str, str]]:
197199
"""Store a Pandas DataFrame into a Amazon Timestream table.
@@ -213,6 +215,9 @@ def write(
213215
version : int
214216
Version number used for upserts.
215217
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
218+
measure_name : Optional[str]
219+
Name that represents the data attribute of the time series.
220+
Overrides ``measure_col`` if specified.
216221
num_threads : str
217222
Number of thread to be used for concurrent writing.
218223
boto3_session : boto3.Session(), optional
@@ -248,8 +253,9 @@ def write(
248253
>>> assert len(rejected_records) == 0
249254
250255
"""
251-
measure_cols_names: List[str] = measure_col if isinstance(measure_col, list) else [measure_col]
256+
measure_cols_names = measure_col if isinstance(measure_col, list) else [measure_col]
252257
_logger.debug("measure_cols_names: %s", measure_cols_names)
258+
253259
measure_types: List[str] = [
254260
_data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names
255261
]
@@ -270,6 +276,7 @@ def write(
270276
itertools.repeat(version),
271277
batches,
272278
itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)),
279+
itertools.repeat(measure_name),
273280
)
274281
)
275282
return [item for sublist in res for item in sublist]

tests/test_timestream.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,3 +345,64 @@ def test_timestamp_measure_column(timestream_database_and_table):
345345
""",
346346
)
347347
assert df["measure_t"].dtype == "datetime64[ns]"
348+
349+
350+
@pytest.mark.parametrize(
351+
"record_type",
352+
["MULTI", "SCALAR"],
353+
)
354+
def test_measure_name(timestream_database_and_table, record_type):
355+
data = {"time": [datetime.now()] * 3}
356+
args = {
357+
"database": timestream_database_and_table,
358+
"table": timestream_database_and_table,
359+
"time_col": "time",
360+
}
361+
if record_type == "MULTI":
362+
data.update(
363+
{
364+
"dim0": ["foo", "boo", "bar"],
365+
"dim1": [1, None, 3],
366+
"measure_0": [1.1, 1.2, 1.3],
367+
"measure_1": [2.1, 2.2, 2.3],
368+
}
369+
)
370+
args.update(
371+
{
372+
"measure_col": ["measure_0", "measure_1"],
373+
"measure_name": "example",
374+
"dimensions_cols": ["dim0", "dim1"],
375+
}
376+
)
377+
else:
378+
data.update(
379+
{
380+
"dim": ["foo", "boo", "bar"],
381+
"measure": [1.1, 1.2, 1.3],
382+
}
383+
)
384+
args.update(
385+
{
386+
"measure_col": ["measure"],
387+
"measure_name": "example",
388+
"dimensions_cols": ["dim"],
389+
}
390+
)
391+
392+
df = pd.DataFrame(data)
393+
rejected_records = wr.timestream.write(
394+
df=df,
395+
**args,
396+
)
397+
398+
assert len(rejected_records) == 0
399+
400+
df = wr.timestream.query(
401+
f"""
402+
SELECT
403+
*
404+
FROM "{timestream_database_and_table}"."{timestream_database_and_table}"
405+
""",
406+
)
407+
for measure_name in df["measure_name"].tolist():
408+
assert measure_name == "example"

0 commit comments

Comments
 (0)