Skip to content

Commit e848aaa

Browse files
Feat/code carbon each node (#175)
* feat: update codecarbon * feat: update codecarbon * feat: added codecarbon * Update optimizer_config.schema.json * fix: fixed import mypy * fix: codecarbon package * fix: only float\integer log * fix: codecarbon package * fix: mypy * fix: test * fix: delete emissions * fix: test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 29de65d commit e848aaa

File tree

5 files changed

+113
-12
lines changed

5 files changed

+113
-12
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from autointent import Dataset
1818
from autointent.context import Context
1919
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
20+
from autointent.nodes.emissions_tracker import EmissionsTracker
2021
from autointent.nodes.info import NODES_INFO
2122

2223

@@ -67,6 +68,7 @@ def __init__(
6768
self.node_type = node_type
6869
self.node_info = NODES_INFO[node_type]
6970
self.target_metric = target_metric
71+
self.emissions_tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}")
7072

7173
self.metrics = metrics if metrics is not None else []
7274
if self.target_metric not in self.metrics:
@@ -141,8 +143,13 @@ def objective(
141143
context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config)
142144

143145
self._logger.debug("Scoring %s module...", module_name)
144-
all_metrics = module.score(context, metrics=self.metrics)
145-
target_metric = all_metrics[self.target_metric]
146+
147+
self.emissions_tracker.start_task("module_scoring")
148+
final_metrics = module.score(context, metrics=self.metrics)
149+
emissions_metrics = self.emissions_tracker.stop_task()
150+
all_metrics = {**final_metrics, **emissions_metrics}
151+
152+
target_metric = final_metrics[self.target_metric]
146153

147154
context.callback_handler.log_metrics(all_metrics)
148155
context.callback_handler.end_module()
@@ -161,7 +168,7 @@ def objective(
161168
config,
162169
target_metric,
163170
self.target_metric,
164-
all_metrics,
171+
final_metrics,
165172
module.get_assets(), # retriever name / scores / predictions
166173
module_dump_dir,
167174
module=module if not context.is_ram_to_clear() else None,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Emissions tracking functionality for monitoring energy consumption and carbon emissions."""
2+
3+
import json
4+
import logging
5+
6+
from codecarbon import EmissionsTracker as CodeCarbonTracker # type: ignore[import-untyped]
7+
from codecarbon.output import EmissionsData # type: ignore[import-untyped]
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class EmissionsTracker:
13+
"""Class for tracking energy consumption and carbon emissions."""
14+
15+
def __init__(self, project_name: str, measure_power_secs: int = 1) -> None:
16+
"""Initialize the emissions tracker.
17+
18+
Args:
19+
project_name: Name of the project to track emissions for.
20+
measure_power_secs: How often to measure power consumption in seconds.
21+
"""
22+
self._logger = logger
23+
self.tracker = CodeCarbonTracker(project_name=project_name, measure_power_secs=measure_power_secs)
24+
25+
def start_task(self, task_name: str) -> None:
26+
"""Start tracking emissions for a specific task.
27+
28+
Args:
29+
task_name: Name of the task to track emissions for.
30+
"""
31+
self.tracker.start_task(task_name)
32+
33+
def stop_task(self) -> dict[str, float]:
34+
"""Stop tracking emissions and return the emissions data.
35+
36+
Returns:
37+
Dictionary containing emissions metrics.
38+
"""
39+
emissions_data = self.tracker.stop_task()
40+
_ = self.tracker.stop()
41+
return self._process_metrics(emissions_data)
42+
43+
def _process_metrics(self, emissions_data: EmissionsData) -> dict[str, float]:
44+
"""Process emissions data into metrics with the 'emissions/' prefix.
45+
46+
Args:
47+
emissions_data: Raw emissions data from the tracker.
48+
49+
Returns:
50+
Dictionary of processed emissions metrics with the 'emissions/' prefix.
51+
"""
52+
emissions_data_dict = json.loads(emissions_data.toJSON())
53+
return {f"emissions/{k}": v for k, v in emissions_data_dict.items() if isinstance(v, int | float)}

docs/optimizer_config.schema.json

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@
6666
"validation_size": {
6767
"default": 0.2,
6868
"description": "Fraction of train samples to allocate for validation (if input dataset doesn't contain validation split).",
69-
"maximum": 1.0,
70-
"minimum": 0.0,
69+
"maximum": 1,
70+
"minimum": 0,
7171
"title": "Validation Size",
7272
"type": "number"
7373
},
7474
"separation_ratio": {
7575
"anyOf": [
7676
{
77-
"maximum": 1.0,
78-
"minimum": 0.0,
77+
"maximum": 1,
78+
"minimum": 0,
7979
"type": "number"
8080
},
8181
{
@@ -342,6 +342,7 @@
342342
},
343343
"search_space": {
344344
"items": {
345+
"additionalProperties": true,
345346
"type": "object"
346347
},
347348
"title": "Search Space",

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ dependencies = [
4545
"xxhash (>=3.5.0,<4.0.0)",
4646
"python-dotenv (>=1.0.1,<2.0.0)",
4747
"transformers[torch] (>=4.49.0,<5.0.0)",
48+
"codecarbon (==2.6)",
4849
]
4950

5051
[project.urls]

tests/callback/test_callback.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def log_value(self, **kwargs: dict[str, Any]) -> None:
2626

2727
def log_metrics(self, **kwargs: dict[str, Any]) -> None:
2828
metrics = kwargs["metrics"]
29+
metrics = {k: v for k, v in metrics.items() if not k.startswith("emissions/")}
2930
for metric_name, metric_value in metrics.items():
3031
if not isinstance(metric_value, str) and np.isnan(metric_value):
3132
metrics[metric_name] = None
@@ -103,7 +104,14 @@ def test_pipeline_callbacks(dataset):
103104
"num": 0,
104105
},
105106
),
106-
("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}),
107+
(
108+
"log_metric",
109+
{
110+
"metrics": {
111+
"retrieval_hit_rate": 1.0,
112+
}
113+
},
114+
),
107115
("end_module", {}),
108116
(
109117
"start_module",
@@ -113,7 +121,14 @@ def test_pipeline_callbacks(dataset):
113121
"num": 1,
114122
},
115123
),
116-
("log_metric", {"metrics": {"retrieval_hit_rate": 1.0}}),
124+
(
125+
"log_metric",
126+
{
127+
"metrics": {
128+
"retrieval_hit_rate": 1.0,
129+
}
130+
},
131+
),
117132
("end_module", {}),
118133
(
119134
"start_module",
@@ -139,7 +154,15 @@ def test_pipeline_callbacks(dataset):
139154
"num": 0,
140155
},
141156
),
142-
("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}),
157+
(
158+
"log_metric",
159+
{
160+
"metrics": {
161+
"scoring_accuracy": 1.0,
162+
"scoring_roc_auc": 1.0,
163+
}
164+
},
165+
),
143166
("end_module", {}),
144167
(
145168
"start_module",
@@ -165,7 +188,15 @@ def test_pipeline_callbacks(dataset):
165188
"num": 1,
166189
},
167190
),
168-
("log_metric", {"metrics": {"scoring_accuracy": 1.0, "scoring_roc_auc": 1.0}}),
191+
(
192+
"log_metric",
193+
{
194+
"metrics": {
195+
"scoring_accuracy": 1.0,
196+
"scoring_roc_auc": 1.0,
197+
}
198+
},
199+
),
169200
("end_module", {}),
170201
(
171202
"start_module",
@@ -189,7 +220,15 @@ def test_pipeline_callbacks(dataset):
189220
"num": 0,
190221
},
191222
),
192-
("log_metric", {"metrics": {"scoring_accuracy": 0.75, "scoring_roc_auc": 1.0}}),
223+
(
224+
"log_metric",
225+
{
226+
"metrics": {
227+
"scoring_accuracy": 0.75,
228+
"scoring_roc_auc": 1.0,
229+
}
230+
},
231+
),
193232
("end_module", {}),
194233
("start_module", {"module_kwargs": {"thresh": 0.5}, "module_name": "threshold", "num": 0}),
195234
(

0 commit comments

Comments
 (0)