11import json
22import os
3+ from typing import Any
34
45import numpy as np
56import pandas as pd
@@ -9,7 +10,7 @@ def load_multi_table(data_dir, verbose=True):
910 dataset_meta = json .load (open (os .path .join (data_dir , "dataset_meta.json" ), "r" ))
1011
1112 relation_order = dataset_meta ["relation_order" ]
12- relation_order_reversed = relation_order [::- 1 ]
13+ # relation_order_reversed = relation_order[::-1]
1314
1415 tables = {}
1516
@@ -21,6 +22,7 @@ def load_multi_table(data_dir, verbose=True):
2122 tables [table ] = {
2223 "df" : train_df ,
2324 "domain" : json .load (open (os .path .join (data_dir , f"{ table } _domain.json" ))),
25+ # ruff: noqa: SIM115
2426 "children" : meta ["children" ],
2527 "parents" : meta ["parents" ],
2628 }
@@ -42,8 +44,9 @@ def load_multi_table(data_dir, verbose=True):
4244 return tables , relation_order , dataset_meta
4345
4446
45- def get_info_from_domain (data_df , domain_dict ):
46- info = {}
47+ def get_info_from_domain (data_df : pd .DataFrame , domain_dict : dict [str , Any ]) -> dict [str , Any ]:
48+ # ruff: noqa: D103
49+ info : dict [str , Any ] = {}
4750 info ["num_col_idx" ] = []
4851 info ["cat_col_idx" ] = []
4952 columns = data_df .columns .tolist ()
@@ -60,7 +63,16 @@ def get_info_from_domain(data_df, domain_dict):
6063 return info
6164
6265
63- def pipeline_process_data (name , data_df , info , ratio = 0.9 , save = False , verbose = True ):
66+ def pipeline_process_data (
67+ # ruff: noqa: PLR0915, PLR0912
68+ name : str ,
69+ data_df : pd .DataFrame ,
70+ info : dict [str , Any ],
71+ ratio : float = 0.9 ,
72+ save : bool = False ,
73+ verbose : bool = True ,
74+ ) -> tuple [dict [str , Any ], dict [str , Any ]]:
75+ # ruff: noqa: D103
6476 num_data = data_df .shape [0 ]
6577
6678 column_names = info ["column_names" ] if info ["column_names" ] else data_df .columns .tolist ()
@@ -91,7 +103,7 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
91103 if ratio < 1 :
92104 test_df .columns = range (len (test_df .columns ))
93105
94- col_info = {}
106+ col_info : dict [ Any , Any ] = {}
95107
96108 for col_idx in num_col_idx :
97109 col_info [col_idx ] = {}
@@ -181,7 +193,7 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
181193 info ["inverse_idx_mapping" ] = inverse_idx_mapping
182194 info ["idx_name_mapping" ] = idx_name_mapping
183195
184- metadata = {"columns" : {}}
196+ metadata : dict [ str , Any ] = {"columns" : {}}
185197 task_type = info ["task_type" ]
186198 num_col_idx = info ["num_col_idx" ]
187199 cat_col_idx = info ["cat_col_idx" ]
@@ -257,9 +269,16 @@ def pipeline_process_data(name, data_df, info, ratio=0.9, save=False, verbose=Tr
257269 return data , info
258270
259271
260- def get_column_name_mapping (data_df , num_col_idx , cat_col_idx , target_col_idx , column_names = None ):
272+ def get_column_name_mapping (
273+ data_df : pd .DataFrame ,
274+ num_col_idx : list [int ],
275+ cat_col_idx : list [int ],
276+ target_col_idx : list [int ],
277+ column_names : list [str ] | None = None ,
278+ ) -> tuple [dict [int , int ], dict [int , int ], dict [int , str ]]:
279+ # ruff: noqa: D103
261280 if not column_names :
262- column_names = np . array ( data_df .columns .tolist () )
281+ column_names = data_df .columns .tolist ()
263282
264283 idx_mapping = {}
265284
@@ -290,7 +309,13 @@ def get_column_name_mapping(data_df, num_col_idx, cat_col_idx, target_col_idx, c
290309 return idx_mapping , inverse_idx_mapping , idx_name_mapping
291310
292311
293- def train_val_test_split (data_df , cat_columns , num_train = 0 , num_test = 0 ):
312+ def train_val_test_split (
313+ data_df : pd .DataFrame ,
314+ cat_columns : list [str ],
315+ num_train : int = 0 ,
316+ num_test : int = 0 ,
317+ ) -> tuple [pd .DataFrame , pd .DataFrame , int ]:
318+ # ruff: noqa: D103
294319 total_num = data_df .shape [0 ]
295320 idx = np .arange (total_num )
296321
0 commit comments