22
22
import xarray as xr
23
23
from patsy import build_design_matrices , dmatrices
24
24
from sklearn .linear_model import LinearRegression as sk_lin_reg
25
+ from matplotlib .lines import Line2D
26
+
25
27
26
28
from causalpy .data_validation import (
27
29
PrePostFitDataValidator ,
@@ -1658,7 +1660,7 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
1658
1660
1659
1661
def plot_weights (bins , top0 , top1 , ax , color = "population" ):
1660
1662
colors_dict = {
1661
- "population" : ["red " , "blue " , 0.9 ],
1663
+ "population" : ["lightcoral " , "skyblue " , 0.6 ],
1662
1664
"pseudo_population" : ["purple" , "purple" , 0.1 ],
1663
1665
}
1664
1666
@@ -1722,7 +1724,17 @@ def make_hists(idata, i, axs, method=method):
1722
1724
axs [0 ].set_title (
1723
1725
"Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
1724
1726
)
1725
- axs [0 ].legend ()
1727
+ custom_lines = [
1728
+ Line2D ([0 ], [0 ], color = "skyblue" , lw = 2 ),
1729
+ Line2D ([0 ], [0 ], color = "lightcoral" , lw = 2 ),
1730
+ Line2D ([0 ], [0 ], color = "purple" , lw = 2 ),
1731
+ Line2D ([0 ], [0 ], color = "black" , lw = 2 , linestyle = "--" ),
1732
+ ]
1733
+
1734
+ axs [0 ].legend (
1735
+ custom_lines ,
1736
+ ["Control PS" , "Treatment PS" , "Weighted Pseudo Population" , "Extreme PS" ],
1737
+ )
1726
1738
1727
1739
[make_hists (idata , i , axs ) for i in range (prop_draws )]
1728
1740
ate_df = pd .DataFrame (
@@ -1734,11 +1746,16 @@ def make_hists(idata, i, axs, method=method):
1734
1746
label = "E(Y(1))" ,
1735
1747
ec = "black" ,
1736
1748
bins = 10 ,
1737
- alpha = 0.8 ,
1738
- color = "blue " ,
1749
+ alpha = 0.6 ,
1750
+ color = "skyblue " ,
1739
1751
)
1740
1752
axs [1 ].hist (
1741
- ate_df ["Y(0)" ], label = "E(Y(0))" , ec = "black" , bins = 10 , alpha = 0.8 , color = "red"
1753
+ ate_df ["Y(0)" ],
1754
+ label = "E(Y(0))" ,
1755
+ ec = "black" ,
1756
+ bins = 10 ,
1757
+ alpha = 0.6 ,
1758
+ color = "lightcoral" ,
1742
1759
)
1743
1760
axs [1 ].legend ()
1744
1761
axs [1 ].set_title (
@@ -1811,17 +1828,24 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
1811
1828
self .weighted_percentile (X [t == 0 ][covariate ].values , w0 , p )
1812
1829
for p in np .linspace (0 , 1 , 1000 )
1813
1830
]
1814
- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_trt , color = "blue" , label = "Raw Treated" )
1815
- axs [0 ].plot (np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "red" , label = "Raw Control" )
1831
+ axs [0 ].plot (
1832
+ np .linspace (0 , 1 , 1000 ), raw_trt , color = "skyblue" , label = "Raw Treated"
1833
+ )
1834
+ axs [0 ].plot (
1835
+ np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "lightcoral" , label = "Raw Control"
1836
+ )
1816
1837
axs [0 ].set_title (f"ECDF \n Raw: { covariate } " )
1817
1838
axs [1 ].set_title (
1818
1839
f"ECDF \n Weighted { weighting_scheme } adjustment for { covariate } "
1819
1840
)
1820
1841
axs [1 ].plot (
1821
- np .linspace (0 , 1 , 1000 ), w_trt , color = "blue " , label = "Reweighted Treated"
1842
+ np .linspace (0 , 1 , 1000 ), w_trt , color = "skyblue " , label = "Reweighted Treated"
1822
1843
)
1823
1844
axs [1 ].plot (
1824
- np .linspace (0 , 1 , 1000 ), w_ntrt , color = "red" , label = "Reweighted Control"
1845
+ np .linspace (0 , 1 , 1000 ),
1846
+ w_ntrt ,
1847
+ color = "lightcoral" ,
1848
+ label = "Reweighted Control" ,
1825
1849
)
1826
1850
axs [1 ].set_xlabel ("Quantiles" )
1827
1851
axs [0 ].set_xlabel ("Quantiles" )
0 commit comments