Skip to content

Commit 8f016e8

Browse files
authored
fix/wandb-final-metrics-skipped (#212)
* fix * sklearn scorer proper name * fix typing errors * try to fix pydantic errors
1 parent a04587b commit 8f016e8

File tree

7 files changed

+21
-5
lines changed

7 files changed

+21
-5
lines changed

autointent/_callbacks/tensorboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self) -> None:
1616
Raises an ImportError if neither are installed.
1717
"""
1818
try:
19-
from torch.utils.tensorboard import SummaryWriter
19+
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
2020

2121
self.writer = SummaryWriter
2222
except ImportError:

autointent/_callbacks/wandb.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
8989
"""
9090
self.wandb.log(metrics)
9191

92+
def _close_current_run(self) -> None:
93+
"""Close the current W&B run if open."""
94+
if self.wandb.run is not None:
95+
self.wandb.finish()
96+
9297
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
9398
"""Logs final evaluation metrics to W&B.
9499
@@ -97,6 +102,8 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
97102
Args:
98103
metrics: A dictionary of final performance metrics.
99104
"""
105+
self._close_current_run()
106+
100107
wandb_run_init_args = {
101108
"project": self.project_name,
102109
"group": self.group,

autointent/_dump_tools.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,15 @@ def load( # noqa: C901, PLR0912, PLR0915
229229
elif child.name == Dumper.ptuning_models:
230230
for model_dir in child.iterdir():
231231
try:
232-
model = AutoModelForSequenceClassification.from_pretrained(model_dir / "base_model")
232+
model = AutoModelForSequenceClassification.from_pretrained(model_dir / "base_model") # type: ignore[no-untyped-call]
233233
hf_models[model_dir.name] = PeftModel.from_pretrained(model, model_dir / "peft")
234234
except Exception as e: # noqa: PERF203
235235
msg = f"Error loading PeftModel {model_dir.name}: {e}"
236236
logger.exception(msg)
237237
elif child.name == Dumper.hf_models:
238238
for model_dir in child.iterdir():
239239
try:
240-
hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir)
240+
hf_models[model_dir.name] = AutoModelForSequenceClassification.from_pretrained(model_dir) # type: ignore[no-untyped-call]
241241
except Exception as e: # noqa: PERF203
242242
msg = f"Error loading HF model {model_dir.name}: {e}"
243243
logger.exception(msg)

autointent/modules/base/_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ class BaseModule(ABC):
3232
name: str
3333
"""Name of the module."""
3434

35+
@property
36+
def trial_name(self) -> str:
37+
"""Name of the module for logging."""
38+
return self.name
39+
3540
@abstractmethod
3641
def fit(self, *args: tuple[Any], **kwargs: dict[str, Any]) -> None:
3742
"""Fit the model.

autointent/modules/scoring/_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _initialize_model(self) -> Any: # noqa: ANN401
7676
label2id = {i: i for i in range(self._n_classes)}
7777
id2label = {i: i for i in range(self._n_classes)}
7878

79-
return AutoModelForSequenceClassification.from_pretrained(
79+
return AutoModelForSequenceClassification.from_pretrained( # type: ignore[no-untyped-call]
8080
self.classification_model_config.model_name,
8181
trust_remote_code=self.classification_model_config.trust_remote_code,
8282
num_labels=self._n_classes,

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,10 @@ def __init__(
7979
logger.error(msg)
8080
raise ValueError(msg)
8181

82+
@property
83+
def trial_name(self) -> str:
84+
return f"sklearn_{self.clf_name}"
85+
8286
@classmethod
8387
def from_context(
8488
cls,

autointent/nodes/_node_optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def objective(
131131
module = self.node_info.modules_available[module_name].from_context(context, **config)
132132
config.update(module.get_implicit_initialization_params())
133133

134-
context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config)
134+
context.callback_handler.start_module(module_name=module.trial_name, num=self._counter, module_kwargs=config)
135135

136136
self._logger.debug("Scoring %s module...", module_name)
137137

0 commit comments

Comments
 (0)