Skip to content

Commit 03359f1

Browse files
committed
feat: update codecarbon
1 parent 0c71358 commit 03359f1

File tree

2 files changed

+61
-39
lines changed

2 files changed

+61
-39
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 8 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import gc
44
import itertools as it
5-
import json
65
import logging
76
from copy import deepcopy
87
from functools import partial
@@ -11,14 +10,14 @@
1110

1211
import optuna
1312
import torch
14-
from codecarbon import EmissionsTracker
1513
from optuna.trial import Trial
1614
from pydantic import BaseModel, Field
1715
from typing_extensions import assert_never
1816

1917
from autointent import Dataset
2018
from autointent.context import Context
2119
from autointent.custom_types import NodeType, SamplerType, SearchSpaceValidationMode
20+
from autointent.nodes.emissions_tracker import EmissionsTracker
2221
from autointent.nodes.info import NODES_INFO
2322

2423

@@ -69,7 +68,7 @@ def __init__(
6968
self.node_type = node_type
7069
self.node_info = NODES_INFO[node_type]
7170
self.target_metric = target_metric
72-
self.tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}", measure_power_secs=1)
71+
self.emissions_tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}")
7372

7473
self.metrics = metrics if metrics is not None else []
7574
if self.target_metric not in self.metrics:
@@ -78,36 +77,6 @@ def __init__(
7877
self.validate_search_space(search_space)
7978
self.modules_search_spaces = search_space
8079

81-
def _start_emissions_tracking(self, task_name: str) -> None:
82-
"""Start tracking emissions for a specific task.
83-
84-
Args:
85-
task_name: Name of the task to track emissions for.
86-
"""
87-
self.tracker.start_task(task_name)
88-
89-
def _stop_emissions_tracking(self) -> dict[str, float]:
90-
"""Stop tracking emissions and return the emissions data.
91-
92-
Returns:
93-
Dictionary containing emissions metrics.
94-
"""
95-
emissions_data = self.tracker.stop_task()
96-
emissions_data_dict = json.loads(emissions_data.toJSON())
97-
_ = self.tracker.stop()
98-
return emissions_data_dict
99-
100-
def _process_emissions_metrics(self, emissions_data: dict[str, float]) -> dict[str, float]:
101-
"""Process emissions data into metrics with the 'emissions/' prefix.
102-
103-
Args:
104-
emissions_data: Raw emissions data from the tracker.
105-
106-
Returns:
107-
Dictionary of processed emissions metrics with the 'emissions/' prefix.
108-
"""
109-
return {f"emissions/{k}": v for k, v in emissions_data.items()}
110-
11180
def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
11281
"""Performs the optimization process for the node.
11382
@@ -175,12 +144,12 @@ def objective(
175144

176145
self._logger.debug("Scoring %s module...", module_name)
177146

178-
self._start_emissions_tracking("module_scoring")
179-
all_metrics = module.score(context, metrics=self.metrics)
180-
emissions_data = self._stop_emissions_tracking()
181-
all_metrics.update(self._process_emissions_metrics(emissions_data))
147+
self.emissions_tracker.start_task("module_scoring")
148+
final_metrics = module.score(context, metrics=self.metrics)
149+
emissions_metrics = self.emissions_tracker.stop_task()
150+
all_metrics = {**final_metrics, **emissions_metrics}
182151

183-
target_metric = all_metrics[self.target_metric]
152+
target_metric = final_metrics[self.target_metric]
184153

185154
context.callback_handler.log_metrics(all_metrics)
186155
context.callback_handler.end_module()
@@ -199,7 +168,7 @@ def objective(
199168
config,
200169
target_metric,
201170
self.target_metric,
202-
all_metrics,
171+
final_metrics,
203172
module.get_assets(), # retriever name / scores / predictions
204173
module_dump_dir,
205174
module=module if not context.is_ram_to_clear() else None,
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""Emissions tracking functionality for monitoring energy consumption and carbon emissions."""
2+
3+
import json
4+
import logging
5+
6+
from codecarbon import EmissionsTracker as CodeCarbonTracker
7+
from codecarbon.output import EmissionsData
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class EmissionsTracker:
13+
"""Class for tracking energy consumption and carbon emissions."""
14+
15+
def __init__(self, project_name: str, measure_power_secs: int = 1) -> None:
16+
"""Initialize the emissions tracker.
17+
18+
Args:
19+
project_name: Name of the project to track emissions for.
20+
measure_power_secs: How often to measure power consumption in seconds.
21+
"""
22+
self._logger = logger
23+
self.tracker = CodeCarbonTracker(project_name=project_name, measure_power_secs=measure_power_secs)
24+
25+
def start_task(self, task_name: str) -> None:
26+
"""Start tracking emissions for a specific task.
27+
28+
Args:
29+
task_name: Name of the task to track emissions for.
30+
"""
31+
self.tracker.start_task(task_name)
32+
33+
def stop_task(self) -> dict[str, float]:
34+
"""Stop tracking emissions and return the emissions data.
35+
36+
Returns:
37+
Dictionary containing emissions metrics.
38+
"""
39+
emissions_data = self.tracker.stop_task()
40+
_ = self.tracker.stop()
41+
return self._process_metrics(emissions_data)
42+
43+
def _process_metrics(self, emissions_data: EmissionsData) -> dict[str, float]:
44+
"""Process emissions data into metrics with the 'emissions/' prefix.
45+
46+
Args:
47+
emissions_data: Raw emissions data from the tracker.
48+
49+
Returns:
50+
Dictionary of processed emissions metrics with the 'emissions/' prefix.
51+
"""
52+
emissions_data_dict = json.loads(emissions_data.toJSON())
53+
return {f"emissions/{k}": v for k, v in emissions_data_dict.items()}

0 commit comments

Comments
 (0)