Skip to content

Commit a987364

Browse files
committed
Add Amazon Timestream support.
1 parent 98d4a68 commit a987364

File tree

10 files changed

+395
-38
lines changed

10 files changed

+395
-38
lines changed

awswrangler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
redshift,
2020
s3,
2121
sts,
22+
timestream,
2223
)
2324
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
2425
from awswrangler._config import config # noqa
@@ -36,6 +37,7 @@
3637
"mysql",
3738
"postgresql",
3839
"config",
40+
"timestream",
3941
"__description__",
4042
"__license__",
4143
"__title__",

awswrangler/_data_types.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,29 @@ def pyarrow2postgresql( # pylint: disable=too-many-branches,too-many-return-sta
166166
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")
167167

168168

169+
def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements
170+
"""Pyarrow to Amazon Timestream data types conversion."""
171+
if pa.types.is_int8(dtype):
172+
return "BIGINT"
173+
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
174+
return "BIGINT"
175+
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
176+
return "BIGINT"
177+
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
178+
return "BIGINT"
179+
if pa.types.is_uint64(dtype):
180+
return "BIGINT"
181+
if pa.types.is_float32(dtype):
182+
return "DOUBLE"
183+
if pa.types.is_float64(dtype):
184+
return "DOUBLE"
185+
if pa.types.is_boolean(dtype):
186+
return "BOOLEAN"
187+
if pa.types.is_string(dtype):
188+
return "VARCHAR"
189+
raise exceptions.UnsupportedType(f"Unsupported Amazon Timestream measure type: {dtype}")
190+
191+
169192
def athena2pyarrow(dtype: str) -> pa.DataType: # pylint: disable=too-many-return-statements
170193
"""Athena to PyArrow data types conversion."""
171194
dtype = dtype.lower().replace(" ", "")
@@ -587,3 +610,13 @@ def database_types_from_pandas(
587610
database_types[col_name] = converter_func(col_dtype, string_type)
588611
_logger.debug("database_types: %s", database_types)
589612
return database_types
613+
614+
615+
def timestream_type_from_pandas(df: pd.DataFrame) -> str:
616+
"""Extract Amazon Timestream types from a Pandas DataFrame."""
617+
pyarrow_types: Dict[str, Optional[pa.DataType]] = pyarrow_types_from_pandas(df=df, index=False, ignore_cols=[])
618+
if len(pyarrow_types) != 1 or list(pyarrow_types.values())[0] is None:
619+
raise RuntimeError(f"Invalid pyarrow_types: {pyarrow_types}")
620+
pyarrow_type: pa.DataType = list(pyarrow_types.values())[0]
621+
_logger.debug("pyarrow_type: %s", pyarrow_type)
622+
return pyarrow2timestream(dtype=pyarrow_type)

awswrangler/_databases.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,15 @@ def read_sql_query(
133133
return _iterate_cursor(
134134
cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe
135135
)
136+
137+
138+
def extract_parameters(df: pd.DataFrame) -> List[List[Any]]:
139+
"""Extract Parameters."""
140+
parameters: List[List[Any]] = df.values.tolist()
141+
for i, row in enumerate(parameters):
142+
for j, value in enumerate(row):
143+
if pd.isna(value):
144+
parameters[i][j] = None
145+
elif hasattr(value, "to_pydatetime"):
146+
parameters[i][j] = value.to_pydatetime()
147+
return parameters

awswrangler/_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,16 @@ def _get_endpoint_url(service_name: str) -> Optional[str]:
8383
return endpoint_url
8484

8585

86-
def client(service_name: str, session: Optional[boto3.Session] = None) -> boto3.client:
86+
def client(
87+
service_name: str, session: Optional[boto3.Session] = None, config: Optional[botocore.config.Config] = None
88+
) -> boto3.client:
8789
"""Create a valid boto3.client."""
8890
endpoint_url: Optional[str] = _get_endpoint_url(service_name=service_name)
8991
return ensure_session(session=session).client(
90-
service_name=service_name, endpoint_url=endpoint_url, use_ssl=True, config=botocore_config()
92+
service_name=service_name,
93+
endpoint_url=endpoint_url,
94+
use_ssl=True,
95+
config=botocore_config() if config is None else config,
9196
)
9297

9398

awswrangler/mysql.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,6 @@ def _create_table(
6666
cursor.execute(sql)
6767

6868

69-
def _extract_parameters(df: pd.DataFrame) -> List[List[Any]]:
70-
parameters: List[List[Any]] = df.values.tolist()
71-
for i, row in enumerate(parameters):
72-
for j, value in enumerate(row):
73-
if pd.isna(value):
74-
parameters[i][j] = None
75-
elif hasattr(value, "to_pydatetime"):
76-
parameters[i][j] = value.to_pydatetime()
77-
return parameters
78-
79-
8069
def connect(
8170
connection: str,
8271
catalog_id: Optional[str] = None,
@@ -339,7 +328,7 @@ def to_sql(
339328
placeholders: str = ", ".join(["%s"] * len(df.columns))
340329
sql: str = f"INSERT INTO {schema}.{table} VALUES ({placeholders})"
341330
_logger.debug("sql: %s", sql)
342-
parameters: List[List[Any]] = _extract_parameters(df=df)
331+
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
343332
cursor.executemany(sql, parameters)
344333
con.commit() # type: ignore
345334
except Exception as ex:

awswrangler/postgresql.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,6 @@ def _create_table(
7070
cursor.execute(sql)
7171

7272

73-
def _extract_parameters(df: pd.DataFrame) -> List[List[Any]]:
74-
parameters: List[List[Any]] = df.values.tolist()
75-
for i, row in enumerate(parameters):
76-
for j, value in enumerate(row):
77-
if pd.isna(value):
78-
parameters[i][j] = None
79-
elif hasattr(value, "to_pydatetime"):
80-
parameters[i][j] = value.to_pydatetime()
81-
return parameters
82-
83-
8473
def connect(
8574
connection: str,
8675
catalog_id: Optional[str] = None,
@@ -343,7 +332,7 @@ def to_sql(
343332
placeholders: str = ", ".join(["%s"] * len(df.columns))
344333
sql: str = f"INSERT INTO {schema}.{table} VALUES ({placeholders})"
345334
_logger.debug("sql: %s", sql)
346-
parameters: List[List[Any]] = _extract_parameters(df=df)
335+
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
347336
cursor.executemany(sql, parameters)
348337
con.commit()
349338
except Exception as ex:

awswrangler/redshift.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -229,17 +229,6 @@ def _create_table(
229229
return table, schema
230230

231231

232-
def _extract_parameters(df: pd.DataFrame) -> List[List[Any]]:
233-
parameters: List[List[Any]] = df.values.tolist()
234-
for i, row in enumerate(parameters):
235-
for j, value in enumerate(row):
236-
if pd.isna(value):
237-
parameters[i][j] = None
238-
elif hasattr(value, "to_pydatetime"):
239-
parameters[i][j] = value.to_pydatetime()
240-
return parameters
241-
242-
243232
def _read_parquet_iterator(
244233
path: str,
245234
keep_files: bool,
@@ -664,7 +653,7 @@ def to_sql(
664653
schema_str = f"{created_schema}." if created_schema else ""
665654
sql: str = f"INSERT INTO {schema_str}{created_table} VALUES ({placeholders})"
666655
_logger.debug("sql: %s", sql)
667-
parameters: List[List[Any]] = _extract_parameters(df=df)
656+
parameters: List[List[Any]] = _db_utils.extract_parameters(df=df)
668657
cursor.executemany(sql, parameters)
669658
if table != created_table: # upsert
670659
_upsert(cursor=cursor, schema=schema, table=table, temp_table=created_table, primary_keys=primary_keys)

0 commit comments

Comments
 (0)