Skip to content

Commit 9e3383e

Browse files
committed
2 parents 888abe0 + b9e9d55 commit 9e3383e

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

synthpop/synthpop.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,48 @@ def post_postprocessing(self,syn_df):
7777

7878
return syn_df
7979

80+
def _infer_dtypes(self, df):
81+
"""Automatically infer data types from DataFrame.
82+
83+
Args:
84+
df: pandas DataFrame
85+
86+
Returns:
87+
dict: Mapping of column names to inferred types ('int', 'float', 'datetime', 'category', 'bool')
88+
"""
89+
dtypes = {}
90+
for column in df.columns:
91+
pd_dtype = str(df[column].dtype)
92+
93+
if pd_dtype.startswith('int'):
94+
dtypes[column] = 'int'
95+
elif pd_dtype.startswith('float'):
96+
dtypes[column] = 'float'
97+
elif pd_dtype.startswith('datetime'):
98+
dtypes[column] = 'datetime'
99+
elif pd_dtype.startswith('bool'):
100+
dtypes[column] = 'bool'
101+
else:
102+
# For object or string dtypes, check if it should be categorical
103+
dtypes[column] = 'category'
104+
105+
return dtypes
106+
80107
def fit(self, df, dtypes=None):
81-
# TODO check df and check/EXTRACT dtypes
82-
# - all column names of df are unique
83-
# - all columns data of df are consistent
84-
# - all dtypes of df are correct ('int', 'float', 'datetime', 'category', 'bool'; no object)
85-
# - can map dtypes (if given) correctly to df
86-
# should create map col: dtype (self.df_dtypes)
108+
"""Fit the synthetic data generator.
109+
110+
Args:
111+
df: pandas DataFrame to learn from
112+
dtypes: Optional dict mapping column names to types. If not provided, types will be inferred.
113+
"""
114+
# Infer dtypes if not provided
115+
if dtypes is None:
116+
dtypes = self._infer_dtypes(df)
117+
118+
# Validate DataFrame
119+
if not df.columns.is_unique:
120+
raise ValueError("DataFrame column names must be unique")
121+
87122
df,dtypes = self.pre_preprocess(df,dtypes,-8)
88123

89124
self.df_columns = df.columns.tolist()

tests/test_synthpop.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from datasets.adult import df, dtypes
55

66
def test_synthpop_default_parameters():
7-
"""Test Synthpop with default parameters using Adult dataset."""
7+
"""Test Synthpop with default parameters and automatic type inference."""
88
# Initialize Synthpop
99
spop = Synthpop()
1010

11-
# Fit the model
12-
spop.fit(df, dtypes)
11+
# Fit the model with automatic type inference
12+
spop.fit(df)
1313

1414
# Generate synthetic data
1515
synth_df = spop.generate(len(df))
@@ -20,6 +20,11 @@ def test_synthpop_default_parameters():
2020
# Verify the synthetic dataframe has the same columns as original
2121
assert all(synth_df.columns == df.columns)
2222

23+
# Verify inferred dtypes match expected types
24+
assert spop.df_dtypes['age'] == 'int'
25+
assert spop.df_dtypes['workclass'] == 'category'
26+
assert spop.df_dtypes['education'] == 'category'
27+
2328
# Verify the method attribute contains expected default values
2429
assert isinstance(spop.method, pd.Series)
2530
assert 'age' in spop.method.index
@@ -37,6 +42,25 @@ def test_synthpop_default_parameters():
3742
assert all(spop.predictor_matrix.index == df.columns)
3843
assert all(spop.predictor_matrix.columns == df.columns)
3944

45+
def test_synthpop_with_manual_dtypes():
46+
"""Test Synthpop with manually specified dtypes."""
47+
# Initialize Synthpop
48+
spop = Synthpop()
49+
50+
# Fit the model with explicit dtypes
51+
spop.fit(df, dtypes)
52+
53+
# Verify the dtypes were set correctly
54+
for col, dtype in dtypes.items():
55+
assert spop.df_dtypes[col] == dtype
56+
57+
# Generate synthetic data
58+
synth_df = spop.generate(len(df))
59+
60+
# Verify the synthetic dataframe has the same shape and columns
61+
assert synth_df.shape == df.shape
62+
assert all(synth_df.columns == df.columns)
63+
4064
def test_synthpop_custom_visit_sequence():
4165
"""Test Synthpop with custom visit sequence using Adult dataset."""
4266
# Define custom visit sequence
@@ -45,8 +69,8 @@ def test_synthpop_custom_visit_sequence():
4569
# Initialize Synthpop with custom visit sequence
4670
spop = Synthpop(visit_sequence=visit_sequence)
4771

48-
# Fit the model
49-
spop.fit(df, dtypes)
72+
# Fit the model with automatic type inference
73+
spop.fit(df)
5074

5175
# Generate synthetic data
5276
synth_df = spop.generate(len(df))

0 commit comments

Comments
 (0)