Skip to content

Commit ebab82f

Browse files
committed
Fixed all ruff and mypy errors
1 parent 80d3233 commit ebab82f

File tree

15 files changed

+587
-370
lines changed

15 files changed

+587
-370
lines changed

examples/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""PLACEHOLDER."""

examples/tutorial/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""PLACEHOLDER."""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ ignore = [
125125
# Ignore import violations in all `__init__.py` files.
126126
[tool.ruff.lint.per-file-ignores]
127127
"__init__.py" = ["E402", "F401", "F403", "F811"]
128+
"test_*.py" = ["D103"]
128129

129130
[tool.ruff.lint.pep8-naming]
130131
ignore-names = ["X*", "setUp"]

src/midst_toolkit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""PLACEHOLDER."""

src/midst_toolkit/core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""PLACEHOLDER."""

src/midst_toolkit/core/data_loaders.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
import os
3+
from typing import Any
34

45
import numpy as np
56
import pandas as pd
@@ -9,7 +10,7 @@ def load_multi_table(data_dir, verbose=True):
910
dataset_meta = json.load(open(os.path.join(data_dir, "dataset_meta.json"), "r"))
1011

1112
relation_order = dataset_meta["relation_order"]
12-
relation_order_reversed = relation_order[::-1]
13+
# relation_order_reversed = relation_order[::-1]
1314

1415
tables = {}
1516

@@ -21,6 +22,7 @@ def load_multi_table(data_dir, verbose=True):
2122
tables[table] = {
2223
"df": train_df,
2324
"domain": json.load(open(os.path.join(data_dir, f"{table}_domain.json"))),
25+
# ruff: noqa: SIM115
2426
"children": meta["children"],
2527
"parents": meta["parents"],
2628
}
@@ -42,8 +44,9 @@ def load_multi_table(data_dir, verbose=True):
4244
return tables, relation_order, dataset_meta
4345

4446

45-
def get_info_from_domain(data_df, domain_dict):
46-
info = {}
47+
def get_info_from_domain(data_df: pd.DataFrame, domain_dict: dict[str, Any]) -> dict[str, Any]:
48+
# ruff: noqa: D103
49+
info: dict[str, Any] = {}
4750
info["num_col_idx"] = []
4851
info["cat_col_idx"] = []
4952
columns = data_df.columns.tolist()
@@ -60,7 +63,16 @@ def get_info_from_domain(data_df, domain_dict):
6063
return info
6164

6265

63-
def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=True):
66+
def pipeline_process_data(
67+
# ruff: noqa: PLR0915, PLR0912
68+
name: str,
69+
data_df: pd.DataFrame,
70+
info: dict[str, Any],
71+
ratio: float = 0.9,
72+
save: bool = False,
73+
verbose: bool = True,
74+
) -> tuple[dict[str, Any], dict[str, Any]]:
75+
# ruff: noqa: D103
6476
num_data = data_df.shape[0]
6577

6678
column_names = info["column_names"] if info["column_names"] else data_df.columns.tolist()
@@ -91,7 +103,7 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
91103
if ratio < 1:
92104
test_df.columns = range(len(test_df.columns))
93105

94-
col_info = {}
106+
col_info: dict[Any, Any] = {}
95107

96108
for col_idx in num_col_idx:
97109
col_info[col_idx] = {}
@@ -181,7 +193,7 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
181193
info["inverse_idx_mapping"] = inverse_idx_mapping
182194
info["idx_name_mapping"] = idx_name_mapping
183195

184-
metadata = {"columns": {}}
196+
metadata: dict[str, Any] = {"columns": {}}
185197
task_type = info["task_type"]
186198
num_col_idx = info["num_col_idx"]
187199
cat_col_idx = info["cat_col_idx"]
@@ -257,9 +269,16 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
257269
return data, info
258270

259271

260-
def get_column_name_mapping(data_df, num_col_idx, cat_col_idx, target_col_idx, column_names=None):
272+
def get_column_name_mapping(
273+
data_df: pd.DataFrame,
274+
num_col_idx: list[int],
275+
cat_col_idx: list[int],
276+
target_col_idx: list[int],
277+
column_names: list[str] | None = None,
278+
) -> tuple[dict[int, int], dict[int, int], dict[int, str]]:
279+
# ruff: noqa: D103
261280
if not column_names:
262-
column_names = np.array(data_df.columns.tolist())
281+
column_names = data_df.columns.tolist()
263282

264283
idx_mapping = {}
265284

@@ -290,7 +309,13 @@ def get_column_name_mapping(data_df, num_col_idx, cat_col_idx, target_col_idx, c
290309
return idx_mapping, inverse_idx_mapping, idx_name_mapping
291310

292311

293-
def train_val_test_split(data_df, cat_columns, num_train=0, num_test=0):
312+
def train_val_test_split(
313+
data_df: pd.DataFrame,
314+
cat_columns: list[str],
315+
num_train: int = 0,
316+
num_test: int = 0,
317+
) -> tuple[pd.DataFrame, pd.DataFrame, int]:
318+
# ruff: noqa: D103
294319
total_num = data_df.shape[0]
295320
idx = np.arange(total_num)
296321

0 commit comments

Comments
 (0)