Skip to content

Commit 3ceb016

Browse files
the callable validator design pattern
1 parent 81e3762 commit 3ceb016

File tree

4 files changed

+75
-8
lines changed

4 files changed

+75
-8
lines changed

src/penn_chime/parameters.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datetime import date
99
from typing import Optional
1010

11+
from .validators import Positive, StrictlyPositive, Rate
1112

1213
# Parameters for each disposition (hospitalized, icu, ventilated)
1314
# The rate of disposition within the population of infected
@@ -64,30 +65,36 @@ def __init__(
6465
recovered: int = 0,
6566
region: Optional[Regions] = None,
6667
):
67-
self.current_hospitalized = current_hospitalized
68-
self.relative_contact_rate = relative_contact_rate
68+
self.current_hospitalized = StrictlyPositive(value=current_hospitalized)
69+
self.relative_contact_rate = Rate(value=relative_contact_rate)
70+
71+
Rate(value=hospitalized.rate), Rate(value=icu.rate), Rate(value=ventilated.rate)
72+
StrictlyPositive(value=hospitalized.days), StrictlyPositive(value=icu.days),
73+
StrictlyPositive(value=ventilated.days)
6974

7075
self.hospitalized = hospitalized
7176
self.icu = icu
7277
self.ventilated = ventilated
7378

7479
if region is not None and population is None:
7580
self.region = region
76-
self.population = region.population
81+
self.population = StrictlyPositive(value=region.population)
7782
elif population is not None:
7883
self.region = None
79-
self.population = population
84+
self.population = StrictlyPositive(value=population)
8085
else:
8186
raise AssertionError('population or regions must be provided.')
8287

8388
self.current_date = current_date
89+
8490
self.date_first_hospitalized = date_first_hospitalized
8591
self.doubling_time = doubling_time
86-
self.infectious_days = infectious_days
87-
self.market_share = market_share
92+
93+
self.infectious_days = StrictlyPositive(value=infectious_days)
94+
self.market_share = Rate(value=market_share)
8895
self.max_y_axis = max_y_axis
89-
self.n_days = n_days
90-
self.recovered = recovered
96+
self.n_days = StrictlyPositive(value=n_days)
97+
self.recovered = Positive(value=recovered)
9198

9299
self.labels = {
93100
"hospitalized": "Hospitalized",

src/penn_chime/validators/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
"""the callable validator design pattern"""
2+
3+
from .validators import Bounded, Rate
4+
5+
EPSILON = 1.e-7
6+
7+
StrictlyPositive = Bounded(lower_bound=EPSILON)
8+
Positive = Bounded(lower_bound=-EPSILON)
9+
Rate = Rate() # type: ignore

src/penn_chime/validators/base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""design pattern via https://youtu.be/S_ipdVNSFlo?t=2153, modified such that validators are _callable_"""
2+
3+
from abc import ABC, abstractmethod
4+
5+
class Validator(ABC):
6+
def __set_name__(self, owner, name):
7+
self.private_name = f"_{name}"
8+
9+
def __call__(self, *, value):
10+
self.validate(value)
11+
return value
12+
13+
@abstractmethod
14+
def validate(self, value):
15+
pass
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
"""design pattern via https://youtu.be/S_ipdVNSFlo?t=2153"""
2+
3+
from typing import Optional
4+
5+
from .base import Validator
6+
7+
8+
class Bounded(Validator):
9+
10+
def __init__(
11+
self,
12+
lower_bound: Optional[float] = None,
13+
upper_bound: Optional[float] = None) -> None:
14+
self.lower_bound = lower_bound
15+
self.upper_bound = upper_bound
16+
self.message = {
17+
(lower_bound, upper_bound): f"in ({self.lower_bound}, {self.upper_bound})",
18+
(None, upper_bound): f"less than {self.upper_bound}",
19+
(lower_bound, None): f"greater than {self.lower_bound}",
20+
(None, None): "ACTUALLY the value is unbounded"
21+
}
22+
23+
def validate(self, value):
24+
if (self.upper_bound is not None and value > self.upper_bound) \
25+
or (self.lower_bound is not None and value < self.lower_bound):
26+
raise ValueError(f"{value} needs to be {self.message[(self.lower_bound, self.upper_bound)]}.")
27+
28+
29+
class Rate(Validator):
30+
def __init__(self) -> None:
31+
pass
32+
33+
def validate(self, value):
34+
if 0 >= value or value >= 1:
35+
raise ValueError(f"{value} needs to be a rate (i.e. in [0,1])")
36+

0 commit comments

Comments
 (0)