Skip to content

Commit 0aa7942

Browse files
committed
Adding BSTS changes and notebook updates
1 parent f4da3fb commit 0aa7942

File tree

4 files changed

+1501
-113
lines changed

4 files changed

+1501
-113
lines changed

causalpy/pymc_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1208,7 +1208,9 @@ def _prepare_idata(self):
12081208

12091209
new_idata = self.idata.copy()
12101210
# Get smoothed posterior and sum over state dimension
1211-
smoothed = self.conditional_idata.rename({"smoothed_posterior": "y_hat"})
1211+
smoothed = self.conditional_idata.isel(observed_state=0).rename(
1212+
{"smoothed_posterior": "y_hat"}
1213+
)
12121214
y_hat_summed = smoothed.y_hat.sum(dim="state")
12131215

12141216
# Rename 'time' to 'obs_ind' to match CausalPy conventions

docs/source/notebooks/its_pymc copy.ipynb

Lines changed: 1151 additions & 0 deletions
Large diffs are not rendered by default.

docs/source/notebooks/its_pymc.ipynb

Lines changed: 220 additions & 112 deletions
Large diffs are not rendered by default.

docs/source/notebooks/model.dot

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
digraph G {
2+
graph [bb="0,0,753.87,640"];
3+
node [label="\N"];
4+
13798618208 [height=0.5,
5+
label="ExpandDims{axis=0}",
6+
pos="593.79,268",
7+
shape=ellipse,
8+
width=2.6064];
9+
14281294576 [fillcolor="#FFAABB",
10+
height=0.5,
11+
label="Add id=6",
12+
pos="496.79,179.5",
13+
shape=ellipse,
14+
style=filled,
15+
width=1.2985];
16+
13798618208 -> 14281294576 [label="1 Matrix(float64, shape=(1, ?))",
17+
lp="668.37,223.75",
18+
pos="e,532.96,191.19 589.88,249.92 586.62,238.94 580.99,225 571.79,215.5 563.85,207.29 553.61,200.8 543.26,195.76"];
19+
14280830768 [fillcolor=cyan,
20+
height=0.5,
21+
label="Vector(float64, shape=(?,))",
22+
pos="593.79,356.5",
23+
shape=box,
24+
style=filled,
25+
width=2.2743];
26+
14280830768 -> 13798618208 [color=blue,
27+
pos="e,593.79,286.35 593.79,338.41 593.79,326.76 593.79,311.05 593.79,297.52"];
28+
13798620448 [height=0.5,
29+
label="ExpandDims{axis=0} id=1",
30+
pos="113.79,533.5",
31+
shape=ellipse,
32+
width=3.1609];
33+
14281294240 [fillcolor="#FFAABB",
34+
height=0.5,
35+
label=Add,
36+
pos="210.79,445",
37+
shape=ellipse,
38+
style=filled,
39+
width=0.75];
40+
13798620448 -> 14281294240 [label="1 Matrix(float64, shape=(1, ?))",
41+
lp="196.29,489.25",
42+
pos="e,183.62,447.24 107.67,515.23 105,504.44 103.85,490.79 110.79,481 124.69,461.42 150.41,452.7 172.29,448.87"];
43+
14281120016 [fillcolor=cyan,
44+
height=0.5,
45+
label="Vector(float64, shape=(?,))",
46+
pos="113.79,622",
47+
shape=box,
48+
style=filled,
49+
width=2.2743];
50+
14281120016 -> 13798620448 [color=blue,
51+
pos="e,113.79,551.85 113.79,603.91 113.79,592.26 113.79,576.55 113.79,563.02"];
52+
13798617312 [height=0.5,
53+
label=dot,
54+
pos="307.79,533.5",
55+
shape=ellipse,
56+
width=0.75];
57+
13798617312 -> 14281294240 [label="0 Matrix(float64, shape=(?, ?))",
58+
lp="382,489.25",
59+
pos="e,235.82,452.6 303.88,515.42 300.62,504.44 294.99,490.5 285.79,481 275.07,469.92 260.16,461.98 246.45,456.49"];
60+
14281233232 [fillcolor=green,
61+
height=0.5,
62+
label=x,
63+
pos="273.79,622",
64+
shape=box,
65+
style=filled,
66+
width=0.75];
67+
14281233232 -> 13798617312 [label=0,
68+
lp="297.21,577.75",
69+
pos="e,301.2,551.26 280.51,603.91 285.24,591.87 291.68,575.49 297.11,561.67"];
70+
14281120368 [fillcolor=cyan,
71+
height=0.5,
72+
label="Matrix(float64, shape=(?, ?))",
73+
pos="406.79,622",
74+
shape=box,
75+
style=filled,
76+
width=2.441];
77+
14281120368 -> 13798617312 [label=1,
78+
lp="369.53,577.75",
79+
pos="e,323.72,548.42 387.23,603.91 371.4,590.07 349.01,570.52 332.04,555.69"];
80+
14281294352 [fillcolor="#FFAABB",
81+
height=0.5,
82+
label=Sigmoid,
83+
pos="210.79,356.5",
84+
shape=ellipse,
85+
style=filled,
86+
width=1.1847];
87+
14281294240 -> 14281294352 [label="Matrix(float64, shape=(?, ?))",
88+
lp="290.67,400.75",
89+
pos="e,210.79,374.85 210.79,426.91 210.79,415.26 210.79,399.55 210.79,386.02"];
90+
14281294464 [height=0.5,
91+
label="dot id=5",
92+
pos="400.79,268",
93+
shape=ellipse,
94+
width=1.1847];
95+
14281294352 -> 14281294464 [label="0 Matrix(float64, shape=(?, ?))",
96+
lp="312.67,312.25",
97+
pos="e,359.06,272.61 212.01,338.26 213.74,326.92 217.85,312.63 227.54,304 245.09,288.37 303.66,278.91 347.88,273.84"];
98+
14281294464 -> 14281294576 [label="0 Matrix(float64, shape=(?, ?))",
99+
lp="486.67,223.75",
100+
pos="e,452.6,185.61 396.42,249.72 394.73,238.93 394.61,225.28 401.54,215.5 411.06,202.07 426.32,193.78 441.86,188.67"];
101+
4424577616 [fillcolor=cyan,
102+
height=0.5,
103+
label="Matrix(float64, shape=(?, ?))",
104+
pos="403.79,356.5",
105+
shape=box,
106+
style=filled,
107+
width=2.441];
108+
4424577616 -> 14281294464 [label=1,
109+
lp="405.94,312.25",
110+
pos="e,401.39,286.35 403.2,338.41 402.8,326.76 402.25,311.05 401.78,297.52"];
111+
14281294688 [height=0.5,
112+
label="Softmax{axis=None}",
113+
pos="496.79,91",
114+
shape=ellipse,
115+
width=2.5638];
116+
14281294576 -> 14281294688 [label="Matrix(float64, shape=(?, ?))",
117+
lp="576.67,135.25",
118+
pos="e,496.79,109.35 496.79,161.41 496.79,149.76 496.79,134.05 496.79,120.52"];
119+
14281241552 [fillcolor=blue,
120+
height=0.5,
121+
label="Matrix(float64, shape=(?, ?))",
122+
pos="496.79,18",
123+
shape=box,
124+
style=filled,
125+
width=2.441];
126+
14281294688 -> 14281241552 [pos="e,496.79,36.029 496.79,72.813 496.79,65.226 496.79,56.101 496.79,47.539"];
127+
}

0 commit comments

Comments
 (0)