Skip to content

Commit f170891

Browse files
committed
Fixing mypy issues
1 parent 97d492f commit f170891

File tree

8 files changed

+60
-45
lines changed

8 files changed

+60
-45
lines changed

awswrangler/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import awswrangler.utils # noqa
1313
import awswrangler.data_types # noqa
1414

15-
if importlib.util.find_spec("pyspark"):
15+
if importlib.util.find_spec("pyspark"): # type: ignore
1616
from awswrangler.spark import Spark # noqa
1717

1818
logging.getLogger("awswrangler").addHandler(logging.NullHandler())

awswrangler/data_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyarrow as pa # type: ignore
66
import pandas as pd # type: ignore
77

8-
from awswrangler.exceptions import UnsupportedType, UndetectedType # type: ignore
8+
from awswrangler.exceptions import UnsupportedType, UndetectedType
99

1010
logger = logging.getLogger(__name__)
1111

@@ -283,7 +283,8 @@ def spark2redshift(dtype: str) -> str:
283283
raise UnsupportedType("Unsupported Spark type: " + dtype)
284284

285285

286-
def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, str]:
286+
def convert_schema(func: Callable,
287+
schema: List[Tuple[str, str]]) -> Dict[str, str]:
287288
"""
288289
Convert schema in the format of {"col name": "bigint", "col2 name": "int"}
289290
applying some data types conversion function (e.g. spark2redshift)
@@ -297,7 +298,8 @@ def convert_schema(func: Callable, schema: List[Tuple[str, str]]) -> Dict[str, s
297298

298299
def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
299300
preserve_index: bool,
300-
indexes_position: str = "right") -> List[Tuple[str, str]]:
301+
indexes_position: str = "right"
302+
) -> List[Tuple[str, str]]:
301303
"""
302304
Extract the related Pyarrow schema from any Pandas DataFrame
303305

awswrangler/pandas.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
import csv
77
from datetime import datetime
88

9-
import pandas as pd
10-
import pyarrow as pa
11-
from pyarrow import parquet as pq
9+
import pandas as pd # type: ignore
10+
import pyarrow as pa # type: ignore
11+
from pyarrow import parquet as pq # type: ignore
1212

1313
from awswrangler import data_types
1414
from awswrangler.exceptions import (UnsupportedWriteMode,
@@ -1058,7 +1058,8 @@ def normalize_columns_names_athena(dataframe, inplace=True):
10581058
return dataframe
10591059

10601060
@staticmethod
1061-
def drop_duplicated_columns(dataframe, inplace=True):
1061+
def drop_duplicated_columns(dataframe: pd.DataFrame,
1062+
inplace: bool = True) -> pd.DataFrame:
10621063
if inplace is False:
10631064
dataframe = dataframe.copy(deep=True)
10641065
duplicated_cols = dataframe.columns.duplicated()

awswrangler/redshift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import json
22
import logging
33

4-
import pg8000
4+
import pg8000 # type: ignore
55

66
from awswrangler import data_types
77
from awswrangler.exceptions import (

awswrangler/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import logging
33
import importlib
44

5-
import boto3
6-
from botocore.config import Config
5+
import boto3 # type: ignore
6+
from botocore.config import Config # type: ignore
77

88
from awswrangler.s3 import S3
99
from awswrangler.athena import Athena
@@ -13,7 +13,7 @@
1313
from awswrangler.redshift import Redshift
1414

1515
PYSPARK_INSTALLED = False
16-
if importlib.util.find_spec("pyspark"):
16+
if importlib.util.find_spec("pyspark"): # type: ignore
1717
PYSPARK_INSTALLED = True
1818
from awswrangler.spark import Spark
1919

awswrangler/spark.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from typing import List, Tuple, Dict
22
import logging
33

4-
import pandas as pd
4+
import pandas as pd # type: ignore
55

6-
from pyspark.sql.functions import pandas_udf, PandasUDFType, spark_partition_id
7-
from pyspark.sql.types import TimestampType
8-
from pyspark.sql import DataFrame
6+
from pyspark import sql
97

108
from awswrangler.exceptions import MissingBatchDetected, UnsupportedFileFormat
119

@@ -38,7 +36,7 @@ def date2timestamp(dataframe):
3836
for name, dtype in dataframe.dtypes:
3937
if dtype == "date":
4038
dataframe = dataframe.withColumn(
41-
name, dataframe[name].cast(TimestampType()))
39+
name, dataframe[name].cast(sql.types.TimestampType()))
4240
logger.warning(
4341
f"Casting column {name} from date to timestamp!")
4442
return dataframe
@@ -98,8 +96,9 @@ def to_redshift(
9896
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
9997
session_primitives = self._session.primitives
10098

101-
@pandas_udf(returnType="objects_paths string",
102-
functionType=PandasUDFType.GROUPED_MAP)
99+
@sql.functions.pandas_udf(
100+
returnType="objects_paths string",
101+
functionType=sql.functions.PandasUDFType.GROUPED_MAP)
103102
def write(pandas_dataframe):
104103
del pandas_dataframe["aws_data_wrangler_internal_partition_id"]
105104
paths = session_primitives.session.pandas.to_parquet(
@@ -112,7 +111,7 @@ def write(pandas_dataframe):
112111
return pd.DataFrame.from_dict({"objects_paths": paths})
113112

114113
df_objects_paths = dataframe.repartition(numPartitions=num_partitions) \
115-
.withColumn("aws_data_wrangler_internal_partition_id", spark_partition_id()) \
114+
.withColumn("aws_data_wrangler_internal_partition_id", sql.functions.spark_partition_id()) \
116115
.groupby("aws_data_wrangler_internal_partition_id") \
117116
.apply(write)
118117

@@ -227,7 +226,8 @@ def _is_map(dtype: str) -> bool:
227226

228227
@staticmethod
229228
def _is_array_or_map(dtype: str) -> bool:
230-
return True if (dtype.startswith("array") or dtype.startswith("map")) else False
229+
return True if (dtype.startswith("array")
230+
or dtype.startswith("map")) else False
231231

232232
@staticmethod
233233
def _parse_aux(path: str, aux: str) -> Tuple[str, str]:
@@ -242,19 +242,22 @@ def _parse_aux(path: str, aux: str) -> Tuple[str, str]:
242242

243243
@staticmethod
244244
def _flatten_struct_column(path: str, dtype: str) -> List[Tuple[str, str]]:
245-
dtype: str = dtype[7:-1] # Cutting off "struct<" and ">"
245+
dtype = dtype[7:-1] # Cutting off "struct<" and ">"
246246
cols: List[Tuple[str, str]] = []
247247
struct_acc: int = 0
248248
path_child: str
249249
dtype_child: str
250250
aux: str = ""
251-
for c, i in zip(dtype, range(len(dtype), 0, -1)): # Zipping a descendant ID for each letter
251+
for c, i in zip(dtype,
252+
range(len(dtype), 0,
253+
-1)): # Zipping a descendant ID for each letter
252254
if ((c == ",") and (struct_acc == 0)) or (i == 1):
253255
if i == 1:
254256
aux += c
255257
path_child, dtype_child = Spark._parse_aux(path=path, aux=aux)
256258
if Spark._is_struct(dtype=dtype_child):
257-
cols += Spark._flatten_struct_column(path=path_child, dtype=dtype_child) # Recursion
259+
cols += Spark._flatten_struct_column(
260+
path=path_child, dtype=dtype_child) # Recursion
258261
elif Spark._is_array(dtype=dtype):
259262
cols.append((path, "array"))
260263
else:
@@ -271,10 +274,10 @@ def _flatten_struct_column(path: str, dtype: str) -> List[Tuple[str, str]]:
271274
return cols
272275

273276
@staticmethod
274-
def _flatten_struct_dataframe(
275-
df: DataFrame,
276-
explode_outer: bool = True,
277-
explode_pos: bool = True) -> List[Tuple[str, str, str]]:
277+
def _flatten_struct_dataframe(df: sql.DataFrame,
278+
explode_outer: bool = True,
279+
explode_pos: bool = True
280+
) -> List[Tuple[str, str, str]]:
278281
explode: str = "EXPLODE_OUTER" if explode_outer is True else "EXPLODE"
279282
explode = f"POS{explode}" if explode_pos is True else explode
280283
cols: List[Tuple[str, str]] = []
@@ -308,26 +311,34 @@ def _flatten_struct_dataframe(
308311

309312
@staticmethod
310313
def _build_name(name: str, expr: str) -> str:
311-
suffix: str = expr[expr.find("(") + 1: expr.find(")")]
314+
suffix: str = expr[expr.find("(") + 1:expr.find(")")]
312315
return f"{name}_{suffix}".replace(".", "_")
313316

314317
@staticmethod
315-
def flatten(
316-
df: DataFrame,
317-
explode_outer: bool = True,
318-
explode_pos: bool = True,
319-
name: str = "root") -> Dict[str, DataFrame]:
320-
cols_exprs: List[Tuple[str, str, str]] = Spark._flatten_struct_dataframe(
321-
df=df,
322-
explode_outer=explode_outer,
323-
explode_pos=explode_pos)
324-
exprs_arr: List[str] = [x[2] for x in cols_exprs if Spark._is_array_or_map(x[1])]
325-
exprs: List[str] = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1])]
326-
dfs: Dict[str, DataFrame] = {name: df.selectExpr(exprs)}
327-
exprs: List[str] = [x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1]) and not x[0].endswith("_pos")]
318+
def flatten(df: sql.DataFrame,
319+
explode_outer: bool = True,
320+
explode_pos: bool = True,
321+
name: str = "root") -> Dict[str, sql.DataFrame]:
322+
cols_exprs: List[
323+
Tuple[str, str, str]] = Spark._flatten_struct_dataframe(
324+
df=df, explode_outer=explode_outer, explode_pos=explode_pos)
325+
exprs_arr: List[str] = [
326+
x[2] for x in cols_exprs if Spark._is_array_or_map(x[1])
327+
]
328+
exprs: List[str] = [
329+
x[2] for x in cols_exprs if not Spark._is_array_or_map(x[1])
330+
]
331+
dfs: Dict[str, sql.DataFrame] = {name: df.selectExpr(exprs)}
332+
exprs = [
333+
x[2] for x in cols_exprs
334+
if not Spark._is_array_or_map(x[1]) and not x[0].endswith("_pos")
335+
]
328336
for expr in exprs_arr:
329337
df_arr = df.selectExpr(exprs + [expr])
330338
name_new: str = Spark._build_name(name=name, expr=expr)
331-
dfs_new = Spark.flatten(df=df_arr, explode_outer=explode_outer, explode_pos=explode_pos, name=name_new)
339+
dfs_new = Spark.flatten(df=df_arr,
340+
explode_outer=explode_outer,
341+
explode_pos=explode_pos,
342+
name=name_new)
332343
dfs = {**dfs, **dfs_new}
333344
return dfs

awswrangler/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def wait_process_release(processes):
4747
sleep(0.1)
4848

4949

50-
def lcm(a, b):
50+
def lcm(a: int, b: int) -> int:
5151
"""
5252
Least Common Multiple
5353
"""
54-
return abs(a * b) // gcd(a, b)
54+
return int(abs(a * b) // gcd(a, b))

testing/run-tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cd ..
66
rm -rf *.pytest_cache
77
yapf --in-place --recursive setup.py awswrangler testing/test_awswrangler
88
flake8 setup.py awswrangler testing/test_awswrangler
9+
mypy awswrangler
910
pip install -e .
1011
pytest testing/test_awswrangler awswrangler
1112
rm -rf *.pytest_cache

0 commit comments

Comments
 (0)