Skip to content

Commit 380b0c7

Browse files
springcoiltwiecki
authored andcommitted
PEP8: Optimizing imports and improving the PEP8 compliance of GHME (#1328)
1 parent 8936438 commit 380b0c7

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

pymc3/examples/GHME_2013.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
# <nbformat>3.0</nbformat>
33

44
# <codecell>
5-
from pylab import *
65
import pandas as pd
7-
from pymc3 import *
8-
from pymc3.distributions.timeseries import *
6+
from pylab import *
7+
8+
from pymc3 import StudentT, Model, NUTS, Normal, find_MAP, trace, get_data_file
9+
from pymc3.distributions.timeseries import GaussianRandomWalk
910

1011
# <markdowncell>
1112

@@ -15,7 +16,7 @@
1516
# <codecell>
1617

1718
data = pd.read_csv(get_data_file('pymc3.examples', 'data/pancreatitis.csv'))
18-
countries = ['CYP', 'DNK', 'ESP', 'FIN','GBR', 'ISL']
19+
countries = ['CYP', 'DNK', 'ESP', 'FIN', 'GBR', 'ISL']
1920
data = data[data.area.isin(countries)]
2021

2122
age = data['age'] = np.array(data.age_start + data.age_end)/2
@@ -28,12 +29,12 @@
2829
# <codecell>
2930

3031
for i, country in enumerate(countries):
31-
subplot(2,3,i+1)
32+
subplot(2, 3, i+1)
3233
title(country)
3334
d = data[data.area == country]
3435
plot(d.age, d.value, '.')
3536

36-
ylim(0,rate.max())
37+
ylim(0, rate.max())
3738

3839
# <markdowncell>
3940

@@ -43,17 +44,17 @@
4344
# <codecell>
4445

4546
nknots = 10
46-
knots = np.linspace(data.age_start.min(),data.age_end.max(), nknots)
47+
knots = np.linspace(data.age_start.min(), data.age_end.max(), nknots)
4748

4849

49-
def interpolate(x0,y0, x, group):
50+
def interpolate(x0, y0, x, group):
5051
x = np.array(x)
5152
group = np.array(group)
5253

5354
idx = np.searchsorted(x0, x)
5455
dl = np.array(x - x0[idx - 1])
5556
dr = np.array(x0[idx] - x)
56-
d=dl+dr
57+
d = dl + dr
5758
wl = dr/d
5859

5960
return wl*y0[idx-1, group] + (1-wl)*y0[idx, group]
@@ -68,7 +69,7 @@ def interpolate(x0,y0, x, group):
6869

6970
sd = StudentT('sd', 10, 2, 5**-2)
7071

71-
vals = Normal('vals', p, sd=sd, observed = rate)
72+
vals = Normal('vals', p, sd=sd, observed=rate)
7273

7374
# <markdowncell>
7475

@@ -80,31 +81,30 @@ def interpolate(x0,y0, x, group):
8081
with model:
8182
s = find_MAP(vars=[sd, y])
8283

83-
step = NUTS(scaling = s)
84+
step = NUTS(scaling=s)
8485
trace = sample(100, step, s)
8586

8687
s = trace[-1]
8788

8889
step = NUTS(scaling=s)
8990

91+
9092
def run(n=3000):
9193
if n == "short":
9294
n = 150
9395
with model:
9496
trace = sample(n, step, s)
95-
96-
9797
# <codecell>
9898

9999
for i, country in enumerate(countries):
100-
subplot(2,3,i+1)
100+
subplot(2, 3, i+1)
101101
title(country)
102102

103103
d = data[data.area == country]
104104
plot(d.age, d.value, '.')
105-
plot(knots, trace[y][::5,:,i].T, color ='r', alpha =.01);
105+
plot(knots, trace[y][::5, :, i].T, color='r', alpha=.01)
106106

107-
ylim(0,rate.max())
107+
ylim(0, rate.max())
108108

109109

110110
if __name__ == '__main__':

0 commit comments

Comments
 (0)