Skip to content

Commit 76d8d2a

Browse files
committed
improved color scheme
Signed-off-by: Nathaniel <[email protected]>
1 parent 37c022c commit 76d8d2a

File tree

3 files changed

+89
-65
lines changed

3 files changed

+89
-65
lines changed

causalpy/pymc_experiments.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import xarray as xr
2323
from patsy import build_design_matrices, dmatrices
2424
from sklearn.linear_model import LinearRegression as sk_lin_reg
25+
from matplotlib.lines import Line2D
26+
2527

2628
from causalpy.data_validation import (
2729
PrePostFitDataValidator,
@@ -1658,7 +1660,7 @@ def plot_ATE(self, idata=None, method=None, prop_draws=100, ate_draws=300):
16581660

16591661
def plot_weights(bins, top0, top1, ax, color="population"):
16601662
colors_dict = {
1661-
"population": ["red", "blue", 0.9],
1663+
"population": ["lightcoral", "skyblue", 0.6],
16621664
"pseudo_population": ["purple", "purple", 0.1],
16631665
}
16641666

@@ -1722,7 +1724,17 @@ def make_hists(idata, i, axs, method=method):
17221724
axs[0].set_title(
17231725
"Draws from the Posterior \n Propensity Scores Distribution", fontsize=20
17241726
)
1725-
axs[0].legend()
1727+
custom_lines = [
1728+
Line2D([0], [0], color="skyblue", lw=2),
1729+
Line2D([0], [0], color="lightcoral", lw=2),
1730+
Line2D([0], [0], color="purple", lw=2),
1731+
Line2D([0], [0], color="black", lw=2, linestyle="--"),
1732+
]
1733+
1734+
axs[0].legend(
1735+
custom_lines,
1736+
["Control PS", "Treatment PS", "Weighted Pseudo Population", "Extreme PS"],
1737+
)
17261738

17271739
[make_hists(idata, i, axs) for i in range(prop_draws)]
17281740
ate_df = pd.DataFrame(
@@ -1734,11 +1746,16 @@ def make_hists(idata, i, axs, method=method):
17341746
label="E(Y(1))",
17351747
ec="black",
17361748
bins=10,
1737-
alpha=0.8,
1738-
color="blue",
1749+
alpha=0.6,
1750+
color="skyblue",
17391751
)
17401752
axs[1].hist(
1741-
ate_df["Y(0)"], label="E(Y(0))", ec="black", bins=10, alpha=0.8, color="red"
1753+
ate_df["Y(0)"],
1754+
label="E(Y(0))",
1755+
ec="black",
1756+
bins=10,
1757+
alpha=0.6,
1758+
color="lightcoral",
17421759
)
17431760
axs[1].legend()
17441761
axs[1].set_title(
@@ -1811,17 +1828,24 @@ def plot_balance_ecdf(self, covariate, idata=None, weighting_scheme=None):
18111828
self.weighted_percentile(X[t == 0][covariate].values, w0, p)
18121829
for p in np.linspace(0, 1, 1000)
18131830
]
1814-
axs[0].plot(np.linspace(0, 1, 1000), raw_trt, color="blue", label="Raw Treated")
1815-
axs[0].plot(np.linspace(0, 1, 1000), raw_ntrt, color="red", label="Raw Control")
1831+
axs[0].plot(
1832+
np.linspace(0, 1, 1000), raw_trt, color="skyblue", label="Raw Treated"
1833+
)
1834+
axs[0].plot(
1835+
np.linspace(0, 1, 1000), raw_ntrt, color="lightcoral", label="Raw Control"
1836+
)
18161837
axs[0].set_title(f"ECDF \n Raw: {covariate}")
18171838
axs[1].set_title(
18181839
f"ECDF \n Weighted {weighting_scheme} adjustment for {covariate}"
18191840
)
18201841
axs[1].plot(
1821-
np.linspace(0, 1, 1000), w_trt, color="blue", label="Reweighted Treated"
1842+
np.linspace(0, 1, 1000), w_trt, color="skyblue", label="Reweighted Treated"
18221843
)
18231844
axs[1].plot(
1824-
np.linspace(0, 1, 1000), w_ntrt, color="red", label="Reweighted Control"
1845+
np.linspace(0, 1, 1000),
1846+
w_ntrt,
1847+
color="lightcoral",
1848+
label="Reweighted Control",
18251849
)
18261850
axs[1].set_xlabel("Quantiles")
18271851
axs[0].set_xlabel("Quantiles")
187 KB
Loading

docs/source/notebooks/inv_prop_pymc.ipynb

Lines changed: 56 additions & 56 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)