44from datasets .adult import df , dtypes
55
66def test_synthpop_default_parameters ():
7- """Test Synthpop with default parameters using Adult dataset ."""
7+ """Test Synthpop with default parameters and automatic type inference ."""
88 # Initialize Synthpop
99 spop = Synthpop ()
1010
11- # Fit the model
12- spop .fit (df , dtypes )
11+ # Fit the model with automatic type inference
12+ spop .fit (df )
1313
1414 # Generate synthetic data
1515 synth_df = spop .generate (len (df ))
@@ -20,6 +20,11 @@ def test_synthpop_default_parameters():
2020 # Verify the synthetic dataframe has the same columns as original
2121 assert all (synth_df .columns == df .columns )
2222
23+ # Verify inferred dtypes match expected types
24+ assert spop .df_dtypes ['age' ] == 'int'
25+ assert spop .df_dtypes ['workclass' ] == 'category'
26+ assert spop .df_dtypes ['education' ] == 'category'
27+
2328 # Verify the method attribute contains expected default values
2429 assert isinstance (spop .method , pd .Series )
2530 assert 'age' in spop .method .index
@@ -37,6 +42,25 @@ def test_synthpop_default_parameters():
3742 assert all (spop .predictor_matrix .index == df .columns )
3843 assert all (spop .predictor_matrix .columns == df .columns )
3944
45+ def test_synthpop_with_manual_dtypes ():
46+ """Test Synthpop with manually specified dtypes."""
47+ # Initialize Synthpop
48+ spop = Synthpop ()
49+
50+ # Fit the model with explicit dtypes
51+ spop .fit (df , dtypes )
52+
53+ # Verify the dtypes were set correctly
54+ for col , dtype in dtypes .items ():
55+ assert spop .df_dtypes [col ] == dtype
56+
57+ # Generate synthetic data
58+ synth_df = spop .generate (len (df ))
59+
60+ # Verify the synthetic dataframe has the same shape and columns
61+ assert synth_df .shape == df .shape
62+ assert all (synth_df .columns == df .columns )
63+
4064def test_synthpop_custom_visit_sequence ():
4165 """Test Synthpop with custom visit sequence using Adult dataset."""
4266 # Define custom visit sequence
@@ -45,8 +69,8 @@ def test_synthpop_custom_visit_sequence():
4569 # Initialize Synthpop with custom visit sequence
4670 spop = Synthpop (visit_sequence = visit_sequence )
4771
48- # Fit the model
49- spop .fit (df , dtypes )
72+ # Fit the model with automatic type inference
73+ spop .fit (df )
5074
5175 # Generate synthetic data
5276 synth_df = spop .generate (len (df ))
0 commit comments