Skip to content

Commit 149167f

Browse files
committed
Increment random seed and update output handling
1 parent 1981fe9 commit 149167f

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

climada/util/calibrate/cross_calibrate.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from abc import ABC, abstractmethod
2222
from dataclasses import dataclass, InitVar, field
23-
from typing import Optional, List, Mapping, Any, Tuple, Union, Sequence
23+
from typing import Optional, List, Mapping, Any, Tuple, Union, Sequence, Dict
2424
from copy import copy, deepcopy
2525
from pathlib import Path
2626

@@ -46,12 +46,25 @@ def sample_data(data: pd.DataFrame, sample: List[Tuple[int, int]]):
4646
return data_sampled
4747

4848

49+
@dataclass
50+
class SingleEnsembleOptimizerOutput(Output):
51+
"""Output for a single member of an ensemble optimizer
52+
53+
Attributes
54+
----------
55+
event_info : dict(str, any)
56+
Information on the events for this calibration instance
57+
"""
58+
59+
event_info: Dict[str, Any] = field(default_factory=dict)
60+
61+
4962
@dataclass
5063
class EnsembleOptimizerOutput:
5164
data: pd.DataFrame
5265

5366
@classmethod
54-
def from_outputs(cls, outputs: Sequence[Output]):
67+
def from_outputs(cls, outputs: Sequence[SingleEnsembleOptimizerOutput]):
5568
"""Build data from a list of outputs"""
5669
cols = pd.MultiIndex.from_tuples(
5770
[("Parameters", p_name) for p_name in outputs[0].params.keys()]
@@ -64,7 +77,6 @@ def from_outputs(cls, outputs: Sequence[Output]):
6477
data["Event"] = pd.DataFrame.from_records([out.event_info for out in outputs])
6578

6679
return cls(data=data)
67-
# return cls(data=pd.DataFrame.from_records([out.params for out in outputs]))
6880

6981
@classmethod
7082
def from_csv(cls, filepath):
@@ -75,10 +87,10 @@ def to_csv(self, filepath):
7587
"""Store data as CSV"""
7688
self.data.to_csv(filepath, index=None)
7789

78-
def to_input_var(self, impact_func_gen, **impfset_kwargs):
90+
def to_input_var(self, impact_func_creator, **impfset_kwargs):
7991
"""Build Unsequa InputVar from the parameters stored in this object"""
8092
impf_set_list = [
81-
impact_func_gen(**row["Parameters"]) for _, row in self.data.iterrows()
93+
impact_func_creator(**row["Parameters"]) for _, row in self.data.iterrows()
8294
]
8395
return InputVar.impfset(impf_set_list, **impfset_kwargs)
8496

@@ -119,20 +131,35 @@ def run(self, **optimizer_run_kwargs) -> EnsembleOptimizerOutput:
119131
input = self.input_from_sample(sample)
120132

121133
# Run optimizer
122-
opt = self.optimizer_type(input, **self.optimizer_init_kwargs)
134+
opt = self.optimizer_type(
135+
input, **self._update_init_kwargs(self.optimizer_init_kwargs, idx)
136+
)
123137
out = opt.run(**optimizer_run_kwargs)
138+
out = SingleEnsembleOptimizerOutput(
139+
params=out.params,
140+
target=out.target,
141+
event_info=self.event_info_from_input(input),
142+
)
124143

125-
out.event_info = self.event_info_from_input(input)
126144
print(f"Ensemble: {idx}, Params: {out.params}")
127145
outputs.append(out)
128146

129147
return EnsembleOptimizerOutput.from_outputs(outputs)
130148

131149
@abstractmethod
132-
def input_from_sample(self, sample: List[Tuple[int, int]]):
150+
def input_from_sample(self, sample: List[Tuple[int, int]]) -> Input:
133151
""""""
134152

135-
def event_info_from_input(self, input: Input):
153+
def _update_init_kwargs(
154+
self, init_kwargs: Dict[str, Any], iteration: int
155+
) -> Dict[str, Any]:
156+
"""Copy settings in the init_kwargs and update for each iteration"""
157+
kwargs = copy(init_kwargs) # Maybe deepcopy?
158+
if "random_state" in kwargs:
159+
kwargs["random_state"] = kwargs["random_state"] + iteration
160+
return kwargs
161+
162+
def event_info_from_input(self, input: Input) -> Dict[str, Any]:
136163
"""Get information on the event(s) for which we calibrated"""
137164
# Get region and event IDs
138165
data = input.data.dropna(axis="columns", how="all").dropna(

0 commit comments

Comments
 (0)