@@ -1660,8 +1660,8 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
1660
1660
1661
1661
def plot_weights (bins , top0 , top1 , ax , color = "population" ):
1662
1662
colors_dict = {
1663
- "population" : ["lightcoral " , "skyblue" , 0.6 ],
1664
- "pseudo_population" : ["purple " , "purple " , 0.1 ],
1663
+ "population" : ["orange " , "skyblue" , 0.6 ],
1664
+ "pseudo_population" : ["grey " , "grey " , 0.1 ],
1665
1665
}
1666
1666
1667
1667
ax .axhline (0 , c = "gray" , linewidth = 1 )
@@ -1724,16 +1724,17 @@ def make_hists(idata, i, axs, method=method):
1724
1724
axs [0 ].set_title (
1725
1725
"Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
1726
1726
)
1727
+ axs [0 ].set_xlabel ("Propensity Scores" )
1727
1728
custom_lines = [
1728
1729
Line2D ([0 ], [0 ], color = "skyblue" , lw = 2 ),
1729
- Line2D ([0 ], [0 ], color = "lightcoral " , lw = 2 ),
1730
- Line2D ([0 ], [0 ], color = "purple " , lw = 2 ),
1730
+ Line2D ([0 ], [0 ], color = "orange " , lw = 2 ),
1731
+ Line2D ([0 ], [0 ], color = "grey " , lw = 2 ),
1731
1732
Line2D ([0 ], [0 ], color = "black" , lw = 2 , linestyle = "--" ),
1732
1733
]
1733
1734
1734
1735
axs [0 ].legend (
1735
1736
custom_lines ,
1736
- ["Control PS" , "Treatment PS" , "Weighted Pseudo Population" , "Extreme PS" ],
1737
+ ["Treatment PS" , "Control PS" , "Weighted Pseudo Population" , "Extreme PS" ],
1737
1738
)
1738
1739
1739
1740
[make_hists (idata , i , axs ) for i in range (prop_draws )]
@@ -1755,9 +1756,10 @@ def make_hists(idata, i, axs, method=method):
1755
1756
ec = "black" ,
1756
1757
bins = 10 ,
1757
1758
alpha = 0.6 ,
1758
- color = "lightcoral " ,
1759
+ color = "orange " ,
1759
1760
)
1760
1761
axs [1 ].legend ()
1762
+ axs [1 ].set_xlabel (self .outcome_variable )
1761
1763
axs [1 ].set_title (
1762
1764
f"The Outcomes \n Under the { method } re-weighting scheme" , fontsize = 20
1763
1765
)
@@ -1766,9 +1768,10 @@ def make_hists(idata, i, axs, method=method):
1766
1768
label = "ATE" ,
1767
1769
ec = "black" ,
1768
1770
bins = 10 ,
1769
- color = "slateblue " ,
1771
+ color = "grey " ,
1770
1772
alpha = 0.6 ,
1771
1773
)
1774
+ axs [2 ].set_xlabel ("Difference" )
1772
1775
axs [2 ].axvline (ate_df ["ATE" ].mean (), label = "E(ATE)" )
1773
1776
axs [2 ].legend ()
1774
1777
axs [2 ].set_title ("Average Treatment Effect" , fontsize = 20 )
@@ -1832,7 +1835,7 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
1832
1835
np .linspace (0 , 1 , 1000 ), raw_trt , color = "skyblue" , label = "Raw Treated"
1833
1836
)
1834
1837
axs [0 ].plot (
1835
- np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "lightcoral " , label = "Raw Control"
1838
+ np .linspace (0 , 1 , 1000 ), raw_ntrt , color = "orange " , label = "Raw Control"
1836
1839
)
1837
1840
axs [0 ].set_title (f"ECDF \n Raw: { covariate } " )
1838
1841
axs [1 ].set_title (
@@ -1844,7 +1847,7 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
1844
1847
axs [1 ].plot (
1845
1848
np .linspace (0 , 1 , 1000 ),
1846
1849
w_ntrt ,
1847
- color = "lightcoral " ,
1850
+ color = "orange " ,
1848
1851
label = "Reweighted Control" ,
1849
1852
)
1850
1853
axs [1 ].set_xlabel ("Quantiles" )
0 commit comments