Skip to content

Commit 33ea06d

Browse files
Merge branch 'develop' into use-pytest-fixtures
2 parents 99c84a3 + 2fa994f commit 33ea06d

File tree

4 files changed

+119
-12
lines changed

4 files changed

+119
-12
lines changed

src/penn_chime/parameters.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from datetime import date
99
from typing import Optional
1010

11+
from .validators import (
12+
Positive, OptionalStrictlyPositive, StrictlyPositive, Rate, Date, OptionalDate
13+
)
1114

1215
# Parameters for each disposition (hospitalized, icu, ventilated)
1316
# The rate of disposition within the population of infected
@@ -64,30 +67,36 @@ def __init__(
6467
recovered: int = 0,
6568
region: Optional[Regions] = None,
6669
):
67-
self.current_hospitalized = current_hospitalized
68-
self.relative_contact_rate = relative_contact_rate
70+
self.current_hospitalized = StrictlyPositive(value=current_hospitalized)
71+
self.relative_contact_rate = Rate(value=relative_contact_rate)
72+
73+
Rate(value=hospitalized.rate), Rate(value=icu.rate), Rate(value=ventilated.rate)
74+
StrictlyPositive(value=hospitalized.days), StrictlyPositive(value=icu.days),
75+
StrictlyPositive(value=ventilated.days)
6976

7077
self.hospitalized = hospitalized
7178
self.icu = icu
7279
self.ventilated = ventilated
7380

7481
if region is not None and population is None:
7582
self.region = region
76-
self.population = region.population
83+
self.population = StrictlyPositive(value=region.population)
7784
elif population is not None:
7885
self.region = None
79-
self.population = population
86+
self.population = StrictlyPositive(value=population)
8087
else:
8188
raise AssertionError('population or regions must be provided.')
8289

83-
self.current_date = current_date
84-
self.date_first_hospitalized = date_first_hospitalized
85-
self.doubling_time = doubling_time
86-
self.infectious_days = infectious_days
87-
self.market_share = market_share
88-
self.max_y_axis = max_y_axis
89-
self.n_days = n_days
90-
self.recovered = recovered
90+
self.current_date = Date(value=current_date)
91+
92+
self.date_first_hospitalized = OptionalDate(value=date_first_hospitalized)
93+
self.doubling_time = OptionalStrictlyPositive(value=doubling_time)
94+
95+
self.infectious_days = StrictlyPositive(value=infectious_days)
96+
self.market_share = Rate(value=market_share)
97+
self.max_y_axis = OptionalStrictlyPositive(value=max_y_axis)
98+
self.n_days = StrictlyPositive(value=n_days)
99+
self.recovered = Positive(value=recovered)
91100

92101
self.labels = {
93102
"hospitalized": "Hospitalized",

src/penn_chime/validators/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""the callable validator design pattern"""
2+
3+
from .validators import Bounded, OptionalBounded, Rate, Date, OptionalDate
4+
5+
EPSILON = 1.e-7
6+
7+
OptionalStrictlyPositive = OptionalBounded(lower_bound=EPSILON)
8+
StrictlyPositive = Bounded(lower_bound=EPSILON)
9+
Positive = Bounded(lower_bound=-EPSILON)
10+
Rate = Rate() # type: ignore
11+
Date = Date() # type: ignore
12+
OptionalDate = OptionalDate() # type: ignore
13+
# # rolling a custom validator for doubling time in case DS wants to add upper bound
14+
# DoublingTime = OptionalBounded(lower_bound=0-EPSILON, upper_bound=None)

src/penn_chime/validators/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""design pattern via https://youtu.be/S_ipdVNSFlo?t=2153, modified such that validators are _callable_"""
2+
3+
from abc import ABC, abstractmethod
4+
5+
class Validator(ABC):
6+
def __set_name__(self, owner, name):
7+
self.private_name = f"_{name}"
8+
9+
def __call__(self, *, value):
10+
self.validate(value)
11+
return value
12+
13+
@abstractmethod
14+
def validate(self, value):
15+
pass
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""design pattern via https://youtu.be/S_ipdVNSFlo?t=2153"""
2+
3+
from typing import Optional
4+
from datetime import date, datetime
5+
6+
from .base import Validator
7+
8+
9+
class Bounded(Validator):
10+
"""A bounded number."""
11+
def __init__(
12+
self,
13+
lower_bound: Optional[float] = None,
14+
upper_bound: Optional[float] = None) -> None:
15+
assert lower_bound is not None or upper_bound is not None, "Do not use this object to create an unbounded validator."
16+
self.lower_bound = lower_bound
17+
self.upper_bound = upper_bound
18+
self.message = {
19+
(lower_bound, upper_bound): f"in ({self.lower_bound}, {self.upper_bound})",
20+
(None, upper_bound): f"less than {self.upper_bound}",
21+
(lower_bound, None): f"greater than {self.lower_bound}",
22+
}
23+
24+
def validate(self, value):
25+
"""This method implicitly validates isinstance(value, (float, int)) because it will throw a TypeError on comparison"""
26+
if (self.upper_bound is not None and value > self.upper_bound) \
27+
or (self.lower_bound is not None and value < self.lower_bound):
28+
raise ValueError(f"{value} needs to be {self.message[(self.lower_bound, self.upper_bound)]}.")
29+
30+
31+
class OptionalBounded(Bounded):
32+
"""a bounded number or a None."""
33+
def __init__(
34+
self,
35+
lower_bound: Optional[float] = None,
36+
upper_bound: Optional[float] = None) -> None:
37+
super().__init__(lower_bound=lower_bound, upper_bound=upper_bound)
38+
39+
def validate(self, value):
40+
if value is None:
41+
return None
42+
super().validate(value)
43+
44+
class Rate(Validator):
45+
"""A rate in [0,1]."""
46+
def __init__(self) -> None:
47+
pass
48+
49+
def validate(self, value):
50+
if 0 >= value or value >= 1:
51+
raise ValueError(f"{value} needs to be a rate (i.e. in [0,1]).")
52+
53+
class Date(Validator):
54+
"""A date of some sort."""
55+
def __init__(self) -> None:
56+
pass
57+
58+
def validate(self, value):
59+
if not isinstance(value, (date, datetime)):
60+
raise (ValueError(f"{value} must be a date or datetime object."))
61+
62+
class OptionalDate(Date):
63+
def __init__(self) -> None:
64+
super().__init__()
65+
66+
def validate(self, value):
67+
if value is None:
68+
return None
69+
super().validate(value)

0 commit comments

Comments
 (0)