Skip to content

Commit 5ab8bda

Browse files
committed
improves structure
1 parent 6eaf095 commit 5ab8bda

File tree

2 files changed

+21
-21
lines changed

2 files changed

+21
-21
lines changed

climada/trajectories/risk_trajectory.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def _compute_period_metrics(
288288
df = self._generic_metrics(
289289
npv=npv, metric_name=metric_name, metric_meth=metric_meth, **kwargs
290290
)
291-
return self._date_to_period_agg(df)
291+
return self._date_to_period_agg(df, grouper=self._grouper)
292292

293293
def _compute_metrics(
294294
self, metric_name: str, metric_meth: str, npv: bool = True, **kwargs
@@ -506,47 +506,47 @@ def _get_risk_periods(
506506
)
507507
]
508508

509+
@staticmethod
510+
def identify_continuous_periods(group, time_unit):
511+
# Calculate the difference between consecutive dates
512+
if time_unit == "year":
513+
group["date_diff"] = group["date"].dt.year.diff()
514+
if time_unit == "month":
515+
group["date_diff"] = group["date"].dt.month.diff()
516+
if time_unit == "day":
517+
group["date_diff"] = group["date"].dt.day.diff()
518+
if time_unit == "hour":
519+
group["date_diff"] = group["date"].dt.hour.diff()
520+
# Identify breaks in continuity
521+
group["period_id"] = (group["date_diff"] != 1).cumsum()
522+
return group
523+
509524
@classmethod
510525
def _date_to_period_agg(
511526
cls,
512527
df: pd.DataFrame,
528+
grouper: list[str],
513529
time_unit: str = "year",
514530
colname: str | list[str] = "risk",
515531
) -> pd.DataFrame | pd.Series:
516532
"""Groups per date risk metric to periods."""
517533

518534
## I'm thinking this does not work with RPs... As you can't just sum impacts
519535
## Not sure what to do with it. -> Fixed I take the avg RP impact of the period
520-
521-
def identify_continuous_periods(group, time_unit):
522-
# Calculate the difference between consecutive dates
523-
if time_unit == "year":
524-
group["date_diff"] = group["date"].dt.year.diff()
525-
if time_unit == "month":
526-
group["date_diff"] = group["date"].dt.month.diff()
527-
if time_unit == "day":
528-
group["date_diff"] = group["date"].dt.day.diff()
529-
if time_unit == "hour":
530-
group["date_diff"] = group["date"].dt.hour.diff()
531-
# Identify breaks in continuity
532-
group["period_id"] = (group["date_diff"] != 1).cumsum()
533-
return group
534-
535536
def conditional_agg(group):
536537
if "rp" in group.name[2]:
537538
return group.mean()
538539
else:
539540
return group.sum()
540541

541-
grouper = cls._grouper
542542
if "group" in df.columns and "group" not in grouper:
543543
grouper = ["group"] + grouper
544544

545545
df_sorted = df.sort_values(by=cls._grouper + ["date"])
546546
# Apply the function to identify continuous periods
547547
df_periods = df_sorted.groupby(
548548
grouper, dropna=False, group_keys=False, observed=True
549-
).apply(identify_continuous_periods, time_unit)
549+
).apply(cls.identify_continuous_periods, time_unit)
550550

551551
if isinstance(colname, str):
552552
colname = [colname]
@@ -572,7 +572,7 @@ def conditional_agg(group):
572572
df_periods.groupby(grouper + ["period_id"], dropna=False, observed=True)[
573573
colname
574574
]
575-
.apply(lambda group: conditional_agg(group))
575+
.apply(conditional_agg)
576576
.reset_index()
577577
)
578578
df_periods = pd.merge(
@@ -588,7 +588,7 @@ def per_period_risk_metrics(
588588
) -> pd.DataFrame | pd.Series:
589589
"""Returns a tidy dataframe of the risk metrics with the total for each different period."""
590590
df = self.per_date_risk_metrics(metrics=metrics, **kwargs)
591-
return self._date_to_period_agg(df, **kwargs)
591+
return self._date_to_period_agg(df, grouper=self._grouper, **kwargs)
592592

593593
def _calc_waterfall_plot_data(
594594
self,

climada/trajectories/test/test_risk_trajectory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,7 @@ def test_per_period_risk_multiple_risk_cols(self):
660660
}
661661
)
662662
result_df = RiskTrajectory._date_to_period_agg(
663-
df_input, colname=["base risk", "exposure contribution"]
663+
df_input, col_agg_dict=["base risk", "exposure contribution"]
664664
)
665665

666666
expected_df = pd.DataFrame(

0 commit comments

Comments
 (0)