Skip to content

Commit 01cf1d5

Browse files
committed
change emission tracker to callback
1 parent 95be22a commit 01cf1d5

File tree

12 files changed

+183
-75
lines changed

12 files changed

+183
-75
lines changed

autointent/_callbacks/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
from autointent._callbacks.base import OptimizerCallback
44
from autointent._callbacks.callback_handler import CallbackHandler
5+
from autointent._callbacks.emissions_tracker import EmissionsTrackerCallback
56
from autointent._callbacks.tensorboard import TensorBoardCallback
67
from autointent._callbacks.wandb import WandbCallback
78

8-
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
9+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback, EmissionsTrackerCallback]}
910

1011
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]
1112

@@ -34,6 +35,7 @@ def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
3435
__all__ = [
3536
"REPORTERS_NAMES",
3637
"CallbackHandler",
38+
"EmissionsTrackerCallback",
3739
"OptimizerCallback",
3840
"TensorBoardCallback",
3941
"WandbCallback",

autointent/_callbacks/base.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,22 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
6666
Args:
6767
metrics: Final metrics.
6868
"""
69+
70+
@abstractmethod
71+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
72+
"""Update metrics during training.
73+
74+
Args:
75+
metrics: Metrics to update.
76+
"""
77+
78+
@abstractmethod
79+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
80+
"""Update final metrics.
81+
82+
Args:
83+
metrics: Final metrics to update.
84+
85+
Returns:
86+
Updated final metrics.
87+
"""

autointent/_callbacks/callback_handler.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,31 @@ def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
8282
"""
8383
for callback in self.callbacks:
8484
getattr(callback, event)(**kwargs)
85+
86+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
87+
"""Update metrics during training.
88+
89+
Args:
90+
metrics: Metrics to update.
91+
92+
Returns:
93+
Updated metrics.
94+
"""
95+
for callback in self.callbacks:
96+
if hasattr(callback, "update_metrics"):
97+
metrics = callback.update_metrics(metrics)
98+
return metrics
99+
100+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
101+
"""Update final metrics.
102+
103+
Args:
104+
metrics: Final metrics to update.
105+
106+
Returns:
107+
Updated final metrics.
108+
"""
109+
for callback in self.callbacks:
110+
if hasattr(callback, "update_final_metrics"):
111+
metrics = callback.update_final_metrics(metrics)
112+
return metrics
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
"""Emissions tracking functionality for monitoring energy consumption and carbon emissions."""
2+
3+
import json
4+
import logging
5+
from pathlib import Path
6+
from typing import Any
7+
8+
from autointent._callbacks import OptimizerCallback
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class EmissionsTrackerCallback(OptimizerCallback):
14+
"""Class for tracking energy consumption and carbon emissions."""
15+
16+
name = "emissions_tracker"
17+
18+
current_module_name: str | None = None
19+
20+
def __init__(self, project_name: str, measure_power_secs: int = 1) -> None:
21+
"""Initialize the emission tracker.
22+
23+
Args:
24+
project_name: Name of the project to track emissions for.
25+
measure_power_secs: How often to measure power consumption in seconds.
26+
"""
27+
try:
28+
from codecarbon import EmissionsTracker
29+
except ImportError as e:
30+
msg = (
31+
"EmissionsTrackerCallback requires the codecarbon package to be installed. "
32+
"Please install it with `pip install codecarbon`."
33+
)
34+
raise ImportError(msg) from e
35+
logger.info("Emissions tracking is enabled via TRACK_EMISSIONS environment variable")
36+
self.tracker = EmissionsTracker(project_name=project_name, measure_power_secs=measure_power_secs)
37+
38+
def start_run(self, run_name: str, dirpath: Path, log_interval_time: float) -> None: # noqa: ARG002
39+
"""Start tracking emissions for the entire run.
40+
41+
Args:
42+
run_name: Name of the run.
43+
dirpath: Path to the directory where the logs will be saved.
44+
log_interval_time: Sampling interval for the system monitor in seconds.
45+
"""
46+
self.tracker.start()
47+
self.current_module_name = None
48+
49+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None: # noqa: ARG002
50+
"""Start tracking emissions for a specific task.
51+
52+
Args:
53+
module_name: Name of the task to track emissions for.
54+
num: Number of the module.
55+
module_kwargs: Module parameters.
56+
"""
57+
self.current_module_name = f"{module_name}_{num}"
58+
self.tracker.start_task(self.current_module_name)
59+
60+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, float]:
61+
"""Stop tracking emissions and return the emissions data.
62+
63+
Returns:
64+
Dictionary containing emissions metrics.
65+
"""
66+
emissions_data = self.tracker.stop_task(self.current_module_name)
67+
emissions_data_json = json.loads(emissions_data.toJSON())
68+
emissions_data_dict = {
69+
f"emissions/{k}": v for k, v in emissions_data_json.items() if isinstance(v, int | float)
70+
}
71+
return emissions_data_dict | metrics
72+
73+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
74+
"""Update final metrics with emissions data.
75+
76+
Args:
77+
metrics: Final metrics to update.
78+
79+
Returns:
80+
Updated metrics including emissions data.
81+
"""
82+
emissions_data = self.tracker.stop()
83+
emissions_data_json = json.loads(emissions_data.toJSON())
84+
emissions_data_dict = {
85+
f"emissions/{k}": v for k, v in emissions_data_json.items() if isinstance(v, int | float)
86+
}
87+
return emissions_data_dict | metrics
88+
89+
def log_value(self, **kwargs: dict[str, Any]) -> None:
90+
pass
91+
92+
def log_metrics(self, metrics: dict[str, Any]) -> None:
93+
pass
94+
95+
def end_module(self) -> None:
96+
pass
97+
98+
def end_run(self) -> None:
99+
pass
100+
101+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
102+
pass

autointent/_callbacks/tensorboard.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,9 @@ def end_module(self) -> None:
120120

121121
def end_run(self) -> None:
122122
"""Ends the current run. This method is currently a placeholder."""
123+
124+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
125+
return metrics
126+
127+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
128+
return metrics

autointent/_callbacks/wandb.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,9 @@ def end_run(self) -> None:
143143
144144
This method is currently a placeholder and does not perform additional operations.
145145
"""
146+
147+
def update_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
148+
return metrics
149+
150+
def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
151+
return metrics

autointent/_pipeline/_pipeline.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,10 @@ def fit(
258258
context.data_handler.test_labels(),
259259
predictions,
260260
)
261-
context.callback_handler.log_final_metrics(context.optimization_info.dump_evaluation_results())
261+
all_final_metrics = context.callback_handler.update_final_metrics(
262+
context.optimization_info.pipeline_metrics,
263+
)
264+
context.callback_handler.log_final_metrics(all_final_metrics)
262265

263266
return context
264267

autointent/modules/scoring/_sklearn/sklearn_scorer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import numpy.typing as npt
8+
from sklearn.base import BaseEstimator
89
from sklearn.multioutput import MultiOutputClassifier
910
from sklearn.utils import all_estimators
1011
from typing_extensions import Self
@@ -15,7 +16,7 @@
1516
from autointent.modules.base import BaseScorer
1617

1718
logger = logging.getLogger(__name__)
18-
AVAILABLE_CLASSIFIERS = {
19+
AVAILABLE_CLASSIFIERS: dict[str, type[BaseEstimator]] = {
1920
name: class_
2021
for name, class_ in all_estimators(
2122
type_filter=[
@@ -61,7 +62,7 @@ def __init__(
6162
self,
6263
clf_name: str = "LogisticRegression",
6364
embedder_config: EmbedderConfig | str | dict[str, Any] | None = None,
64-
**clf_args: Any, # noqa: ANN401
65+
**clf_args: dict[str, float | str | bool],
6566
) -> None:
6667
"""Initialize the SklearnScorer.
6768
@@ -89,7 +90,7 @@ def from_context(
8990
context: Context,
9091
clf_name: str = "LogisticRegression",
9192
embedder_config: EmbedderConfig | str | None = None,
92-
**clf_args: float | str | bool,
93+
**clf_args: dict[str, float | str | bool],
9394
) -> Self:
9495
"""Create a SklearnScorer instance using a Context object.
9596

autointent/nodes/_node_optimizer.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from autointent import Dataset
1818
from autointent.context import Context
1919
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
20-
from autointent.nodes.emissions_tracker import EmissionsTracker
2120
from autointent.nodes.info import NODES_INFO
2221
from autointent.schemas.node_validation import ParamSpaceFloat, ParamSpaceInt, ParamSpaceT, SearchSpaceConfig
2322

@@ -50,7 +49,6 @@ def __init__(
5049
self.node_type = node_type
5150
self.node_info = NODES_INFO[node_type]
5251
self.target_metric = target_metric
53-
self.emissions_tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}")
5452

5553
self.metrics = metrics if metrics is not None else []
5654
if self.target_metric not in self.metrics:
@@ -135,13 +133,10 @@ def objective(
135133

136134
self._logger.debug("Scoring %s module...", module_name)
137135

138-
self.emissions_tracker.start_task("module_scoring")
139136
quality_metrics = module.score(context, metrics=self.metrics)
140-
emissions_metrics = self.emissions_tracker.stop_task()
141-
all_metrics = {**quality_metrics, **emissions_metrics}
142137

143138
target_metric = quality_metrics[self.target_metric]
144-
139+
all_metrics = context.callback_handler.update_metrics(quality_metrics)
145140
context.callback_handler.log_metrics(all_metrics)
146141
context.callback_handler.end_module()
147142

autointent/nodes/emissions_tracker.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

0 commit comments

Comments
 (0)