@@ -97,11 +97,14 @@ def duration2rate(x):
9797
9898# HIV params---------------------------------------------
9999# Parameters we have to decide
100- k_on = fraction2rate (0.3 ) # annual PrEP uptake rate
101- k_off = duration2rate (5.0 ) # average duration of taking PrEP per year
100+ # k_on = fraction2rate(0.3) # annual PrEP uptake rate
101+ # k_off = duration2rate(5.0) # average duration of taking PrEP per year
102+ k_on = jnp .array ([0 ,0 ,0 ,- jnp .log (1 - 0.3 ) / 365 ]) # annual PrEP uptake rate
103+ k_off = jnp .array ([0 ,0 ,0 , 1 / 5.0 / 365 ]) # average duration of taking PrEP per year
104+ tau_p = jnp .array ([0 ,0 ,0 ,- jnp .log (1 - 0.95 ) / 365 ]) # annual ART uptake rate
102105
103106# from GannaRozhnova paper (Elimination prospects of the Dutch HIV epidemic)
104- tau_p = fraction2rate (0.95 ) # annual ART uptake rate
107+ # tau_p = fraction2rate(0.95) # annual ART uptake rate
105108c = jnp .array ([0.13 , 1.43 , 5.44 , 18.21 ]) / 365.0 # per year, average number of partners in risk group l
106109h = jnp .array ([0.62 , 0.12 , 0.642 , 0.0 ]) # infectivity of untreated individuals in stage k of infection
107110phis = fraction2rate (0.05 ) # per year, annual ART dropout rate
@@ -165,14 +168,6 @@ def duration2rate(x):
165168
166169
167170
168- # TODO thisn logging stuff is also done in the beginning already, delete here??
169- # Configure logging
170- logging .basicConfig (level = logging .INFO )
171- logger = logging .getLogger (__name__ )
172- # Global flag to track logging
173- logged_exp_logis = False
174- logged_tau = False
175-
176171def m (args , y ):
177172 """
178173 Exponential function with three parameters: minimum value, maximum value, and rate/tau.
@@ -241,6 +236,7 @@ def contact_matrix(y, args):
241236
242237 return mixing + diagonal
243238
239+
244240def hazard (y , args ):
245241 """
246242 Calculates the hazard from the HIV model used for risk-perception.
@@ -268,7 +264,7 @@ def alive_fraction_HIV(y):
268264 Returns:
269265 Value as float64.
270266 """
271- return jnp .sum ( jnp . array ([ y [ comp ] for comp in [ "S" , "SP" , "I1" , "IP" , "I2" , "I3" , "I4" , "A1" , "A2" , "A3" , "A4" ]]) )
267+ return 1 - jnp .sum (y [ "D" ] )
272268
273269def alive_fraction_STI (y ):
274270 """
@@ -277,7 +273,7 @@ def alive_fraction_STI(y):
277273 Returns:
278274 Value as float64.
279275 """
280- return jnp .sum ( jnp . array ([ y [ comp ] for comp in [ "S_STI" , "Ia_STI" , "Is_STI" , "T_STI" ]]) )
276+ return 1 - jnp .sum (y [ "D_STI" ] )
281277
282278def lambda_a (y , args ):
283279 """
0 commit comments