22
33import gc
44import itertools as it
5+ import json
56import logging
67from copy import deepcopy
78from functools import partial
1011
1112import optuna
1213import torch
14+ from codecarbon import EmissionsTracker
1315from optuna .trial import Trial
1416from pydantic import BaseModel , Field
1517from typing_extensions import assert_never
@@ -67,6 +69,7 @@ def __init__(
6769 self .node_type = node_type
6870 self .node_info = NODES_INFO [node_type ]
6971 self .target_metric = target_metric
72+ self .tracker = EmissionsTracker (project_name = f"{ self .node_info .node_type } " , measure_power_secs = 1 )
7073
7174 self .metrics = metrics if metrics is not None else []
7275 if self .target_metric not in self .metrics :
@@ -75,6 +78,36 @@ def __init__(
7578 self .validate_search_space (search_space )
7679 self .modules_search_spaces = search_space
7780
81+ def _start_emissions_tracking (self , task_name : str ) -> None :
82+ """Start tracking emissions for a specific task.
83+
84+ Args:
85+ task_name: Name of the task to track emissions for.
86+ """
87+ self .tracker .start_task (task_name )
88+
89+ def _stop_emissions_tracking (self ) -> dict [str , float ]:
90+ """Stop tracking emissions and return the emissions data.
91+
92+ Returns:
93+ Dictionary containing emissions metrics.
94+ """
95+ emissions_data = self .tracker .stop_task ()
96+ emissions_data_dict = json .loads (emissions_data .toJSON ())
97+ _ = self .tracker .stop ()
98+ return emissions_data_dict
99+
100+ def _process_emissions_metrics (self , emissions_data : dict [str , float ]) -> dict [str , float ]:
101+ """Process emissions data into metrics with the 'emissions/' prefix.
102+
103+ Args:
104+ emissions_data: Raw emissions data from the tracker.
105+
106+ Returns:
107+ Dictionary of processed emissions metrics with the 'emissions/' prefix.
108+ """
109+ return {f"emissions/{ k } " : v for k , v in emissions_data .items ()}
110+
78111 def fit (self , context : Context , sampler : SamplerType = "brute" ) -> None :
79112 """Performs the optimization process for the node.
80113
@@ -141,7 +174,12 @@ def objective(
141174 context .callback_handler .start_module (module_name = module_name , num = self ._counter , module_kwargs = config )
142175
143176 self ._logger .debug ("Scoring %s module..." , module_name )
177+
178+ self ._start_emissions_tracking ("module_scoring" )
144179 all_metrics = module .score (context , metrics = self .metrics )
180+ emissions_data = self ._stop_emissions_tracking ()
181+ all_metrics .update (self ._process_emissions_metrics (emissions_data ))
182+
145183 target_metric = all_metrics [self .target_metric ]
146184
147185 context .callback_handler .log_metrics (all_metrics )
0 commit comments