Skip to content

Commit 66c1a80

Browse files
change emission tracker to callback (#228)
* change emission tracker to callback * Update optimizer_config.schema.json * fix callback * fix types * make update_* not abstract * try print metrics * fix metrics * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent d15f1e2 commit 66c1a80

File tree

14 files changed

+454
-624
lines changed

14 files changed

+454
-624
lines changed

autointent/_callbacks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from autointent._callbacks.base import OptimizerCallback
44
from autointent._callbacks.callback_handler import CallbackHandler
5+
from autointent._callbacks.emissions_tracker import EmissionsTrackerCallback
56
from autointent._callbacks.tensorboard import TensorBoardCallback
67
from autointent._callbacks.wandb import WandbCallback
78

8-
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
9+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback, EmissionsTrackerCallback]}
910

1011
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]
1112

@@ -34,6 +35,7 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
3435
__all__ = [
3536
"REPORTERS_NAMES",
3637
"CallbackHandler",
38+
"EmissionsTrackerCallback",
3739
"OptimizerCallback",
3840
"TensorBoardCallback",
3941
"WandbCallback",

autointent/_callbacks/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,22 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
6666
Args:
6767
metrics: Final metrics.
6868
"""
69+
70+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
71+
"""Update metrics during training.
72+
73+
Args:
74+
metrics: Metrics to update.
75+
"""
76+
return metrics
77+
78+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
79+
"""Update final metrics.
80+
81+
Args:
82+
metrics: Final metrics to update.
83+
84+
Returns:
85+
Updated final metrics.
86+
"""
87+
return metrics

autointent/_callbacks/callback_handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,31 @@ def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
8282
"""
8383
for callback in self.callbacks:
8484
getattr(callback, event)(**kwargs)
85+
86+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
87+
"""Update metrics during training.
88+
89+
Args:
90+
metrics: Metrics to update.
91+
92+
Returns:
93+
Updated metrics.
94+
"""
95+
for callback in self.callbacks:
96+
if hasattr(callback, "update_metrics"):
97+
metrics = callback.update_metrics(metrics)
98+
return metrics
99+
100+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
101+
"""Update final metrics.
102+
103+
Args:
104+
metrics: Final metrics to update.
105+
106+
Returns:
107+
Updated final metrics.
108+
"""
109+
for callback in self.callbacks:
110+
if hasattr(callback, "update_final_metrics"):
111+
metrics = callback.update_final_metrics(metrics)
112+
return metrics
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Emissions tracking functionality for monitoring energy consumption and carbon emissions."""
2+
3+
import json
4+
import logging
5+
from pathlib import Path
6+
from typing import Any
7+
8+
from autointent._callbacks import OptimizerCallback
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class EmissionsTrackerCallback(OptimizerCallback):
14+
"""Class for tracking energy consumption and carbon emissions."""
15+
16+
name = "emissions_tracker"
17+
18+
current_module_name: str | None = None
19+
20+
def __init__(self) -> None:
21+
"""Initialize the emission tracker."""
22+
try:
23+
from codecarbon import EmissionsTracker
24+
except ImportError as e:
25+
msg = (
26+
"EmissionsTrackerCallback requires the codecarbon package to be installed. "
27+
"Please install it with `pip install autointent[codecarbon]`."
28+
)
29+
raise ImportError(msg) from e
30+
self.emission_tracker = EmissionsTracker
31+
32+
def start_run(self, run_name: str, dirpath: Path, log_interval_time: float) -> None: # noqa: ARG002
33+
"""Start tracking emissions for the entire run.
34+
35+
Args:
36+
run_name: Name of the run.
37+
dirpath: Path to the directory where the logs will be saved.
38+
log_interval_time: Sampling interval for the system monitor in seconds.
39+
"""
40+
self.tracker = self.emission_tracker(project_name=run_name, measure_power_secs=log_interval_time)
41+
42+
self.tracker.start()
43+
self.current_module_name = None
44+
45+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: # noqa: ARG002
46+
"""Start tracking emissions for a specific task.
47+
48+
Args:
49+
module_name: Name of the task to track emissions for.
50+
num: Number of the module.
51+
module_kwargs: Module parameters.
52+
"""
53+
self.current_module_name = f"{module_name}_{num}"
54+
self.tracker.start_task(self.current_module_name)
55+
56+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, float]:
57+
"""Stop tracking emissions and return the emissions data.
58+
59+
Returns:
60+
Dictionary containing emissions metrics.
61+
"""
62+
emissions_data = self.tracker.stop_task(self.current_module_name)
63+
emissions_data_json = json.loads(emissions_data.toJSON())
64+
emissions_data_dict = {
65+
f"emissions/{k}": v for k, v in emissions_data_json.items() if isinstance(v, int | float)
66+
}
67+
return emissions_data_dict | metrics
68+
69+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
70+
"""Update final metrics with emissions data.
71+
72+
Args:
73+
metrics: Final metrics to update.
74+
75+
Returns:
76+
Updated metrics including emissions data.
77+
"""
78+
_ = self.tracker.stop()
79+
emissions_data_json = json.loads(self.tracker.final_emissions_data.toJSON())
80+
emissions_data_dict = {
81+
f"emissions/{k}": v for k, v in emissions_data_json.items() if isinstance(v, int | float)
82+
}
83+
return {"emissions": emissions_data_dict} | metrics
84+
85+
def log_value(self, **kwargs: dict[str, Any]) -> None:
86+
pass
87+
88+
def log_metrics(self, metrics: dict[str, Any]) -> None:
89+
pass
90+
91+
def end_module(self) -> None:
92+
pass
93+
94+
def end_run(self) -> None:
95+
pass
96+
97+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
98+
pass

autointent/_callbacks/wandb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
112112
}
113113

114114
try:
115-
config = metrics["configs"]
115+
config = metrics.get("configs")
116116
self.wandb.init(config=config, **wandb_run_init_args)
117117
self.wandb.log(metrics)
118118
except Exception as e:

autointent/_dataset/_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def from_hub(cls, repo_name: str) -> "Dataset":
102102
"""
103103
from ._reader import DictReader
104104

105-
splits = load_dataset(repo_name)
105+
splits = load_dataset(repo_name, "default")
106106
mapping = dict(**splits)
107107
if Split.INTENTS in get_dataset_config_names(repo_name):
108108
mapping["intents"] = load_dataset(repo_name, Split.INTENTS)[Split.INTENTS].to_list()

autointent/_pipeline/_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,10 @@ def fit(
252252
context.data_handler.test_labels(),
253253
predictions,
254254
)
255-
context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results())
255+
all_final_metrics = context.callback_handler.update_final_metrics(
256+
context.optimization_info.dump_evaluation_results(),
257+
)
258+
context.callback_handler.log_final_metrics(all_final_metrics)
256259

257260
return context
258261

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import numpy.typing as npt
8+
from sklearn.base import BaseEstimator
89
from sklearn.multioutput import MultiOutputClassifier
910
from sklearn.utils import all_estimators
1011
from typing_extensions import Self
@@ -15,7 +16,7 @@
1516
from autointent.modules.base import BaseScorer
1617

1718
logger = logging.getLogger(__name__)
18-
AVAILABLE_CLASSIFIERS = {
19+
AVAILABLE_CLASSIFIERS: dict[str, type[BaseEstimator]] = {
1920
name: class_
2021
for name, class_ in all_estimators(
2122
type_filter=[
@@ -61,7 +62,7 @@ def __init__(
6162
self,
6263
clf_name: str = "LogisticRegression",
6364
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
64-
**clf_args: Any, # noqa: ANN401
65+
**clf_args: dict[str, float | str | bool],
6566
) -> None:
6667
"""Initialize the SklearnScorer.
6768
@@ -89,7 +90,7 @@ def from_context(
8990
context: Context,
9091
clf_name: str = "LogisticRegression",
9192
embedder_config: EmbedderConfig | str | None = None,
92-
**clf_args: float | str | bool,
93+
**clf_args: dict[str, float | str | bool],
9394
) -> Self:
9495
"""Create a SklearnScorer instance using a Context object.
9596

autointent/nodes/_node_optimizer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autointent import Dataset
1818
from autointent.context import Context
1919
from autointent.custom_types import NodeType, SearchSpaceValidationMode
20-
from autointent.nodes.emissions_tracker import EmissionsTracker
2120
from autointent.nodes.info import NODES_INFO
2221
from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, ParamSpaceT, SearchSpaceConfig
2322

@@ -50,7 +49,6 @@ def __init__(
5049
self.node_type = node_type
5150
self.node_info = NODES_INFO[node_type]
5251
self.target_metric = target_metric
53-
self.emissions_tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}")
5452

5553
self.metrics = metrics if metrics is not None else []
5654
if self.target_metric not in self.metrics:
@@ -143,13 +141,10 @@ def objective(
143141

144142
self._logger.debug("Scoring %s module...", module_name)
145143

146-
self.emissions_tracker.start_task("module_scoring")
147144
quality_metrics = module.score(context, metrics=self.metrics)
148-
emissions_metrics = self.emissions_tracker.stop_task()
149-
all_metrics = {**quality_metrics, **emissions_metrics}
150145

151146
target_metric = quality_metrics[self.target_metric]
152-
147+
all_metrics = context.callback_handler.update_metrics(quality_metrics)
153148
context.callback_handler.log_metrics(all_metrics)
154149
context.callback_handler.end_module()
155150

autointent/nodes/emissions_tracker.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)