2929import matplotlib .pyplot as plt
3030import matplotlib .ticker as ticker
3131import pandas as pd
32+ from pandas .tseries .frequencies import to_offset
3233
3334from climada .entity .disc_rates .base import DiscRates
3435from climada .trajectories .interpolation import InterpolationStrategyBase
@@ -62,7 +63,7 @@ def __init__(
6263 self ,
6364 snapshots_list : list [Snapshot ],
6465 * ,
65- time_resolution : str = "YS " ,
66+ time_resolution : str = "Y " ,
6667 all_groups_name : str = "All" ,
6768 risk_disc : DiscRates | None = None ,
6869 interpolation_strategy : InterpolationStrategyBase | None = None ,
@@ -409,7 +410,10 @@ def risk_components_metrics(self, npv: bool = True, **kwargs) -> pd.DataFrame:
409410 if len (self ._snapshots ) > 2 :
410411 tmp .set_index (["group" , "date" , "measure" , "metric" ], inplace = True )
411412 start_dates = [snap .date for snap in self ._snapshots [:- 1 ]]
412- end_dates = [snap .date for snap in self ._snapshots [1 :]]
413+ end_dates = [
414+ snap .date - to_offset (self ._time_resolution )
415+ for snap in self ._snapshots [1 :]
416+ ]
413417 periods_dates = list (zip (start_dates , end_dates ))
414418 tmp .loc [pd .IndexSlice [:, :, :, "base risk" ]] = tmp .loc [
415419 pd .IndexSlice [:, str (self .start_date ), :, "base risk" ]
@@ -435,7 +439,7 @@ def risk_components_metrics(self, npv: bool = True, **kwargs) -> pd.DataFrame:
435439 ].iloc [0 ]
436440
437441 tmp .reset_index (inplace = True )
438- return tmp
442+ return tmp . drop ( "index" , axis = 1 , errors = "ignore" )
439443
440444 def per_date_risk_metrics (
441445 self ,
@@ -486,6 +490,7 @@ def _get_risk_periods(
486490 risk_periods : list [CalcRiskPeriod ],
487491 start_date : datetime .date ,
488492 end_date : datetime .date ,
493+ strict : bool = True ,
489494 ):
490495 """Returns risk periods from the given list that are within `start_date` and `end_date`.
491496
@@ -495,16 +500,28 @@ def _get_risk_periods(
495500 The list of risk periods to look through
496501 start_date : datetime.date
497502 end_date : datetime.date
498-
503+ strict: bool, default True
504+ If true, only returns periods stricly within start and end dates. Else,
505+ returns periods that have an overlap within start and end.
499506 """
500- return [
501- period
502- for period in risk_periods
503- if (
504- start_date <= period .snapshot_start .date
505- or end_date >= period .snapshot_end .date
506- )
507- ]
507+ if strict :
508+ return [
509+ period
510+ for period in risk_periods
511+ if (
512+ start_date <= period .snapshot_start .date
513+ and end_date >= period .snapshot_end .date
514+ )
515+ ]
516+ else :
517+ return [
518+ period
519+ for period in risk_periods
520+ if not (
521+ start_date >= period .snapshot_end .date
522+ or end_date <= period .snapshot_start .date
523+ )
524+ ]
508525
509526 @staticmethod
510527 def identify_continuous_periods (group , time_unit ):
@@ -605,8 +622,8 @@ def _calc_waterfall_plot_data(
605622 end_date = self .end_date if end_date is None else end_date
606623 risk_components = self .risk_components_metrics (npv )
607624 risk_components = risk_components .loc [
608- (risk_components ["date" ]. dt . date >= start_date )
609- & (risk_components ["date" ]. dt . date <= end_date )
625+ (risk_components ["date" ] >= str ( start_date ) )
626+ & (risk_components ["date" ] <= str ( end_date ) )
610627 ]
611628 risk_components = risk_components .set_index (["date" , "metric" ])[
612629 "risk"
@@ -664,7 +681,7 @@ def plot_per_date_waterfall(
664681 risk_component ["base risk" ] = risk_component .iloc [0 ]["base risk" ]
665682 # risk_component.plot(x="date", ax=ax, kind="bar", stacked=True)
666683 ax .stackplot (
667- risk_component .index ,
684+ risk_component .index . to_timestamp () ,
668685 [risk_component [col ] for col in risk_component .columns ],
669686 labels = risk_component .columns ,
670687 )
@@ -717,23 +734,25 @@ def plot_waterfall(
717734 """
718735 start_date = self .start_date if start_date is None else start_date
719736 end_date = self .end_date if end_date is None else end_date
737+ start_date_p = pd .to_datetime (start_date ).to_period (self ._time_resolution )
738+ end_date_p = pd .to_datetime (end_date ).to_period (self ._time_resolution )
720739 risk_component = self ._calc_waterfall_plot_data (
721740 start_date = start_date , end_date = end_date , npv = npv
722741 )
723742 if ax is None :
724743 _ , ax = plt .subplots (figsize = (8 , 5 ))
725744
726745 risk_component = risk_component .loc [
727- (risk_component .index . date == end_date )
746+ (risk_component .index == str ( end_date ) )
728747 ].squeeze ()
729748
730749 labels = [
731- f"Risk { start_date } " ,
732- f"Exposure contribution { end_date } " ,
733- f"Hazard contribution { end_date } " ,
734- f"Vulnerability contribution { end_date } " ,
735- f"Interaction contribution { end_date } " ,
736- f"Total Risk { end_date } " ,
750+ f"Risk { start_date_p } " ,
751+ f"Exposure contribution { end_date_p } " ,
752+ f"Hazard contribution { end_date_p } " ,
753+ f"Vulnerability contribution { end_date_p } " ,
754+ f"Interaction contribution { end_date_p } " ,
755+ f"Total Risk { end_date_p } " ,
737756 ]
738757 values = [
739758 risk_component ["base risk" ],
@@ -783,7 +802,7 @@ def plot_waterfall(
783802
784803 # Construct y-axis label and title based on parameters
785804 value_label = "USD"
786- title_label = f"Evolution of the components of risk between { start_date } and { end_date } (Average impact)"
805+ title_label = f"Evolution of the components of risk between { start_date_p } and { end_date_p } (Average impact)"
787806 ax .yaxis .set_major_formatter (ticker .EngFormatter ())
788807 ax .set_title (title_label )
789808 ax .set_ylabel (value_label )
0 commit comments