@@ -1656,25 +1656,57 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
1656
1656
if method is None :
1657
1657
method = self .weighting_scheme
1658
1658
1659
- def plot_weights (bins , top0 , top1 , ax ):
1659
+ def plot_weights (bins , top0 , top1 , ax , color = "population" ):
1660
+ colors_dict = {
1661
+ "population" : ["red" , "blue" , 0.9 ],
1662
+ "pseudo_population" : ["purple" , "purple" , 0.1 ],
1663
+ }
1664
+
1660
1665
ax .axhline (0 , c = "gray" , linewidth = 1 )
1661
1666
bars0 = ax .bar (
1662
- bins [:- 1 ] + 0.025 , top0 , width = 0.04 , facecolor = "red" , alpha = 0.3
1667
+ bins [:- 1 ] + 0.025 ,
1668
+ top0 ,
1669
+ width = 0.04 ,
1670
+ facecolor = colors_dict [color ][0 ],
1671
+ alpha = colors_dict [color ][2 ],
1663
1672
)
1664
1673
bars1 = ax .bar (
1665
- bins [:- 1 ] + 0.025 , - top1 , width = 0.04 , facecolor = "blue" , alpha = 0.3
1674
+ bins [:- 1 ] + 0.025 ,
1675
+ - top1 ,
1676
+ width = 0.04 ,
1677
+ facecolor = colors_dict [color ][1 ],
1678
+ alpha = colors_dict [color ][2 ],
1666
1679
)
1667
1680
1668
1681
for bars in (bars0 , bars1 ):
1669
1682
for bar in bars :
1670
1683
bar .set_edgecolor ("black" )
1671
1684
1672
- def make_hists (idata , i , axs ):
1685
+ def make_hists (idata , i , axs , method = method ):
1673
1686
p_i = az .extract (idata )["p" ][:, i ].values
1687
+ if method == "raw" :
1688
+ weight0 = 1 / (1 - p_i [self .t .flatten () == 0 ])
1689
+ weight1 = 1 / (p_i [self .t .flatten () == 1 ])
1690
+ elif method == "overlap" :
1691
+ t = self .t .flatten ()
1692
+ weight1 = (1 - p_i [t == 1 ]) * t [t == 1 ]
1693
+ weight0 = p_i [t == 0 ] * (1 - t [t == 0 ])
1694
+ else :
1695
+ t = self .t .flatten ()
1696
+ p_of_t = np .mean (t )
1697
+ weight1 = p_of_t / p_i [t == 1 ]
1698
+ weight0 = (1 - p_of_t ) / (1 - p_i [t == 0 ])
1674
1699
bins = np .arange (0.025 , 0.99 , 0.005 )
1675
1700
top0 , _ = np .histogram (p_i [self .t .flatten () == 0 ], bins = bins )
1676
1701
top1 , _ = np .histogram (p_i [self .t .flatten () == 1 ], bins = bins )
1677
1702
plot_weights (bins , top0 , top1 , axs [0 ])
1703
+ top0 , _ = np .histogram (
1704
+ p_i [self .t .flatten () == 0 ], bins = bins , weights = weight0
1705
+ )
1706
+ top1 , _ = np .histogram (
1707
+ p_i [self .t .flatten () == 1 ], bins = bins , weights = weight1
1708
+ )
1709
+ plot_weights (bins , top0 , top1 , axs [0 ], color = "pseudo_population" )
1678
1710
1679
1711
mosaic = """AAAAAA
1680
1712
BBBBCC"""
@@ -1690,6 +1722,7 @@ def make_hists(idata, i, axs):
1690
1722
axs [0 ].set_title (
1691
1723
"Draws from the Posterior \n Propensity Scores Distribution" , fontsize = 20
1692
1724
)
1725
+ axs [0 ].legend ()
1693
1726
1694
1727
[make_hists (idata , i , axs ) for i in range (prop_draws )]
1695
1728
ate_df = pd .DataFrame (
0 commit comments