Skip to content

Commit fdc7bef

Browse files
authored
(perf): Distribute timestream write with executor (#1715)
* Distribute timestream write method
1 parent f10f952 commit fdc7bef

File tree

8 files changed

+173
-82
lines changed

8 files changed

+173
-82
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ As a result existing scripts can run on significantly larger datasets with no co
128128
| | `unload` ||
129129
| `Athena` | `read_sql_query` ||
130130
| `LakeFormation` | `read_sql_query` ||
131+
| `Timestream` | `write` ||
131132
</p>
132133

133134
## [Read The Docs](https://aws-sdk-pandas.readthedocs.io/)

awswrangler/_data_types.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -779,14 +779,12 @@ def database_types_from_pandas(
779779
return database_types
780780

781781

782-
def timestream_type_from_pandas(df: pd.DataFrame) -> str:
782+
def timestream_type_from_pandas(df: pd.DataFrame) -> List[str]:
783783
"""Extract Amazon Timestream types from a Pandas DataFrame."""
784-
pyarrow_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(df=df, index=False, ignore_cols=[])
785-
if len(pyarrow_types) != 1 or list(pyarrow_types.values())[0] is None:
786-
raise RuntimeError(f"Invalid pyarrow_types: {pyarrow_types}")
787-
pyarrow_type: pa.DataType = list(pyarrow_types.values())[0]
788-
_logger.debug("pyarrow_type: %s", pyarrow_type)
789-
return pyarrow2timestream(dtype=pyarrow_type)
784+
return [
785+
pyarrow2timestream(pyarrow_type)
786+
for pyarrow_type in pyarrow_types_from_pandas(df=df, index=False, ignore_cols=[]).values()
787+
]
790788

791789

792790
def get_arrow_timestamp_unit(data_type: pa.lib.DataType) -> Any:

awswrangler/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,12 @@ def check_schema_changes(columns_types: Dict[str, str], table_input: Optional[Di
408408
)
409409

410410

411+
@engine.dispatch_on_engine
412+
def split_pandas_frame(df: pd.DataFrame, splits: int) -> List[pd.DataFrame]:
413+
"""Split a DataFrame into n chunks."""
414+
return [sub_df for sub_df in np.array_split(df, splits) if not sub_df.empty] # type: ignore
415+
416+
411417
@engine.dispatch_on_engine
412418
def table_refs_to_df(tables: List[pa.Table], kwargs: Dict[str, Any]) -> pd.DataFrame: # type: ignore
413419
"""Build Pandas dataframe from list of PyArrow tables."""

awswrangler/distributed/ray/_register.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# pylint: disable=import-outside-toplevel
33
from awswrangler._data_types import pyarrow_types_from_pandas
44
from awswrangler._distributed import MemoryFormatEnum, engine, memory_format
5-
from awswrangler._utils import is_pandas_frame, table_refs_to_df
5+
from awswrangler._utils import is_pandas_frame, split_pandas_frame, table_refs_to_df
66
from awswrangler.distributed.ray import ray_remote
77
from awswrangler.lakeformation._read import _get_work_unit_results
88
from awswrangler.s3._delete import _delete_objects
@@ -13,6 +13,7 @@
1313
from awswrangler.s3._write_dataset import _to_buckets, _to_partitions
1414
from awswrangler.s3._write_parquet import _to_parquet
1515
from awswrangler.s3._write_text import _to_text
16+
from awswrangler.timestream import _write_batch, _write_df
1617

1718

1819
def register_ray() -> None:
@@ -24,12 +25,18 @@ def register_ray() -> None:
2425
_select_query,
2526
_select_object_content,
2627
_wait_object_batch,
28+
_write_batch,
29+
_write_df,
2730
]:
2831
engine.register_func(func, ray_remote(func))
2932

3033
if memory_format.get() == MemoryFormatEnum.MODIN:
3134
from awswrangler.distributed.ray.modin._data_types import pyarrow_types_from_pandas_distributed
32-
from awswrangler.distributed.ray.modin._utils import _arrow_refs_to_df, _is_pandas_or_modin_frame
35+
from awswrangler.distributed.ray.modin._utils import (
36+
_arrow_refs_to_df,
37+
_is_pandas_or_modin_frame,
38+
_split_modin_frame,
39+
)
3340
from awswrangler.distributed.ray.modin.s3._read_parquet import _read_parquet_distributed
3441
from awswrangler.distributed.ray.modin.s3._read_text import _read_text_distributed
3542
from awswrangler.distributed.ray.modin.s3._write_dataset import (
@@ -47,7 +54,8 @@ def register_ray() -> None:
4754
_to_parquet: _to_parquet_distributed,
4855
_to_partitions: _to_partitions_distributed,
4956
_to_text: _to_text_distributed,
50-
table_refs_to_df: _arrow_refs_to_df,
5157
is_pandas_frame: _is_pandas_or_modin_frame,
58+
split_pandas_frame: _split_modin_frame,
59+
table_refs_to_df: _arrow_refs_to_df,
5260
}.items():
5361
engine.register_func(o_f, d_f) # type: ignore

awswrangler/distributed/ray/modin/_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from modin.distributed.dataframe.pandas import from_partitions
99
from ray.data._internal.arrow_block import ArrowBlockAccessor, ArrowRow
1010
from ray.data._internal.remote_fn import cached_remote_fn
11+
from ray.types import ObjectRef
1112

1213
from awswrangler import exceptions
1314
from awswrangler._arrow import _table_to_df
@@ -43,6 +44,11 @@ def _to_modin(
4344
)
4445

4546

47+
def _split_modin_frame(df: modin_pd.DataFrame, splits: int) -> List[ObjectRef[Any]]: # pylint: disable=unused-argument
48+
object_refs: List[ObjectRef[Any]] = ray.data.from_modin(df).get_internal_block_refs()
49+
return object_refs
50+
51+
4652
def _arrow_refs_to_df(arrow_refs: List[Callable[..., Any]], kwargs: Optional[Dict[str, Any]]) -> modin_pd.DataFrame:
4753
return _to_modin(dataset=ray.data.from_arrow_refs(arrow_refs), to_pandas_kwargs=kwargs)
4854

awswrangler/timestream.py

Lines changed: 69 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Amazon Timestream Module."""
22

3-
import concurrent.futures
43
import itertools
54
import logging
65
from datetime import datetime
@@ -11,10 +10,17 @@
1110
from botocore.config import Config
1211

1312
from awswrangler import _data_types, _utils
13+
from awswrangler._distributed import engine
14+
from awswrangler._threading import _get_executor
15+
from awswrangler.distributed.ray import ray_get
1416

1517
_logger: logging.Logger = logging.getLogger(__name__)
1618

1719

20+
def _flatten_list(elements: List[List[Any]]) -> List[Any]:
21+
return [item for sublist in elements for item in sublist]
22+
23+
1824
def _df2list(df: pd.DataFrame) -> List[List[Any]]:
1925
"""Extract Parameters."""
2026
parameters: List[List[Any]] = df.values.tolist()
@@ -27,17 +33,17 @@ def _df2list(df: pd.DataFrame) -> List[List[Any]]:
2733
return parameters
2834

2935

36+
@engine.dispatch_on_engine
3037
def _write_batch(
38+
boto3_session: Optional[boto3.Session],
3139
database: str,
3240
table: str,
3341
cols_names: List[str],
3442
measure_cols_names: List[str],
3543
measure_types: List[str],
3644
version: int,
3745
batch: List[Any],
38-
boto3_primitives: _utils.Boto3PrimitivesType,
3946
) -> List[Dict[str, str]]:
40-
boto3_session: boto3.Session = _utils.boto3_from_primitives(primitives=boto3_primitives)
4147
client: boto3.client = _utils.client(
4248
service_name="timestream-write",
4349
session=boto3_session,
@@ -85,6 +91,33 @@ def _write_batch(
8591
return []
8692

8793

94+
@engine.dispatch_on_engine
95+
def _write_df(
96+
df: pd.DataFrame,
97+
executor: Any,
98+
database: str,
99+
table: str,
100+
cols_names: List[str],
101+
measure_cols_names: List[str],
102+
measure_types: List[str],
103+
version: int,
104+
boto3_session: Optional[boto3.Session],
105+
) -> List[Dict[str, str]]:
106+
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df), max_length=100)
107+
_logger.debug("len(batches): %s", len(batches))
108+
return executor.map( # type: ignore
109+
_write_batch,
110+
boto3_session,
111+
itertools.repeat(database),
112+
itertools.repeat(table),
113+
itertools.repeat(cols_names),
114+
itertools.repeat(measure_cols_names),
115+
itertools.repeat(measure_types),
116+
itertools.repeat(version),
117+
batches,
118+
)
119+
120+
88121
def _cast_value(value: str, dtype: str) -> Any: # pylint: disable=too-many-branches,too-many-return-statements
89122
if dtype == "VARCHAR":
90123
return value
@@ -173,14 +206,18 @@ def write(
173206
measure_col: Union[str, List[str]],
174207
dimensions_cols: List[str],
175208
version: int = 1,
176-
num_threads: int = 32,
209+
use_threads: Union[bool, int] = True,
177210
boto3_session: Optional[boto3.Session] = None,
178211
) -> List[Dict[str, str]]:
179212
"""Store a Pandas DataFrame into a Amazon Timestream table.
180213
214+
Note
215+
----
216+
In case `use_threads=True`, the number of threads from os.cpu_count() is used.
217+
181218
Parameters
182219
----------
183-
df: pandas.DataFrame
220+
df : pandas.DataFrame
184221
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
185222
database : str
186223
Amazon Timestream database name.
@@ -195,8 +232,10 @@ def write(
195232
version : int
196233
Version number used for upserts.
197234
Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html.
198-
num_threads : str
199-
Number of thread to be used for concurrent writing.
235+
use_threads : bool, int
236+
True to enable concurrent writing, False to disable multiple threads.
237+
If enabled, os.cpu_count() is used as the number of threads.
238+
If integer is provided, specified number is used.
200239
boto3_session : boto3.Session(), optional
201240
Boto3 Session. The default boto3 Session will be used if boto3_session receive None.
202241
@@ -232,29 +271,33 @@ def write(
232271
"""
233272
measure_cols_names: List[str] = measure_col if isinstance(measure_col, list) else [measure_col]
234273
_logger.debug("measure_cols_names: %s", measure_cols_names)
235-
measure_types: List[str] = [
236-
_data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names
237-
]
274+
measure_types: List[str] = _data_types.timestream_type_from_pandas(df.loc[:, measure_cols_names])
238275
_logger.debug("measure_types: %s", measure_types)
239276
cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols
240277
_logger.debug("cols_names: %s", cols_names)
241-
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[cols_names]), max_length=100)
242-
_logger.debug("len(batches): %s", len(batches))
243-
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
244-
res: List[List[Any]] = list(
245-
executor.map(
246-
_write_batch,
247-
itertools.repeat(database),
248-
itertools.repeat(table),
249-
itertools.repeat(cols_names),
250-
itertools.repeat(measure_cols_names),
251-
itertools.repeat(measure_types),
252-
itertools.repeat(version),
253-
batches,
254-
itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)),
255-
)
278+
dfs = _utils.split_pandas_frame(df.loc[:, cols_names], _utils.ensure_cpu_count(use_threads=use_threads))
279+
_logger.debug("len(dfs): %s", len(dfs))
280+
281+
executor = _get_executor(use_threads=use_threads)
282+
errors = _flatten_list(
283+
ray_get(
284+
[
285+
_write_df(
286+
df=df,
287+
executor=executor,
288+
database=database,
289+
table=table,
290+
cols_names=cols_names,
291+
measure_cols_names=measure_cols_names,
292+
measure_types=measure_types,
293+
version=version,
294+
boto3_session=boto3_session,
295+
)
296+
for df in dfs
297+
]
256298
)
257-
return [item for sublist in res for item in sublist]
299+
)
300+
return _flatten_list(ray_get(errors))
258301

259302

260303
def query(

tests/load/test_database.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from datetime import datetime
2+
3+
import pytest
4+
import ray
5+
from pyarrow import csv
6+
7+
import awswrangler as wr
8+
9+
from .._utils import ExecutionTimer
10+
11+
12+
@pytest.mark.parametrize("benchmark_time", [180])
13+
def test_real_csv_load_scenario(benchmark_time: int, timestream_database_and_table: str) -> None:
14+
name = timestream_database_and_table
15+
df = (
16+
ray.data.read_csv(
17+
"https://raw.githubusercontent.com/awslabs/amazon-timestream-tools/mainline/sample_apps/data/sample.csv",
18+
**{
19+
"read_options": csv.ReadOptions(
20+
column_names=[
21+
"ignore0",
22+
"region",
23+
"ignore1",
24+
"az",
25+
"ignore2",
26+
"hostname",
27+
"measure_kind",
28+
"measure",
29+
"ignore3",
30+
"ignore4",
31+
"ignore5",
32+
]
33+
)
34+
},
35+
)
36+
.to_modin()
37+
.loc[:, ["region", "az", "hostname", "measure_kind", "measure"]]
38+
)
39+
40+
df["time"] = datetime.now()
41+
df.reset_index(inplace=True, drop=False)
42+
df_cpu = df[df.measure_kind == "cpu_utilization"]
43+
df_memory = df[df.measure_kind == "memory_utilization"]
44+
45+
with ExecutionTimer("elapsed time of wr.timestream.write()") as timer:
46+
rejected_records = wr.timestream.write(
47+
df=df_cpu,
48+
database=name,
49+
table=name,
50+
time_col="time",
51+
measure_col="measure",
52+
dimensions_cols=["index", "region", "az", "hostname"],
53+
)
54+
assert len(rejected_records) == 0
55+
rejected_records = wr.timestream.write(
56+
df=df_memory,
57+
database=name,
58+
table=name,
59+
time_col="time",
60+
measure_col="measure",
61+
dimensions_cols=["index", "region", "az", "hostname"],
62+
)
63+
assert len(rejected_records) == 0
64+
assert timer.elapsed_time < benchmark_time
65+
66+
df = wr.timestream.query(f'SELECT COUNT(*) AS counter FROM "{name}"."{name}"')
67+
assert df["counter"].iloc[0] == 126_000

0 commit comments

Comments
 (0)