Skip to content

Commit a368a49

Browse files
committed
synthesised categorical variable without encoding for bar dataset works if we don't include sex
1 parent c3b05fc commit a368a49

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

main.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,28 @@
22
import pandas as pd
33
import pyreadr
44

5+
def synBar():
6+
df = pd.read_csv("bar_pass_prediction.csv")[[ 'race1', 'ugpa', 'bar']]
7+
print(df.dtypes)
8+
dtype_map = { 'race1': 'category', 'ugpa': 'float', 'bar': 'category'}
9+
10+
for (k,v) in dtype_map.items():
11+
12+
if v == 'category':
13+
df = df.astype({k : "category"})
14+
15+
print(df.dtypes)
16+
spop = Synthpop()
17+
spop.fit(df,dtype_map)
18+
19+
synth_df = spop.generate(len(df))
20+
21+
print(synth_df.head())
522

623
def synSD2011():
724
df0 = pyreadr.read_r("SD2011.rda")['SD2011']
825
#pd.read_csv("bar_pass_prediction.csv")
9-
print(df0)
26+
print(df0.dtypes)
1027
df = df0[['age', 'unempdur', 'income', 'sex']]#df0[['sex', 'race1', 'ugpa', 'bar']]
1128
#df.to_excel("inputData.xlsx")
1229
dtype_map ={
@@ -27,4 +44,7 @@ def synSD2011():
2744

2845
synth_df = spop.generate(len(df))
2946

30-
print(synth_df.head())
47+
print(synth_df.head())
48+
49+
50+
synBar()

synthpop/processor/processor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def preprocess(self, df, dtypes):
5757
'nan_value': col_nan_category
5858
}
5959

60-
df[col].cat.add_categories(col_nan_category, inplace=True)
60+
df[col] = df[col].cat.add_categories(col_nan_category) #argument 'inplace' is deprecated and removed
6161
df[col].fillna(col_nan_category, inplace=True)
6262

6363
# NaNs in numerical columns
@@ -79,7 +79,7 @@ def preprocess(self, df, dtypes):
7979
df.loc[bool_series, col_nan_name] = cat_index
8080
df.loc[col_all_nan_indices, col] = 0
8181

82-
df[col_nan_name] = df[col_nan_name].astype('category')
82+
df.loc[:,col_nan_name] = df[col_nan_name].astype('category')
8383
self.spop.df_dtypes[col_nan_name] = 'category'
8484

8585
return df

synthpop/synthpop.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def fit(self, df, dtypes=None):
5757
self.validator.check_processor()
5858
# preprocess
5959
processed_df = self.processor.preprocess(df, self.df_dtypes)
60+
print(processed_df)
6061
self.processed_df_columns = processed_df.columns.tolist()
6162
self.n_processed_df_columns = len(self.processed_df_columns)
6263

0 commit comments

Comments
 (0)