Skip to content

Commit 0c71358

Browse files
committed
feat: update codecarbon
1 parent 29de65d commit 0c71358

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

autointent/nodes/_node_optimizer.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gc
44
import itertools as it
5+
import json
56
import logging
67
from copy import deepcopy
78
from functools import partial
@@ -10,6 +11,7 @@
1011

1112
import optuna
1213
import torch
14+
from codecarbon import EmissionsTracker
1315
from optuna.trial import Trial
1416
from pydantic import BaseModel, Field
1517
from typing_extensions import assert_never
@@ -67,6 +69,7 @@ def __init__(
6769
self.node_type = node_type
6870
self.node_info = NODES_INFO[node_type]
6971
self.target_metric = target_metric
72+
self.tracker = EmissionsTracker(project_name=f"{self.node_info.node_type}", measure_power_secs=1)
7073

7174
self.metrics = metrics if metrics is not None else []
7275
if self.target_metric not in self.metrics:
@@ -75,6 +78,36 @@ def __init__(
7578
self.validate_search_space(search_space)
7679
self.modules_search_spaces = search_space
7780

81+
def _start_emissions_tracking(self, task_name: str) -> None:
82+
"""Start tracking emissions for a specific task.
83+
84+
Args:
85+
task_name: Name of the task to track emissions for.
86+
"""
87+
self.tracker.start_task(task_name)
88+
89+
def _stop_emissions_tracking(self) -> dict[str, float]:
90+
"""Stop tracking emissions and return the emissions data.
91+
92+
Returns:
93+
Dictionary containing emissions metrics.
94+
"""
95+
emissions_data = self.tracker.stop_task()
96+
emissions_data_dict = json.loads(emissions_data.toJSON())
97+
_ = self.tracker.stop()
98+
return emissions_data_dict
99+
100+
def _process_emissions_metrics(self, emissions_data: dict[str, float]) -> dict[str, float]:
101+
"""Process emissions data into metrics with the 'emissions/' prefix.
102+
103+
Args:
104+
emissions_data: Raw emissions data from the tracker.
105+
106+
Returns:
107+
Dictionary of processed emissions metrics with the 'emissions/' prefix.
108+
"""
109+
return {f"emissions/{k}": v for k, v in emissions_data.items()}
110+
78111
def fit(self, context: Context, sampler: SamplerType = "brute") -> None:
79112
"""Performs the optimization process for the node.
80113
@@ -141,7 +174,12 @@ def objective(
141174
context.callback_handler.start_module(module_name=module_name, num=self._counter, module_kwargs=config)
142175

143176
self._logger.debug("Scoring %s module...", module_name)
177+
178+
self._start_emissions_tracking("module_scoring")
144179
all_metrics = module.score(context, metrics=self.metrics)
180+
emissions_data = self._stop_emissions_tracking()
181+
all_metrics.update(self._process_emissions_metrics(emissions_data))
182+
145183
target_metric = all_metrics[self.target_metric]
146184

147185
context.callback_handler.log_metrics(all_metrics)

0 commit comments

Comments
 (0)