@@ -725,3 +725,283 @@ def test_inverse_prop():
725725 assert all (isinstance (ax , plt .Axes ) for ax in axs )
726726 with pytest .raises (NotImplementedError ):
727727 result .get_plot_data ()
728+
729+
730+ @pytest .fixture
731+ def multi_unit_sc_data (rng ):
732+ """Generate synthetic data for SyntheticControl with multiple treated units."""
733+ n_obs = 60
734+ n_control = 4
735+ n_treated = 3
736+
737+ # Create time index
738+ time_index = pd .date_range ("2020-01-01" , periods = n_obs , freq = "D" )
739+ treatment_time = time_index [40 ] # Intervention at day 40
740+
741+ # Control unit data
742+ control_data = {}
743+ for i in range (n_control ):
744+ control_data [f"control_{ i } " ] = rng .normal (10 , 2 , n_obs ) + np .sin (
745+ np .arange (n_obs ) * 0.1
746+ )
747+
748+ # Treated unit data (combinations of control units with some noise)
749+ treated_data = {}
750+ for j in range (n_treated ):
751+ # Each treated unit is a different weighted combination of controls
752+ weights = rng .dirichlet (np .ones (n_control ))
753+ base_signal = sum (
754+ weights [i ] * control_data [f"control_{ i } " ] for i in range (n_control )
755+ )
756+
757+ # Add treatment effect after intervention
758+ treatment_effect = np .zeros (n_obs )
759+ treatment_effect [40 :] = rng .normal (
760+ 5 , 1 , n_obs - 40
761+ ) # Positive effect after treatment
762+
763+ treated_data [f"treated_{ j } " ] = (
764+ base_signal + treatment_effect + rng .normal (0 , 0.5 , n_obs )
765+ )
766+
767+ # Create DataFrame
768+ df = pd .DataFrame ({** control_data , ** treated_data }, index = time_index )
769+
770+ control_units = [f"control_{ i } " for i in range (n_control )]
771+ treated_units = [f"treated_{ j } " for j in range (n_treated )]
772+
773+ return df , treatment_time , control_units , treated_units
774+
775+
776+ @pytest .fixture
777+ def single_unit_sc_data (rng ):
778+ """Generate synthetic data for SyntheticControl with single treated unit."""
779+ n_obs = 60
780+ n_control = 4
781+
782+ # Create time index
783+ time_index = pd .date_range ("2020-01-01" , periods = n_obs , freq = "D" )
784+ treatment_time = time_index [40 ] # Intervention at day 40
785+
786+ # Control unit data
787+ control_data = {}
788+ for i in range (n_control ):
789+ control_data [f"control_{ i } " ] = rng .normal (10 , 2 , n_obs ) + np .sin (
790+ np .arange (n_obs ) * 0.1
791+ )
792+
793+ # Single treated unit data
794+ weights = rng .dirichlet (np .ones (n_control ))
795+ base_signal = sum (
796+ weights [i ] * control_data [f"control_{ i } " ] for i in range (n_control )
797+ )
798+
799+ # Add treatment effect after intervention
800+ treatment_effect = np .zeros (n_obs )
801+ treatment_effect [40 :] = rng .normal (
802+ 5 , 1 , n_obs - 40
803+ ) # Positive effect after treatment
804+
805+ treated_data = {
806+ "treated_0" : base_signal + treatment_effect + rng .normal (0 , 0.5 , n_obs )
807+ }
808+
809+ # Create DataFrame
810+ df = pd .DataFrame ({** control_data , ** treated_data }, index = time_index )
811+
812+ control_units = [f"control_{ i } " for i in range (n_control )]
813+ treated_units = ["treated_0" ]
814+
815+ return df , treatment_time , control_units , treated_units
816+
817+
818+ class TestSyntheticControlMultiUnit :
819+ """Tests for SyntheticControl experiment with multiple treated units."""
820+
821+ @pytest .mark .integration
822+ def test_multi_unit_initialization (self , multi_unit_sc_data ):
823+ """Test that SyntheticControl can initialize with multiple treated units."""
824+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
825+
826+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
827+
828+ # Should initialize without error
829+ sc = cp .SyntheticControl (
830+ data = df ,
831+ treatment_time = treatment_time ,
832+ control_units = control_units ,
833+ treated_units = treated_units ,
834+ model = model ,
835+ )
836+
837+ # Check basic attributes
838+ assert sc .treated_units == treated_units
839+ assert sc .control_units == control_units
840+ assert sc .treatment_time == treatment_time
841+
842+ # Check data shapes
843+ assert sc .datapre_treated .shape == (40 , len (treated_units ))
844+ assert sc .datapost_treated .shape == (20 , len (treated_units ))
845+ assert sc .datapre_control .shape == (40 , len (control_units ))
846+ assert sc .datapost_control .shape == (20 , len (control_units ))
847+
848+ @pytest .mark .integration
849+ def test_multi_unit_scoring (self , multi_unit_sc_data ):
850+ """Test that scoring works with multiple treated units."""
851+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
852+
853+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
854+
855+ sc = cp .SyntheticControl (
856+ data = df ,
857+ treatment_time = treatment_time ,
858+ control_units = control_units ,
859+ treated_units = treated_units ,
860+ model = model ,
861+ )
862+
863+ # Score should be a pandas Series with separate entries for each unit
864+ assert isinstance (sc .score , pd .Series )
865+
866+ # Check that we have r2 and r2_std for each treated unit
867+ for unit in treated_units :
868+ assert f"{ unit } _r2" in sc .score .index
869+ assert f"{ unit } _r2_std" in sc .score .index
870+
871+ @pytest .mark .integration
872+ def test_multi_unit_summary (self , multi_unit_sc_data , capsys ):
873+ """Test that summary works with multiple treated units."""
874+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
875+
876+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
877+
878+ sc = cp .SyntheticControl (
879+ data = df ,
880+ treatment_time = treatment_time ,
881+ control_units = control_units ,
882+ treated_units = treated_units ,
883+ model = model ,
884+ )
885+
886+ # Test summary
887+ sc .summary (round_to = 3 )
888+
889+ captured = capsys .readouterr ()
890+ output = captured .out
891+
892+ # Check that output contains information for multiple treated units
893+ assert "Treated units:" in output
894+ for unit in treated_units :
895+ assert unit in output
896+
897+ @pytest .mark .integration
898+ def test_single_unit_backward_compatibility (self , single_unit_sc_data ):
899+ """Test that single treated unit still works (backward compatibility)."""
900+ df , treatment_time , control_units , treated_units = single_unit_sc_data
901+
902+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
903+
904+ sc = cp .SyntheticControl (
905+ data = df ,
906+ treatment_time = treatment_time ,
907+ control_units = control_units ,
908+ treated_units = treated_units ,
909+ model = model ,
910+ )
911+
912+ # Check basic attributes
913+ assert sc .treated_units == treated_units
914+ assert sc .control_units == control_units
915+ assert sc .treatment_time == treatment_time
916+
917+ @pytest .mark .integration
918+ def test_multi_unit_plotting (self , multi_unit_sc_data ):
919+ """Test that plotting works with multiple treated units."""
920+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
921+
922+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
923+
924+ sc = cp .SyntheticControl (
925+ data = df ,
926+ treatment_time = treatment_time ,
927+ control_units = control_units ,
928+ treated_units = treated_units ,
929+ model = model ,
930+ )
931+
932+ # Test plotting - should work for each treated unit individually
933+ for unit in treated_units :
934+ fig , ax = sc .plot (treated_unit = unit )
935+ assert isinstance (fig , plt .Figure )
936+ assert isinstance (ax , np .ndarray ) and all (
937+ isinstance (item , plt .Axes ) for item in ax
938+ )
939+
940+ # Test default plotting (first unit)
941+ fig , ax = sc .plot ()
942+ assert isinstance (fig , plt .Figure )
943+ assert isinstance (ax , np .ndarray ) and all (
944+ isinstance (item , plt .Axes ) for item in ax
945+ )
946+
947+ @pytest .mark .integration
948+ def test_multi_unit_plot_data (self , multi_unit_sc_data ):
949+ """Test that plot data generation works with multiple treated units."""
950+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
951+
952+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
953+
954+ sc = cp .SyntheticControl (
955+ data = df ,
956+ treatment_time = treatment_time ,
957+ control_units = control_units ,
958+ treated_units = treated_units ,
959+ model = model ,
960+ )
961+
962+ # Test plot data generation for each treated unit
963+ for unit in treated_units :
964+ plot_data = sc .get_plot_data (treated_unit = unit )
965+ assert isinstance (plot_data , pd .DataFrame )
966+
967+ # Check expected columns
968+ expected_columns = [
969+ "prediction" ,
970+ "pred_hdi_lower_94" ,
971+ "pred_hdi_upper_94" ,
972+ "impact" ,
973+ "impact_hdi_lower_94" ,
974+ "impact_hdi_upper_94" ,
975+ ]
976+ assert set (expected_columns ).issubset (set (plot_data .columns ))
977+
978+ # Test default plot data (first unit)
979+ plot_data = sc .get_plot_data ()
980+ assert isinstance (plot_data , pd .DataFrame )
981+
982+ @pytest .mark .integration
983+ def test_multi_unit_plotting_invalid_unit (self , multi_unit_sc_data ):
984+ """Test that plotting with invalid treated unit raises appropriate errors."""
985+ df , treatment_time , control_units , treated_units = multi_unit_sc_data
986+
987+ model = cp .pymc_models .WeightedSumFitter (sample_kwargs = sample_kwargs )
988+
989+ sc = cp .SyntheticControl (
990+ data = df ,
991+ treatment_time = treatment_time ,
992+ control_units = control_units ,
993+ treated_units = treated_units ,
994+ model = model ,
995+ )
996+
997+ # Test that invalid treated unit name is handled gracefully
998+ # Note: Current implementation may not raise ValueError, so we test default behavior
999+ try :
1000+ sc .plot (treated_unit = "invalid_unit" )
1001+ except (ValueError , KeyError ):
1002+ pass # Either error type is acceptable
1003+
1004+ try :
1005+ sc .get_plot_data (treated_unit = "invalid_unit" )
1006+ except (ValueError , KeyError ):
1007+ pass # Either error type is acceptable
0 commit comments