Skip to content

Commit 93d209a

Browse files
authored
Merge pull request #2 from becksgld/type-hints
Added type hints
2 parents 45bd96e + e44bbef commit 93d209a

File tree

3 files changed

+64
-39
lines changed

3 files changed

+64
-39
lines changed

src/epidemik/EpiModel.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
# @author Bruno Goncalves
44
######################################################
55

6+
from typing import Dict, List, Set, Union
67
import warnings
8+
import string
9+
710
import networkx as nx
811
import numpy as np
912
from numpy import linalg
1013
from numpy import random
1114
import scipy.integrate
1215
import pandas as pd
1316
import matplotlib.pyplot as plt
14-
import string
17+
1518
from .utils import *
1619

1720
class EpiModel(object):
@@ -38,7 +41,7 @@ def __init__(self, compartments=None):
3841
if compartments is not None:
3942
self.transitions.add_nodes_from([comp for comp in compartments])
4043

41-
def add_interaction(self, source, target, agent, rate):
44+
def add_interaction(self, source: str, target: str, agent: str, rate: float) -> None:
4245
"""
4346
Add an interaction between two compartments
4447
@@ -57,7 +60,7 @@ def add_interaction(self, source, target, agent, rate):
5760
"""
5861
self.transitions.add_edge(source, target, agent=agent, rate=rate)
5962

60-
def add_spontaneous(self, source, target, rate):
63+
def add_spontaneous(self, source: str, target: str, rate: float) -> None:
6164
"""
6265
Add a spontaneous transition between two compartments
6366
@@ -74,7 +77,7 @@ def add_spontaneous(self, source, target, rate):
7477
"""
7578
self.transitions.add_edge(source, target, rate=rate)
7679

77-
def add_vaccination(self, source, target, rate, start):
80+
def add_vaccination(self, source: str, target: str, rate: float, start: int) -> None:
7881
"""
7982
Add a vaccination transition between two compartments
8083
@@ -93,7 +96,14 @@ def add_vaccination(self, source, target, rate, start):
9396
"""
9497
self.transitions.add_edge(source, target, rate=rate, start=start)
9598

96-
def add_age_structure(self, matrix, population):
99+
def add_age_structure(self, matrix: List, population: List) -> List[List]:
100+
"""
101+
Add a vaccination transition between two compartments
102+
103+
Parameters:
104+
- matrix: List
105+
- population: List
106+
"""
97107
self.contact = np.asarray(matrix)
98108
self.population = np.asarray(population).flatten()
99109

@@ -131,7 +141,7 @@ def add_age_structure(self, matrix, population):
131141

132142
self.transitions = model.transitions
133143

134-
def _new_cases(self, population, time, pos):
144+
def _new_cases(self, population: np.ndarray, time: float, pos: Dict) -> np.ndarray:
135145
"""
136146
Internal function used by integration routine
137147
@@ -188,7 +198,7 @@ def _new_cases(self, population, time, pos):
188198

189199
return diff
190200

191-
def plot(self, title=None, normed=True, show=True, ax=None, **kwargs):
201+
def plot(self, title: Union[str, None]= None, normed: bool = True, show: bool = True, ax: Union[plt.Axes, None] = None, **kwargs):
192202
"""
193203
Convenience function for plotting
194204
@@ -235,7 +245,7 @@ def plot(self, title=None, normed=True, show=True, ax=None, **kwargs):
235245
print(e)
236246
raise NotInitialized('You must call integrate() or simulate() first')
237247

238-
def __getattr__(self, name):
248+
def __getattr__(self, name: str) -> pd.Series:
239249
"""
240250
Dynamic method to return the individual compartment values
241251
@@ -252,7 +262,7 @@ def __getattr__(self, name):
252262
else:
253263
raise AttributeError("'EpiModel' object has no attribute '%s'" % name)
254264

255-
def simulate(self, timesteps, t_min=1, seasonality=None, **kwargs):
265+
def simulate(self, timesteps: int, t_min: int = 1, seasonality: Union[np.ndarray, None] = None, **kwargs) -> None:
256266
"""
257267
Stochastically simulate the epidemic model
258268
@@ -334,7 +344,7 @@ def simulate(self, timesteps, t_min=1, seasonality=None, **kwargs):
334344
values = np.array(values)
335345
self.values_ = pd.DataFrame(values[1:], columns=comps, index=time)
336346

337-
def integrate(self, timesteps, t_min=1, seasonality=None, **kwargs):
347+
def integrate(self, timesteps: int , t_min: int = 1, seasonality: Union[np.ndarray, None] = None, **kwargs) -> None:
338348
"""
339349
Numerically integrate the epidemic model
340350
@@ -375,7 +385,14 @@ def integrate(self, timesteps, t_min=1, seasonality=None, **kwargs):
375385
time = np.arange(t_min, t_min+timesteps, 1)
376386

377387
self.seasonality = seasonality
378-
values = pd.DataFrame(scipy.integrate.odeint(self._new_cases, population, time, args=(pos,)), columns=pos.keys(), index=time)
388+
values = pd.DataFrame(
389+
scipy.integrate.odeint(
390+
self._new_cases,
391+
population,
392+
time,
393+
args=(pos,)
394+
), columns=pos.keys(), index=time
395+
)
379396

380397
if self.population is None:
381398
self.values_ = values
@@ -398,7 +415,7 @@ def single_step(self, seasonality=None, **kwargs):
398415
new_values = pd.concat([old_values, self.values_.iloc[[-1]]])
399416
self.values_ = new_values
400417

401-
def __repr__(self):
418+
def __repr__(self) -> str:
402419
"""
403420
Return a string representation of the EpiModel object
404421
@@ -433,7 +450,7 @@ def __repr__(self):
433450

434451
return text
435452

436-
def _get_active(self):
453+
def _get_active(self) -> Set:
437454
active = set()
438455

439456
for node_i, node_j, data in self.transitions.edges(data=True):
@@ -444,7 +461,7 @@ def _get_active(self):
444461

445462
return active
446463

447-
def _get_susceptible(self):
464+
def _get_susceptible(self) -> Set:
448465
susceptible = set([node for node, deg in self.transitions.in_degree() if deg==0])
449466

450467
if len(susceptible) == 0:
@@ -454,7 +471,7 @@ def _get_susceptible(self):
454471

455472
return susceptible
456473

457-
def _get_infections(self):
474+
def _get_infections(self) -> Dict:
458475
inf = {}
459476

460477
for node_i, node_j, data in self.transitions.edges(data=True):
@@ -472,7 +489,7 @@ def _get_infections(self):
472489

473490
return inf
474491

475-
def draw_model(self, ax=None, show=True):
492+
def draw_model(self, ax: Union[plt.Axes, None] = None, show: bool = True) -> None:
476493
"""
477494
Plot the model structure
478495
@@ -516,7 +533,7 @@ def draw_model(self, ax=None, show=True):
516533
if show:
517534
plt.show()
518535

519-
def R0(self):
536+
def R0(self) -> Union[float, None]:
520537
"""
521538
Return the value of the basic reproductive ratio, $R_0$, for the model as defined
522539

src/epidemik/MetaEpiModel.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MetaEpiModel:
2222
2323
Provides a way to implement and numerically integrate
2424
"""
25-
def __init__(self, travel_graph, populations, population='Population'):
25+
def __init__(self, travel_graph: pd.DataFrame, populations: pd.DataFrame, population: str ='Population'):
2626
"""
2727
Initialize the EpiModel object
2828
@@ -51,7 +51,7 @@ def __init__(self, travel_graph, populations, population='Population'):
5151

5252
self.models = models
5353

54-
def __repr__(self):
54+
def __repr__(self) -> str:
5555
"""
5656
Return a string representation of the EpiModel object
5757
@@ -65,7 +65,7 @@ def __repr__(self):
6565
text = "Metapopulation model with %u populations\n\nThe disease is defined by an %s" % (self.travel_graph.shape[0], model_text)
6666
return text
6767

68-
def add_interaction(self, source, target, agent, rate):
68+
def add_interaction(self, source: str, target: str, agent: str, rate: float) -> None:
6969
"""
7070
Add an interaction between two compartments_
7171
@@ -85,7 +85,7 @@ def add_interaction(self, source, target, agent, rate):
8585
for state in self.models:
8686
self.models[state].add_interaction(source, target, agent, rate)
8787

88-
def add_spontaneous(self, source, target, rate):
88+
def add_spontaneous(self, source: str, target: str, rate: float) -> None:
8989
"""
9090
Add a spontaneous transition between two compartments_
9191
@@ -103,7 +103,7 @@ def add_spontaneous(self, source, target, rate):
103103
for state in self.models:
104104
self.models[state].add_spontaneous(source, target, rate)
105105

106-
def add_vaccination(self, source, target, rate, start):
106+
def add_vaccination(self, source: str, target: str, rate: float, start: int) -> None:
107107
"""
108108
Add a vaccination transition between two compartments_
109109
@@ -123,11 +123,11 @@ def add_vaccination(self, source, target, rate, start):
123123
for state in self.models:
124124
self.models[state].add_vaccination(source, target, rate, start)
125125

126-
def R0(self):
126+
def R0(self) -> Union[float, None]:
127127
key = list(self.models.keys())[0]
128128
return self.models[key].R0()
129129

130-
def get_state(self, state):
130+
def get_state(self, state: str) -> EpiModel:
131131
"""
132132
Return a reference to a state EpiModel object
133133
@@ -138,7 +138,7 @@ def get_state(self, state):
138138

139139
return self.models[state]
140140

141-
def _initialize_populations(self, susceptible, population=None):
141+
def _initialize_populations(self, susceptible: str, population: Union[pd.DataFrame, None] =None) -> None:
142142
columns = list(self.transitions.nodes())
143143
self.compartments_ = pd.DataFrame(np.zeros((self.travel_graph.shape[0], len(columns)), dtype='int'), columns=columns)
144144
self.compartments_.index = self.populations.index
@@ -149,8 +149,8 @@ def _initialize_populations(self, susceptible, population=None):
149149
for state in self.compartments_.index:
150150
self.compartments_.loc[state, susceptible] = self.populations.loc[state, population]
151151

152-
def _run_travel(self, compartments_, travel):
153-
def travel_step(x, populations):
152+
def _run_travel(self, compartments_: pd.DataFrame, travel: pd.DataFrame) -> pd.DataFrame:
153+
def travel_step(x, populations: pd.DataFrame) -> pd.Series:
154154
n = populations.loc[x.name]
155155
p = travel.loc[x.name].values.tolist()
156156
output = np.random.multinomial(n, p)
@@ -163,17 +163,24 @@ def travel_step(x, populations):
163163
# Travel occurs independently for each compartment
164164
# since we don't allow in-flight transitions
165165
for comp in compartments_.columns:
166-
new_compartments[comp] = travel.apply(travel_step, populations=compartments_[comp]).sum(axis=1)
166+
new_compartments[comp] = travel.apply(
167+
travel_step,
168+
populations=compartments_[comp]
169+
).sum(axis=1)
167170

168171
return new_compartments
169172

170-
def _run_spread(self):
173+
def _run_spread(self) -> None:
171174
for state in self.compartments_.index:
172175
pop = self.compartments_.loc[state].to_dict()
173176
self.models[state].single_step(**pop)
174177
self.compartments_.loc[state] = self.models[state].values_.iloc[[-1]].values[0]
175178

176-
def simulate(self, timestamp, t_min=1, seasonality=None, seed_state=None, susceptible='S', **kwargs):
179+
def simulate(
180+
self, timestamp: int, t_min: int = 1,
181+
seasonality=None, seed_state: [str, None] = None,
182+
susceptible: str ='S', **kwargs
183+
) -> None:
177184
if seed_state is None:
178185
raise NotInitialized("You have to specify the seed_state")
179186

@@ -193,10 +200,10 @@ def simulate(self, timestamp, t_min=1, seasonality=None, seed_state=None, suscep
193200
def integrate(self, **kwargs):
194201
raise NotImplementedError("MetaEpiModel doesn't support direct integration of the ODE")
195202

196-
def draw_model(self):
203+
def draw_model(self) -> None:
197204
return self.models.iloc[0].draw_model()
198205

199-
def plot(self, title=None, normed=True, layout=None, **kwargs):
206+
def plot(self, title: Union[str, None] = None, normed: bool = True, layout=None, **kwargs) -> None:
200207
if layout is None:
201208
n_pop = self.travel_graph.shape[0]
202209
N = int(np.round(np.sqrt(n_pop), 0)+1)
@@ -270,7 +277,7 @@ def plot(self, title=None, normed=True, layout=None, **kwargs):
270277
fig.patch.set_facecolor('#FFFFFF')
271278
fig.tight_layout()
272279

273-
def plot_peaks(self):
280+
def plot_peaks(self) -> None:
274281
peaks = None
275282

276283
for state in self.models.values():
@@ -301,4 +308,4 @@ def plot_peaks(self):
301308
ax.set_xticks(np.arange(0, peaks.shape[1], 3))
302309
ax.set_xticklabels(np.arange(0, peaks.shape[1], 3), fontsize=10)
303310
# ax.set_aspect(1)
304-
fig.patch.set_facecolor('#FFFFFF')
311+
fig.patch.set_facecolor('#FFFFFF')

src/epidemik/NetworkEpiModel.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# @author Bruno Goncalves
44
######################################################
55

6+
from typing import Union
67
import networkx as nx
78
import numpy as np
89
from numpy import linalg
@@ -24,7 +25,7 @@ def __init__(self, network, compartments=None):
2425
def integrate(self, timesteps, **kwargs):
2526
raise NotImplementedError("Network Models don't support numerical integration")
2627

27-
def add_interaction(self, source, target, agent, rate, rescale=False):
28+
def add_interaction(self, source: str, target: str, agent: str, rate: float, rescale: bool = False) -> None:
2829
if rescale:
2930
rate /= self.kavg_
3031

@@ -38,7 +39,7 @@ def add_interaction(self, source, target, agent, rate, rescale=False):
3839

3940
self.interactions[source][agent] = {'target': target, 'rate': rate}
4041

41-
def add_spontaneous(self, source, target, rate):
42+
def add_spontaneous(self, source: str, target: str, rate: float) -> None:
4243
super(NetworkEpiModel, self).add_spontaneous(source, target, rate=rate)
4344
if source not in self.spontaneous:
4445
self.spontaneous[source] = {}
@@ -49,7 +50,7 @@ def add_spontaneous(self, source, target, rate):
4950
self.spontaneous[source][target] = rate
5051

5152

52-
def simulate(self, timesteps, seeds, **kwargs):
53+
def simulate(self, timesteps: int, seeds, **kwargs) -> None:
5354
"""Stochastically simulate the epidemic model"""
5455
pos = {comp: i for i, comp in enumerate(self.transitions.nodes())}
5556
N = self.network.number_of_nodes()
@@ -130,7 +131,7 @@ def simulate(self, timesteps, seeds, **kwargs):
130131
self.population_ = pd.DataFrame(population)
131132
self.values_ = pd.DataFrame.from_records(self.population_.apply(lambda x: Counter(x), axis=1)).fillna(0).astype('int')
132133

133-
def R0(self):
134+
def R0(self) -> Union[float, None]:
134135
if 'R' not in set(self.transitions.nodes):
135136
return None
136-
return np.round(super(NetworkEpiModel, self).R0()*self.kavg_, 2)
137+
return np.round(super(NetworkEpiModel, self).R0()*self.kavg_, 2)

0 commit comments

Comments
 (0)