Skip to content

Commit 0b1e566

Browse files
committed
visit sequence adjusted for nan columns
1 parent 474caf3 commit 0b1e566

File tree

3 files changed

+23
-2
lines changed

3 files changed

+23
-2
lines changed

main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def synSD2011():
3535
df0 = pyreadr.read_r("SD2011.rda")['SD2011']
3636
#pd.read_csv("bar_pass_prediction.csv")
3737
#print(df0.dtypes)
38-
df = df0#df0[['sex', 'race1', 'ugpa', 'bar']]
38+
df = df0[['age', 'unempdur', 'income', 'sex']]
3939
#print(df.isna().sum())
4040
#df.to_excel("inputData.xlsx")
4141
dtype_map ={
@@ -64,7 +64,7 @@ def synSD2011():
6464

6565

6666
r = df.dtypes.keys()
67-
spop = Synthpop()
67+
spop = Synthpop(visit_sequence=['age', 'unempdur', 'income_NaN','income', 'sex'])
6868
spop.fit(df,dtype_map)
6969

7070
synth_df = spop.generate(len(df))

synthpop/synthpop.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,14 @@ def __init__(self,
4040
self.map_column_to_NaN_column = {}
4141
# check init
4242
self.validator.check_init()
43+
def include_nan_columns(self):
44+
for (col,nan_col) in self.map_column_to_NaN_column.items():
45+
46+
if col not in self.visit_sequence:
47+
continue
48+
49+
index_of_col = self.visit_sequence.index(col)
50+
self.visit_sequence.insert(index_of_col,nan_col)
4351

4452
def pre_preprocess(self,df,dtypes,nan_fill):
4553

@@ -79,7 +87,10 @@ def fit(self, df, dtypes=None):
7987
# - can map dtypes (if given) correctly to df
8088
# should create map col: dtype (self.df_dtypes)
8189
df,dtypes = self.pre_preprocess(df,dtypes,-8)
90+
8291
self.df_columns = df.columns.tolist()
92+
self.visit_sequence = df.columns.tolist()
93+
self.include_nan_columns()
8394
self.n_df_rows, self.n_df_columns = np.shape(df)
8495
self.df_dtypes = dtypes
8596

tests_processing.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ def test_add_NaN_columns_for_numeric_columns(self):
1818
self.assertEqual(res['a'][2], -8)
1919
self.assertEqual(dtype_res['a_NaN'],'category')
2020
self.assertEqual(spop.map_column_to_NaN_column['a'],'a_NaN')
21+
def test_make_visit_sequence_when_one_is_given(self):
22+
23+
visit_seq = ['x','a','b']
24+
spop = Synthpop(visit_sequence=visit_seq)
25+
spop.map_column_to_NaN_column = {'a':'a_NaN','c':'c_NaN'}
26+
27+
spop.include_nan_columns()
28+
29+
self.assertSequenceEqual(spop.visit_sequence,['x','a_NaN','a','b'])
30+
2131

2232
def test_apply_and_remove_added_NaN_columns(self):
2333
df = pd.DataFrame({'a':[1,2,-8],'a_NaN':[False,True,False], 'b':[1,1,1], 'c':['x','y',None]})

0 commit comments

Comments
 (0)