Skip to content

Commit 58a2b2f

Browse files
Fix apply_configs decorator causing function signature to be lost (#1858)
1 parent ac47d72 commit 58a2b2f

File tree

10 files changed

+65
-37
lines changed

10 files changed

+65
-37
lines changed

awswrangler/_config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
import logging
55
import os
6-
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union, cast
6+
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union, cast
77

88
import botocore.config
99
import pandas as pd
@@ -466,7 +466,10 @@ def _inject_config_doc(doc: Optional[str], available_configs: Tuple[str, ...]) -
466466
return _insert_str(text=doc, token="\n Parameters", insert=insertion)
467467

468468

469-
def apply_configs(function: Callable[..., Any]) -> Callable[..., Any]:
469+
FunctionType = TypeVar("FunctionType", bound=Callable[..., Any])
470+
471+
472+
def apply_configs(function: FunctionType) -> FunctionType:
470473
"""Decorate some function with configs."""
471474
signature = inspect.signature(function)
472475
args_names: Tuple[str, ...] = tuple(signature.parameters.keys())
@@ -495,7 +498,7 @@ def wrapper(*args_raw: Any, **kwargs: Any) -> Any:
495498
wrapper.__doc__ = _inject_config_doc(doc=function.__doc__, available_configs=available_configs)
496499
wrapper.__name__ = function.__name__
497500
wrapper.__setattr__("__signature__", signature) # pylint: disable=no-member
498-
return wrapper
501+
return wrapper # type: ignore
499502

500503

501504
config: _Config = _Config()

awswrangler/athena/_read.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ def _fetch_parquet_result(
106106

107107
database, temp_table_name = map(lambda x: x.replace('"', ""), temp_table_fqn.split("."))
108108
dtype_dict = catalog.get_table_types(database=database, table=temp_table_name, boto3_session=boto3_session)
109+
if dtype_dict is None:
110+
raise exceptions.ResourceDoesNotExist(f"Temp table {temp_table_fqn} not found.")
111+
109112
df = pd.DataFrame(columns=list(dtype_dict.keys()))
110113
df = cast_pandas_with_athena_types(df=df, dtype=dtype_dict)
111114
df = _apply_query_metadata(df=df, query_metadata=query_metadata)

awswrangler/athena/_utils.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -554,15 +554,18 @@ def repair_table(
554554
if (database is not None) and (not database.startswith("`")):
555555
database = f"`{database}`"
556556
session: boto3.Session = _utils.ensure_session(session=boto3_session)
557-
query_id = start_query_execution(
558-
sql=query,
559-
database=database,
560-
data_source=data_source,
561-
s3_output=s3_output,
562-
workgroup=workgroup,
563-
encryption=encryption,
564-
kms_key=kms_key,
565-
boto3_session=session,
557+
query_id = cast(
558+
str,
559+
start_query_execution(
560+
sql=query,
561+
database=database,
562+
data_source=data_source,
563+
s3_output=s3_output,
564+
workgroup=workgroup,
565+
encryption=encryption,
566+
kms_key=kms_key,
567+
boto3_session=session,
568+
),
566569
)
567570
response: Dict[str, Any] = wait_query(query_execution_id=query_id, boto3_session=session)
568571
return cast(str, response["Status"]["State"])
@@ -624,14 +627,17 @@ def describe_table(
624627
if (database is not None) and (not database.startswith("`")):
625628
database = f"`{database}`"
626629
session: boto3.Session = _utils.ensure_session(session=boto3_session)
627-
query_id = start_query_execution(
628-
sql=query,
629-
database=database,
630-
s3_output=s3_output,
631-
workgroup=workgroup,
632-
encryption=encryption,
633-
kms_key=kms_key,
634-
boto3_session=session,
630+
query_id = cast(
631+
str,
632+
start_query_execution(
633+
sql=query,
634+
database=database,
635+
s3_output=s3_output,
636+
workgroup=workgroup,
637+
encryption=encryption,
638+
kms_key=kms_key,
639+
boto3_session=session,
640+
),
635641
)
636642
query_metadata: _QueryMetadata = _get_query_metadata(query_execution_id=query_id, boto3_session=session)
637643
raw_result = _fetch_txt_result(
@@ -643,7 +649,7 @@ def describe_table(
643649
@apply_configs
644650
def create_ctas_table( # pylint: disable=too-many-locals
645651
sql: str,
646-
database: str,
652+
database: Optional[str] = None,
647653
ctas_table: Optional[str] = None,
648654
ctas_database: Optional[str] = None,
649655
s3_output: Optional[str] = None,
@@ -669,7 +675,7 @@ def create_ctas_table( # pylint: disable=too-many-locals
669675
----------
670676
sql : str
671677
SELECT SQL query.
672-
database : str
678+
database : Optional[str], optional
673679
The name of the database where the original table is stored.
674680
ctas_table : Optional[str], optional
675681
The name of the CTAS table.
@@ -756,6 +762,10 @@ def create_ctas_table( # pylint: disable=too-many-locals
756762
"""
757763
ctas_table = catalog.sanitize_table_name(ctas_table) if ctas_table else f"temp_table_{uuid.uuid4().hex}"
758764
ctas_database = ctas_database if ctas_database else database
765+
766+
if ctas_database is None:
767+
raise exceptions.InvalidArgumentCombination("Either ctas_database or database must be defined.")
768+
759769
fully_qualified_name = f'"{ctas_database}"."{ctas_table}"'
760770

761771
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
@@ -904,14 +914,17 @@ def show_create_table(
904914
if (database is not None) and (not database.startswith("`")):
905915
database = f"`{database}`"
906916
session: boto3.Session = _utils.ensure_session(session=boto3_session)
907-
query_id = start_query_execution(
908-
sql=query,
909-
database=database,
910-
s3_output=s3_output,
911-
workgroup=workgroup,
912-
encryption=encryption,
913-
kms_key=kms_key,
914-
boto3_session=session,
917+
query_id = cast(
918+
str,
919+
start_query_execution(
920+
sql=query,
921+
database=database,
922+
s3_output=s3_output,
923+
workgroup=workgroup,
924+
encryption=encryption,
925+
kms_key=kms_key,
926+
boto3_session=session,
927+
),
915928
)
916929
query_metadata: _QueryMetadata = _get_query_metadata(query_execution_id=query_id, boto3_session=session)
917930
raw_result = _fetch_txt_result(

awswrangler/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,7 @@ class FailedQualityCheck(Exception):
115115

116116
class AlreadyExists(Exception):
117117
"""AlreadyExists."""
118+
119+
120+
class ResourceDoesNotExist(Exception):
121+
"""ResourceDoesNotExist."""

awswrangler/quicksight/_utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Internal (private) Amazon QuickSight Utilities Module."""
22

33
import logging
4-
from typing import Any, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional, cast
55

66
import boto3
77

@@ -29,7 +29,9 @@ def extract_athena_query_columns(
2929
data_source: Dict[str, Any] = [x for x in data_sources if x["Arn"] == data_source_arn][0]
3030
workgroup: str = data_source["DataSourceParameters"]["AthenaParameters"]["WorkGroup"]
3131
sql_wrapped: str = f"/* QuickSight */\nSELECT ds.* FROM ( {sql} ) ds LIMIT 0"
32-
query_id: str = athena.start_query_execution(sql=sql_wrapped, workgroup=workgroup, boto3_session=boto3_session)
32+
query_id = cast(
33+
str, athena.start_query_execution(sql=sql_wrapped, workgroup=workgroup, boto3_session=boto3_session)
34+
)
3335
athena.wait_query(query_execution_id=query_id, boto3_session=boto3_session)
3436
dtypes: Dict[str, str] = athena.get_query_columns_types(query_execution_id=query_id, boto3_session=boto3_session)
3537
return [{"Name": name, "Type": _data_types.athena2quicksight(dtype=dtype)} for name, dtype in dtypes.items()]

awswrangler/redshift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def _validate_parameters(
254254

255255

256256
def _redshift_types_from_path(
257-
path: Optional[Union[str, List[str]]],
257+
path: Union[str, List[str]],
258258
varchar_lengths_default: int,
259259
varchar_lengths: Optional[Dict[str, int]],
260260
parquet_infer_sampling: float,

awswrangler/s3/_download.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Amazon S3 Download Module (PRIVATE)."""
22

33
import logging
4-
from typing import Any, Dict, Optional, Union
4+
from typing import Any, Dict, Optional, Union, cast
55

66
import boto3
77

@@ -76,7 +76,7 @@ def download(
7676
if isinstance(local_file, str):
7777
_logger.debug("Downloading local_file: %s", local_file)
7878
with open(file=local_file, mode="wb") as local_f:
79-
local_f.write(s3_f.read())
79+
local_f.write(cast(bytes, s3_f.read()))
8080
else:
8181
_logger.debug("Downloading file-like object.")
8282
local_file.write(s3_f.read())

awswrangler/s3/_fs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def __init__(
228228
else:
229229
raise RuntimeError(f"Invalid mode: {self._mode}")
230230

231-
def __enter__(self) -> Union["_S3ObjectBase"]:
231+
def __enter__(self) -> "_S3ObjectBase":
232232
return self
233233

234234
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:

awswrangler/s3/_merge_upsert_table.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ def _generate_empty_frame_for_table(
7777
boto3_session: Optional[boto3.Session] = None,
7878
) -> pandas.DataFrame:
7979
type_dict = wr.catalog.get_table_types(database=database, table=table, boto3_session=boto3_session)
80+
if type_dict is None:
81+
raise wr.exceptions.ResourceDoesNotExist(f"Table {table} from database {database} does not exist.")
82+
8083
empty_frame = pandas.DataFrame(columns=type_dict.keys())
8184
return _data_types.cast_pandas_with_athena_types(empty_frame, type_dict)
8285

awswrangler/s3/_upload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def upload(
7171
if isinstance(local_file, str):
7272
_logger.debug("Uploading local_file: %s", local_file)
7373
with open(file=local_file, mode="rb") as local_f:
74-
s3_f.write(local_f.read())
74+
s3_f.write(local_f.read()) # type: ignore
7575
else:
7676
_logger.debug("Uploading file-like object.")
7777
s3_f.write(local_file.read())

0 commit comments

Comments
 (0)