Skip to content

Commit 949d87c

Browse files
committed
postprocessing to add the NaN values is added
1 parent 78d550b commit 949d87c

File tree

3 files changed

+23
-4
lines changed

3 files changed

+23
-4
lines changed

main.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ def synBar():
1919

2020
synth_df = spop.generate(len(df))
2121

22+
print("synthetische data:")
2223
print(synth_df.head())
24+
25+
print("aantal NaNs:")
2326
print(synth_df.isna().sum())
2427

2528
def synSD2011():
@@ -50,4 +53,4 @@ def synSD2011():
5053
print(synth_df.head())
5154

5255

53-
synBar()
56+
synSD2011()

synthpop/synthpop.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self,
3737
self.numtocat = numtocat
3838
self.catgroups = catgroups
3939
self.seed = seed
40-
40+
self.map_column_to_NaN_column = {}
4141
# check init
4242
self.validator.check_init()
4343

@@ -54,13 +54,22 @@ def pre_preprocess(self,df,dtypes,nan_fill):
5454

5555
nan_col_name = column+"_NaN"
5656
df.loc[:,nan_col_name] = maybe_nans
57+
self.map_column_to_NaN_column[column] = nan_col_name
5758

5859
dtypes[nan_col_name] = 'category'
5960

6061

6162
return df,dtypes
6263

6364
def post_postprocessing(self,syn_df):
65+
for column in syn_df:
66+
67+
if column in self.map_column_to_NaN_column.keys():
68+
nan_col_name = self.map_column_to_NaN_column[column]
69+
column_NaN_at = syn_df[nan_col_name]
70+
syn_df.loc[column_NaN_at,column] = None
71+
syn_df = syn_df.drop(columns=nan_col_name)
72+
6473
return syn_df
6574
def fit(self, df, dtypes=None):
6675
# TODO check df and check/EXTRACT dtypes
@@ -117,7 +126,7 @@ def generate(self, k=None):
117126
# postprocess
118127
processed_synth_df = self.processor.postprocess(synth_df)
119128

120-
return processed_synth_df
129+
return self.post_postprocessing(processed_synth_df)
121130

122131
def _generate(self):
123132
synth_df = pd.DataFrame(data=np.zeros([self.k, len(self.visit_sequence)]), columns=self.visit_sequence.index)

tests_processing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,17 @@ def test_add_NaN_columns_for_numeric_columns(self):
1717
self.assertTrue(res['a_NaN'][2])
1818
self.assertEqual(res['a'][2], -8)
1919
self.assertEqual(dtype_res['a_NaN'],'category')
20+
self.assertEqual(spop.map_column_to_NaN_column['a'],'a_NaN')
2021

2122
def test_apply_and_remove_added_NaN_columns(self):
22-
df = pd.DataFrame({'a':[1,2,np.nan],'a_NaN':[False,False,True], 'b':[1,1,1], 'c':['x','y',None]})
23+
df = pd.DataFrame({'a':[1,2,-8],'a_NaN':[False,True,False], 'b':[1,1,1], 'c':['x','y',None]})
24+
2325
spop = Synthpop()
26+
spop.map_column_to_NaN_column = {'a':'a_NaN'}
27+
28+
res = spop.post_postprocessing(df)
29+
self.assertTrue(np.isnan(res['a'][1]), "NaNs should be placed where indicated")
30+
self.assertFalse('a_NaN' in res, "indicator columns should be removed")
2431

2532

2633
if __name__ == '__main__':

0 commit comments

Comments
 (0)