Skip to content

Commit 0b4a4dd

Browse files
committed
log all metrics on sweep run
1 parent 8fb05a6 commit 0b4a4dd

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

framework3/plugins/optimizer/wandb_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from typing import Any, Dict, cast
2+
3+
import wandb
24
from framework3 import Container
35
from framework3.base import BaseMetric
46
from framework3.base.base_clases import BaseFilter, BasePlugin
@@ -152,8 +154,8 @@ def exec(
152154
match pipeline.fit(x, y):
153155
case None:
154156
losses = pipeline.evaluate(x, y, pipeline.predict(x))
155-
156-
loss = losses.get(self.scorer.__class__.__name__, 0.0)
157+
loss = losses.pop(self.scorer.__class__.__name__, 0.0)
158+
wandb.log(dict(losses)) # type: ignore[attr-defined]
157159

158160
return {self.scorer.__class__.__name__: float(loss)}
159161
case float() as loss:

framework3/utils/wandb.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ def get_grid(aux: Dict[str, Any], config: Dict[str, Any]):
7070
case {"filters": filters, **r}:
7171
for filter_config in filters:
7272
WandbSweepManager.get_grid(filter_config, config)
73+
case {"filter": filter, **r}:
74+
WandbSweepManager.get_grid(filter, config)
7375
case {"pipeline": pipeline, **r}: # noqa: F841
7476
WandbSweepManager.get_grid(pipeline, config)
7577
case p_params:

0 commit comments

Comments
 (0)