33# @author Bruno Goncalves
44######################################################
55
6+ from typing import Dict , List , Set , Union
67import warnings
8+ import string
9+
710import networkx as nx
811import numpy as np
912from numpy import linalg
1013from numpy import random
1114import scipy .integrate
1215import pandas as pd
1316import matplotlib .pyplot as plt
14- import string
17+
1518from .utils import *
1619
1720class EpiModel (object ):
@@ -38,7 +41,7 @@ def __init__(self, compartments=None):
3841 if compartments is not None :
3942 self .transitions .add_nodes_from ([comp for comp in compartments ])
4043
41- def add_interaction (self , source , target , agent , rate ) :
44+ def add_interaction (self , source : str , target : str , agent : str , rate : float ) -> None :
4245 """
4346 Add an interaction between two compartments
4447
@@ -57,7 +60,7 @@ def add_interaction(self, source, target, agent, rate):
5760 """
5861 self .transitions .add_edge (source , target , agent = agent , rate = rate )
5962
60- def add_spontaneous (self , source , target , rate ) :
63+ def add_spontaneous (self , source : str , target : str , rate : float ) -> None :
6164 """
6265 Add a spontaneous transition between two compartments
6366
@@ -74,7 +77,7 @@ def add_spontaneous(self, source, target, rate):
7477 """
7578 self .transitions .add_edge (source , target , rate = rate )
7679
77- def add_vaccination (self , source , target , rate , start ) :
80+ def add_vaccination (self , source : str , target : str , rate : float , start : int ) -> None :
7881 """
7982 Add a vaccination transition between two compartments
8083
@@ -93,7 +96,14 @@ def add_vaccination(self, source, target, rate, start):
9396 """
9497 self .transitions .add_edge (source , target , rate = rate , start = start )
9598
96- def add_age_structure (self , matrix , population ):
99+ def add_age_structure (self , matrix : List , population : List ) -> List [List ]:
100+ """
101+ Add a vaccination transition between two compartments
102+
103+ Parameters:
104+ - matrix: List
105+ - population: List
106+ """
97107 self .contact = np .asarray (matrix )
98108 self .population = np .asarray (population ).flatten ()
99109
@@ -131,7 +141,7 @@ def add_age_structure(self, matrix, population):
131141
132142 self .transitions = model .transitions
133143
134- def _new_cases (self , population , time , pos ) :
144+ def _new_cases (self , population : np . ndarray , time : float , pos : Dict ) -> np . ndarray :
135145 """
136146 Internal function used by integration routine
137147
@@ -188,7 +198,7 @@ def _new_cases(self, population, time, pos):
188198
189199 return diff
190200
191- def plot (self , title = None , normed = True , show = True , ax = None , ** kwargs ):
201+ def plot (self , title : Union [ str , None ] = None , normed : bool = True , show : bool = True , ax : Union [ plt . Axes , None ] = None , ** kwargs ):
192202 """
193203 Convenience function for plotting
194204
@@ -235,7 +245,7 @@ def plot(self, title=None, normed=True, show=True, ax=None, **kwargs):
235245 print (e )
236246 raise NotInitialized ('You must call integrate() or simulate() first' )
237247
238- def __getattr__ (self , name ) :
248+ def __getattr__ (self , name : str ) -> pd . Series :
239249 """
240250 Dynamic method to return the individual compartment values
241251
@@ -252,7 +262,7 @@ def __getattr__(self, name):
252262 else :
253263 raise AttributeError ("'EpiModel' object has no attribute '%s'" % name )
254264
255- def simulate (self , timesteps , t_min = 1 , seasonality = None , ** kwargs ):
265+ def simulate (self , timesteps : int , t_min : int = 1 , seasonality : Union [ np . ndarray , None ] = None , ** kwargs ) -> None :
256266 """
257267 Stochastically simulate the epidemic model
258268
@@ -334,7 +344,7 @@ def simulate(self, timesteps, t_min=1, seasonality=None, **kwargs):
334344 values = np .array (values )
335345 self .values_ = pd .DataFrame (values [1 :], columns = comps , index = time )
336346
337- def integrate (self , timesteps , t_min = 1 , seasonality = None , ** kwargs ):
347+ def integrate (self , timesteps : int , t_min : int = 1 , seasonality : Union [ np . ndarray , None ] = None , ** kwargs ) -> None :
338348 """
339349 Numerically integrate the epidemic model
340350
@@ -375,7 +385,14 @@ def integrate(self, timesteps, t_min=1, seasonality=None, **kwargs):
375385 time = np .arange (t_min , t_min + timesteps , 1 )
376386
377387 self .seasonality = seasonality
378- values = pd .DataFrame (scipy .integrate .odeint (self ._new_cases , population , time , args = (pos ,)), columns = pos .keys (), index = time )
388+ values = pd .DataFrame (
389+ scipy .integrate .odeint (
390+ self ._new_cases ,
391+ population ,
392+ time ,
393+ args = (pos ,)
394+ ), columns = pos .keys (), index = time
395+ )
379396
380397 if self .population is None :
381398 self .values_ = values
@@ -398,7 +415,7 @@ def single_step(self, seasonality=None, **kwargs):
398415 new_values = pd .concat ([old_values , self .values_ .iloc [[- 1 ]]])
399416 self .values_ = new_values
400417
401- def __repr__ (self ):
418+ def __repr__ (self ) -> str :
402419 """
403420 Return a string representation of the EpiModel object
404421
@@ -433,7 +450,7 @@ def __repr__(self):
433450
434451 return text
435452
436- def _get_active (self ):
453+ def _get_active (self ) -> Set :
437454 active = set ()
438455
439456 for node_i , node_j , data in self .transitions .edges (data = True ):
@@ -444,7 +461,7 @@ def _get_active(self):
444461
445462 return active
446463
447- def _get_susceptible (self ):
464+ def _get_susceptible (self ) -> Set :
448465 susceptible = set ([node for node , deg in self .transitions .in_degree () if deg == 0 ])
449466
450467 if len (susceptible ) == 0 :
@@ -454,7 +471,7 @@ def _get_susceptible(self):
454471
455472 return susceptible
456473
457- def _get_infections (self ):
474+ def _get_infections (self ) -> Dict :
458475 inf = {}
459476
460477 for node_i , node_j , data in self .transitions .edges (data = True ):
@@ -472,7 +489,7 @@ def _get_infections(self):
472489
473490 return inf
474491
475- def draw_model (self , ax = None , show = True ):
492+ def draw_model (self , ax : Union [ plt . Axes , None ] = None , show : bool = True ) -> None :
476493 """
477494 Plot the model structure
478495
@@ -516,7 +533,7 @@ def draw_model(self, ax=None, show=True):
516533 if show :
517534 plt .show ()
518535
519- def R0 (self ):
536+ def R0 (self ) -> Union [ float , None ] :
520537 """
521538 Return the value of the basic reproductive ratio, $R_0$, for the model as defined
522539
0 commit comments