Skip to content

Commit 63385af

Browse files
committed
Applying mypy strict mode.
1 parent 3369123 commit 63385af

32 files changed

+284
-120
lines changed

awswrangler/__init__.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,21 @@
1111
from awswrangler.__metadata__ import __description__, __license__, __title__, __version__ # noqa
1212
from awswrangler._config import config # noqa
1313

14+
__all__ = [
15+
"athena",
16+
"catalog",
17+
"cloudwatch",
18+
"db",
19+
"emr",
20+
"exceptions",
21+
"quicksight",
22+
"s3",
23+
"sts",
24+
"config",
25+
"__description__",
26+
"__license__",
27+
"__title__",
28+
"__version__",
29+
]
30+
1431
_logging.getLogger("awswrangler").addHandler(_logging.NullHandler())

awswrangler/_config.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717

1818
class _ConfigArg(NamedTuple):
19-
dtype: Type
19+
dtype: Type[Union[str, bool, int]]
2020
nullable: bool
2121
enforced: bool = False
2222

@@ -36,7 +36,7 @@ class _ConfigArg(NamedTuple):
3636
class _Config:
3737
"""Wrangler's Configuration class."""
3838

39-
def __init__(self):
39+
def __init__(self) -> None:
4040
self._loaded_values: Dict[str, _ConfigValueType] = {}
4141
name: str
4242
for name in _CONFIG_ARGS:
@@ -127,11 +127,11 @@ def _reset_item(self, item: str) -> None:
127127
del self._loaded_values[item]
128128
self._load_config(name=item)
129129

130-
def _repr_html_(self):
130+
def _repr_html_(self) -> Any:
131131
return self.to_pandas().to_html()
132132

133133
@staticmethod
134-
def _apply_type(name: str, value: Any, dtype: Type, nullable: bool) -> _ConfigValueType:
134+
def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nullable: bool) -> _ConfigValueType:
135135
if _Config._is_null(value=value):
136136
if nullable is True:
137137
return None
@@ -215,7 +215,9 @@ def s3fs_block_size(self, value: int) -> None:
215215
self._set_config_value(key="s3fs_block_size", value=value)
216216

217217

218-
def _inject_config_doc(doc: str, available_configs: Tuple[str, ...]) -> str:
218+
def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -> str:
219+
if doc is None:
220+
return "Undocumented function."
219221
if "\n Parameters" not in doc:
220222
return doc
221223
header: str = (
@@ -235,14 +237,14 @@ def _inject_config_doc(doc: str, available_configs: Tuple[str, ...]) -> str:
235237
return _utils.insert_str(text=doc, token="\n Parameters", insert=insertion)
236238

237239

238-
def apply_configs(function) -> Callable:
240+
def apply_configs(function: Callable[..., Any]) -> Callable[..., Any]:
239241
"""Decorate some function with configs."""
240242
signature = inspect.signature(function)
241243
args_names: Tuple[str, ...] = tuple(signature.parameters.keys())
242244
available_configs: Tuple[str, ...] = tuple(x for x in _CONFIG_ARGS if x in args_names)
243245

244-
def wrapper(*args, **kwargs):
245-
args: Dict[str, Any] = signature.bind_partial(*args, **kwargs).arguments
246+
def wrapper(*args_raw: Any, **kwargs: Any) -> Any:
247+
args: Dict[str, Any] = signature.bind_partial(*args_raw, **kwargs).arguments
246248
for name in available_configs:
247249
if hasattr(config, name) is True:
248250
value: _ConfigValueType = config[name]

awswrangler/_data_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def process_not_inferred_dtype(ex: pa.ArrowInvalid) -> pa.DataType:
315315
"""Infer data type from PyArrow inference exception."""
316316
ex_str = str(ex)
317317
_logger.debug("PyArrow was not able to infer data type:\n%s", ex_str)
318-
match: Optional[Match] = re.search(
318+
match: Optional[Match[str]] = re.search(
319319
pattern="Could not convert (.*) with type (.*): did not recognize "
320320
"Python value type when inferring an Arrow data type",
321321
string=ex_str,

awswrangler/_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import random
88
import time
9-
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
9+
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union, cast
1010

1111
import boto3 # type: ignore
1212
import botocore.config # type: ignore
@@ -196,19 +196,19 @@ def get_fs(
196196
return fs
197197

198198

199-
def open_file(fs: s3fs.S3FileSystem, **kwargs) -> Any:
199+
def open_file(fs: s3fs.S3FileSystem, **kwargs: Any) -> Any:
200200
"""Open s3fs file with retries to overcome eventual consistency."""
201201
fs.invalidate_cache()
202202
fs.clear_instance_cache()
203203
return try_it(f=fs.open, ex=FileNotFoundError, **kwargs)
204204

205205

206-
def empty_generator() -> Generator:
206+
def empty_generator() -> Generator[None, None, None]:
207207
"""Empty Generator."""
208208
yield from ()
209209

210210

211-
def ensure_postgresql_casts():
211+
def ensure_postgresql_casts() -> None:
212212
"""Ensure that psycopg2 will handle some data types right."""
213213
psycopg2.extensions.register_adapter(bytes, psycopg2.Binary)
214214
typecast_bytea = lambda data, cur: None if data is None else bytes(psycopg2.BINARY(data, cur)) # noqa
@@ -225,7 +225,7 @@ def get_region_from_subnet(subnet_id: str, boto3_session: Optional[boto3.Session
225225
"""Extract region from Subnet ID."""
226226
session: boto3.Session = ensure_session(session=boto3_session)
227227
client_ec2: boto3.client = client(service_name="ec2", session=session)
228-
return client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:-1]
228+
return cast(str, client_ec2.describe_subnets(SubnetIds=[subnet_id])["Subnets"][0]["AvailabilityZone"][:-1])
229229

230230

231231
def get_region_from_session(boto3_session: Optional[boto3.Session] = None, default_region: Optional[str] = None) -> str:
@@ -279,7 +279,7 @@ def check_duplicated_columns(df: pd.DataFrame) -> Any:
279279
raise exceptions.InvalidDataFrame(f"There is duplicated column names in your DataFrame: {duplicated}")
280280

281281

282-
def try_it(f: Callable, ex, base: float = 1.0, max_num_tries: int = 3, **kwargs) -> Any:
282+
def try_it(f: Callable[..., Any], ex: Any, base: float = 1.0, max_num_tries: int = 3, **kwargs: Any) -> Any:
283283
"""Run function with decorrelated Jitter.
284284
285285
Reference: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/

awswrangler/athena/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,18 @@
1313
stop_query_execution,
1414
wait_query,
1515
)
16+
17+
__all__ = [
18+
"read_sql_query",
19+
"read_sql_table",
20+
"create_athena_bucket",
21+
"describe_table",
22+
"get_query_columns_types",
23+
"get_query_execution",
24+
"get_work_group",
25+
"repair_table",
26+
"show_create_table",
27+
"start_query_execution",
28+
"stop_query_execution",
29+
"wait_query",
30+
]

awswrangler/athena/_read.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _get_last_query_executions(
111111
yield execution_data.get("QueryExecutions")
112112

113113

114-
def _sort_successful_executions_data(query_executions: List[Dict[str, Any]]):
114+
def _sort_successful_executions_data(query_executions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
115115
"""
116116
Sorts `_get_last_query_executions`'s results based on query Completion DateTime.
117117
@@ -289,14 +289,16 @@ def _fetch_csv_result(
289289

290290

291291
def _resolve_query_with_cache(
292-
cache_info,
292+
cache_info: _CacheInfo,
293293
categories: Optional[List[str]],
294294
chunksize: Optional[Union[int, bool]],
295295
use_threads: bool,
296296
session: Optional[boto3.Session],
297-
):
297+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
298298
"""Fetch cached data and return it as a pandas DataFrame (or list of DataFrames)."""
299299
_logger.debug("cache_info:\n%s", cache_info)
300+
if cache_info.query_execution_id is None:
301+
raise RuntimeError("Trying to resolve with cache but w/o any query execution ID.")
300302
query_metadata: _QueryMetadata = _get_query_metadata(
301303
query_execution_id=cache_info.query_execution_id,
302304
boto3_session=session,
@@ -330,14 +332,14 @@ def _resolve_query_without_cache_ctas(
330332
keep_files: bool,
331333
chunksize: Union[int, bool, None],
332334
categories: Optional[List[str]],
333-
encryption,
335+
encryption: Optional[str],
334336
workgroup: Optional[str],
335337
kms_key: Optional[str],
336338
wg_config: _WorkGroupConfig,
337339
name: Optional[str],
338340
use_threads: bool,
339341
boto3_session: boto3.Session,
340-
):
342+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
341343
path: str = f"{s3_output}/{name}"
342344
ext_location: str = "\n" if wg_config.enforced is True else f",\n external_location = '{path}'\n"
343345
sql = (
@@ -406,13 +408,13 @@ def _resolve_query_without_cache_regular(
406408
keep_files: bool,
407409
chunksize: Union[int, bool, None],
408410
categories: Optional[List[str]],
409-
encryption,
411+
encryption: Optional[str],
410412
workgroup: Optional[str],
411413
kms_key: Optional[str],
412414
wg_config: _WorkGroupConfig,
413415
use_threads: bool,
414416
boto3_session: boto3.Session,
415-
):
417+
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
416418
_logger.debug("sql: %s", sql)
417419
query_id: str = _start_query_execution(
418420
sql=sql,
@@ -698,7 +700,7 @@ def read_sql_table(
698700
table: str,
699701
database: str,
700702
ctas_approach: bool = True,
701-
categories: List[str] = None,
703+
categories: Optional[List[str]] = None,
702704
chunksize: Optional[Union[int, bool]] = None,
703705
s3_output: Optional[str] = None,
704706
workgroup: Optional[str] = None,

awswrangler/athena/_utils.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import time
66
import warnings
77
from decimal import Decimal
8-
from typing import Any, Dict, List, NamedTuple, Optional, Union
8+
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union, cast
99

1010
import boto3 # type: ignore
1111
import pandas as pd # type: ignore
@@ -89,7 +89,7 @@ def _start_query_execution(
8989
client_athena: boto3.client = _utils.client(service_name="athena", session=session)
9090
_logger.debug("args: \n%s", pprint.pformat(args))
9191
response: Dict[str, Any] = client_athena.start_query_execution(**args)
92-
return response["QueryExecutionId"]
92+
return cast(str, response["QueryExecutionId"])
9393

9494

9595
def _get_workgroup_config(session: boto3.Session, workgroup: Optional[str] = None) -> _WorkGroupConfig:
@@ -137,7 +137,7 @@ def _fetch_txt_result(query_metadata: _QueryMetadata, keep_files: bool, boto3_se
137137

138138
def _parse_describe_table(df: pd.DataFrame) -> pd.DataFrame:
139139
origin_df_dict = df.to_dict()
140-
target_df_dict: Dict[str, List] = {"Column Name": [], "Type": [], "Partition": [], "Comment": []}
140+
target_df_dict: Dict[str, List[Union[str, bool]]] = {"Column Name": [], "Type": [], "Partition": [], "Comment": []}
141141
for index, col_name in origin_df_dict["col_name"].items():
142142
col_name = col_name.strip()
143143
if col_name.startswith("#") or col_name == "":
@@ -156,7 +156,7 @@ def _parse_describe_table(df: pd.DataFrame) -> pd.DataFrame:
156156
def _get_query_metadata( # pylint: disable=too-many-statements
157157
query_execution_id: str,
158158
boto3_session: boto3.Session,
159-
categories: List[str] = None,
159+
categories: Optional[List[str]] = None,
160160
query_execution_payload: Optional[Dict[str, Any]] = None,
161161
) -> _QueryMetadata:
162162
"""Get query metadata."""
@@ -226,7 +226,9 @@ def _get_query_metadata( # pylint: disable=too-many-statements
226226
return query_metadata
227227

228228

229-
def _empty_dataframe_response(chunked: bool, query_metadata: _QueryMetadata):
229+
def _empty_dataframe_response(
230+
chunked: bool, query_metadata: _QueryMetadata
231+
) -> Union[pd.DataFrame, Generator[None, None, None]]:
230232
"""Generate an empty dataframe response."""
231233
if chunked is False:
232234
df = pd.DataFrame()
@@ -425,7 +427,7 @@ def repair_table(
425427
boto3_session=session,
426428
)
427429
response: Dict[str, Any] = wait_query(query_execution_id=query_id, boto3_session=session)
428-
return response["Status"]["State"]
430+
return cast(str, response["Status"]["State"])
429431

430432

431433
@apply_configs
@@ -556,7 +558,7 @@ def show_create_table(
556558
)
557559
query_metadata: _QueryMetadata = _get_query_metadata(query_execution_id=query_id, boto3_session=session)
558560
raw_result = _fetch_txt_result(query_metadata=query_metadata, keep_files=True, boto3_session=session,)
559-
return raw_result.createtab_stmt.str.strip().str.cat(sep=" ")
561+
return cast(str, raw_result.createtab_stmt.str.strip().str.cat(sep=" "))
560562

561563

562564
def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
@@ -581,7 +583,7 @@ def get_work_group(workgroup: str, boto3_session: Optional[boto3.Session] = None
581583
582584
"""
583585
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
584-
return client_athena.get_work_group(WorkGroup=workgroup)
586+
return cast(Dict[str, Any], client_athena.get_work_group(WorkGroup=workgroup))
585587

586588

587589
def stop_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> None:
@@ -645,7 +647,7 @@ def wait_query(query_execution_id: str, boto3_session: Optional[boto3.Session] =
645647
raise exceptions.QueryFailed(response["QueryExecution"]["Status"].get("StateChangeReason"))
646648
if state == "CANCELLED":
647649
raise exceptions.QueryCancelled(response["QueryExecution"]["Status"].get("StateChangeReason"))
648-
return response["QueryExecution"]
650+
return cast(Dict[str, Any], response["QueryExecution"])
649651

650652

651653
def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.Session] = None) -> Dict[str, Any]:
@@ -673,4 +675,4 @@ def get_query_execution(query_execution_id: str, boto3_session: Optional[boto3.S
673675
"""
674676
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
675677
response: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)
676-
return response["QueryExecution"]
678+
return cast(Dict[str, Any], response["QueryExecution"])

awswrangler/catalog/__init__.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,42 @@
4040
sanitize_dataframe_columns_names,
4141
sanitize_table_name,
4242
)
43+
44+
__all__ = [
45+
"add_csv_partitions",
46+
"add_parquet_partitions",
47+
"does_table_exist",
48+
"drop_duplicated_columns",
49+
"extract_athena_types",
50+
"sanitize_column_name",
51+
"sanitize_dataframe_columns_names",
52+
"sanitize_table_name",
53+
"_create_csv_table",
54+
"_create_parquet_table",
55+
"create_csv_table",
56+
"create_database",
57+
"create_parquet_table",
58+
"overwrite_table_parameters",
59+
"upsert_table_parameters",
60+
"_get_table_input",
61+
"databases",
62+
"get_columns_comments",
63+
"get_connection",
64+
"get_csv_partitions",
65+
"get_databases",
66+
"get_engine",
67+
"get_parquet_partitions",
68+
"get_partitions",
69+
"get_table_description",
70+
"get_table_location",
71+
"get_table_number_of_versions",
72+
"get_table_parameters",
73+
"get_table_types",
74+
"get_table_versions",
75+
"get_tables",
76+
"search_tables",
77+
"table",
78+
"tables",
79+
"delete_database",
80+
"delete_table_if_exists",
81+
]

awswrangler/catalog/_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def _add_partitions(
1919
boto3_session: Optional[boto3.Session],
2020
inputs: List[Dict[str, Any]],
2121
catalog_id: Optional[str] = None,
22-
):
22+
) -> None:
2323
chunks: List[List[Dict[str, Any]]] = _utils.chunkify(lst=inputs, max_length=100)
2424
client_glue: boto3.client = _utils.client(service_name="glue", session=boto3_session)
2525
for chunk in chunks: # pylint: disable=too-many-nested-blocks

awswrangler/catalog/_create.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def _create_table( # pylint: disable=too-many-branches,too-many-statements
4343
projection_intervals: Optional[Dict[str, str]],
4444
projection_digits: Optional[Dict[str, str]],
4545
catalog_id: Optional[str],
46-
):
46+
) -> None:
4747
# Description
4848
mode = _update_if_necessary(dic=table_input, key="Description", value=description, mode=mode)
4949

0 commit comments

Comments
 (0)