Skip to content

Commit d328da3

Browse files
committed
model: Make get_loss a standalone pure function
1 parent 6d4913f commit d328da3

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/penn_chime/models.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,8 @@ def get_argmin_doubling_time(self, p: Parameters, dts):
167167
if peak_admits_day < 0:
168168
continue
169169

170-
loss = self.get_loss(raw)
170+
predicted = raw["census_hospitalized"][self.i_day]
171+
loss = get_loss(self.current_hospitalized, predicted)
171172
losses[i] = loss
172173

173174
min_loss = pd.Series(losses).argmin()
@@ -208,10 +209,10 @@ def run_projection(self, p: Parameters, policy: Sequence[Tuple[float, int]]):
208209

209210
return raw
210211

211-
def get_loss(self, raw) -> float:
212-
"""Squared error: predicted vs. actual current hospitalized."""
213-
predicted = raw["census_hospitalized"][self.i_day]
214-
return (self.current_hospitalized - predicted) ** 2.0
212+
213+
def get_loss(current_hospitalized, predicted) -> float:
214+
"""Squared error: predicted vs. actual current hospitalized."""
215+
return (current_hospitalized - predicted) ** 2.0
215216

216217

217218
def get_argmin_ds(census, current_hospitalized: float) -> float:

0 commit comments

Comments
 (0)