Skip to content

Commit bf01269

Browse files
committed
fixes in metrics logging
1 parent 0b4a4dd commit bf01269

File tree

9 files changed

+33
-25
lines changed

9 files changed

+33
-25
lines changed

framework3/base/base_clases.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def _pre_predict(self, x: XYData) -> XYData:
648648
except Exception:
649649
raise ValueError("Trainable filter model not trained or loaded")
650650

651-
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> Optional[float]:
651+
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> Optional[float | dict]:
652652
"""
653653
Wrapper method for the fit function.
654654
@@ -712,7 +712,7 @@ def __setstate__(self, state: Dict[str, Any]):
712712
self.__dict__["fit"] = self._pre_fit_wrapp
713713
self.__dict__["predict"] = self._pre_predict_wrapp
714714

715-
def fit(self, x: XYData, y: Optional[XYData]) -> Optional[float]:
715+
def fit(self, x: XYData, y: Optional[XYData]) -> Optional[float | dict]:
716716
"""
717717
Method for fitting the filter to the data.
718718

framework3/base/base_splitter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def evaluate(
150150
"""
151151
...
152152

153-
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> float | None:
153+
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> float | None | dict:
154154
"""
155155
Wrapper method for pre-fitting.
156156

framework3/plugins/filters/cache/cached_filter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def init(self) -> None:
137137
self.filter.init()
138138
super().init()
139139

140-
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> float | None:
140+
def _pre_fit_wrapp(self, x: XYData, y: Optional[XYData]) -> float | None | dict:
141141
"""
142142
Wrapper method for the pre-fit stage.
143143

framework3/plugins/optimizer/grid_optimizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ def fit(self, x: XYData, y: Optional[XYData]):
183183
match pipeline.fit(x, y):
184184
case None:
185185
losses = pipeline.evaluate(x, y, pipeline.predict(x))
186-
187186
score = losses.get(self.scorer.__class__.__name__, 0.0)
188-
189187
case float() as score:
190188
pass
189+
case dict() as losses:
190+
score = losses.get(self.scorer.__class__.__name__, 0.0)
191191
case _:
192192
raise ValueError("Unexpected return type from pipeline.fit()")
193193

framework3/plugins/optimizer/optuna_optimizer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import optuna
22

33
from typing import Any, Callable, Dict, Sequence, Union, cast
4+
from framework3 import F1
45
from framework3.container import Container
5-
from framework3.base import BasePlugin, XYData
6+
from framework3.base import BaseMetric, BasePlugin, XYData
67

78
from rich import print
89

@@ -69,6 +70,7 @@ def __init__(
6970
pipeline: BaseFilter | None = None,
7071
study_name: str | None = None,
7172
storage: str | None = None,
73+
scorer: BaseMetric = F1(),
7274
):
7375
"""
7476
Initialize the OptunaOptimizer.
@@ -90,6 +92,7 @@ def __init__(
9092
self.n_trials = n_trials
9193
self.load_if_exists = load_if_exists
9294
self.reset_study = reset_study
95+
self.scorer = scorer
9396

9497
def optimize(self, pipeline: BaseFilter):
9598
"""
@@ -227,17 +230,12 @@ def matcher(k, v):
227230

228231
match pipeline.fit(x, y):
229232
case None:
230-
return float(
231-
next(
232-
iter(
233-
pipeline.evaluate(
234-
x, y, pipeline.predict(x)
235-
).values()
236-
)
237-
)
238-
)
233+
metrics = pipeline.evaluate(x, y, pipeline.predict(x))
234+
return float(metrics.get(self.scorer.__class__.__name__, 0.0))
239235
case float() as loss:
240236
return loss
237+
case dict() as losses:
238+
return float(losses.get(self.scorer.__class__.__name__, 0.0))
241239
case _:
242240
raise ValueError("Unsupported type in pipeline.fit")
243241

framework3/plugins/optimizer/wandb_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ def exec(
160160
return {self.scorer.__class__.__name__: float(loss)}
161161
case float() as loss:
162162
return {self.scorer.__class__.__name__: loss}
163+
case dict() as losses:
164+
loss = losses.pop(self.scorer.__class__.__name__, 0.0)
165+
wandb.log(dict(losses)) # type: ignore[attr-defined]
166+
return {self.scorer.__class__.__name__: loss}
163167
case _:
164168
raise ValueError("Unexpected return type from pipeline.fit()")
165169

framework3/plugins/pipelines/sequential/f3_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def start(
134134
print(f"Error during pipeline execution: {e}")
135135
raise e
136136

137-
def fit(self, x: XYData, y: Optional[XYData]) -> None | float:
137+
def fit(self, x: XYData, y: Optional[XYData]) -> None | float | dict:
138138
"""
139139
Fit the pipeline to the input data.
140140

framework3/plugins/splitter/cross_validation_splitter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def split(self, pipeline: BaseFilter):
103103
self.pipeline = pipeline
104104
self.pipeline.verbose(False)
105105

106-
def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
106+
def fit(self, x: XYData, y: XYData | None) -> Optional[float | dict]:
107107
"""
108108
Perform K-Fold cross-validation on the given data.
109109
@@ -131,7 +131,7 @@ def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
131131
if self.pipeline is None:
132132
raise ValueError("Pipeline must be fitted before splitting")
133133

134-
losses = []
134+
losses: dict = {}
135135
splits = self._kfold.split(X)
136136
for train_idx, val_idx in tqdm(
137137
splits, total=self._kfold.get_n_splits(X), disable=not self._verbose
@@ -151,11 +151,14 @@ def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
151151
_y = pipeline.predict(X_val)
152152

153153
loss = pipeline.evaluate(X_val, y_val, _y)
154-
losses.append(float(next(iter(loss.values()))))
154+
for metric, value in loss.items():
155+
v = losses.get(metric, [])
156+
v.append(value)
157+
losses[metric] = v
155158

156159
self.clear_memory()
157160

158-
return float(np.mean(losses) if losses else 0.0)
161+
return dict(map(lambda item: (item[0], np.mean(item[1])), losses.items()))
159162

160163
def start(
161164
self, x: XYData, y: Optional[XYData], X_: Optional[XYData]

framework3/plugins/splitter/stratified_cross_validation_splitter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def split(self, pipeline: BaseFilter):
8484
self.pipeline = pipeline
8585
self.pipeline.verbose(False)
8686

87-
def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
87+
def fit(self, x: XYData, y: XYData | None) -> Optional[float | dict]:
8888
"""
8989
Perform Stratified K-Fold cross-validation on the given data.
9090
@@ -111,7 +111,7 @@ def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
111111
X = x.value
112112
Y = y.value
113113

114-
losses = []
114+
losses: dict = {}
115115
splits = self._skf.split(X, Y)
116116
for train_idx, val_idx in tqdm(
117117
splits, total=self._skf.get_n_splits(X, Y), disable=not self._verbose
@@ -131,11 +131,14 @@ def fit(self, x: XYData, y: XYData | None) -> Optional[float]:
131131
_y = pipeline.predict(X_val)
132132

133133
loss = pipeline.evaluate(X_val, y_val, _y)
134-
losses.append(float(next(iter(loss.values()))))
134+
for metric, value in loss.items():
135+
v = losses.get(metric, [])
136+
v.append(value)
137+
losses[metric] = v
135138

136139
self.clear_memory()
137140

138-
return float(np.mean(losses) if losses else 0.0)
141+
return dict(map(lambda item: (item[0], np.mean(item[1])), losses.items()))
139142

140143
def start(
141144
self, x: XYData, y: Optional[XYData], X_: Optional[XYData]

0 commit comments

Comments
 (0)