22
33import gc
44import itertools as it
5- import json
65import logging
76from copy import deepcopy
87from functools import partial
1110
1211import optuna
1312import torch
14- from codecarbon import EmissionsTracker
1513from optuna .trial import Trial
1614from pydantic import BaseModel , Field
1715from typing_extensions import assert_never
1816
1917from autointent import Dataset
2018from autointent .context import Context
2119from autointent .custom_types import NodeType , SamplerType , SearchSpaceValidationMode
20+ from autointent .nodes .emissions_tracker import EmissionsTracker
2221from autointent .nodes .info import NODES_INFO
2322
2423
@@ -69,7 +68,7 @@ def __init__(
6968 self .node_type = node_type
7069 self .node_info = NODES_INFO [node_type ]
7170 self .target_metric = target_metric
72- self .tracker = EmissionsTracker (project_name = f"{ self .node_info .node_type } " , measure_power_secs = 1 )
71+ self .emissions_tracker = EmissionsTracker (project_name = f"{ self .node_info .node_type } " )
7372
7473 self .metrics = metrics if metrics is not None else []
7574 if self .target_metric not in self .metrics :
@@ -78,36 +77,6 @@ def __init__(
7877 self .validate_search_space (search_space )
7978 self .modules_search_spaces = search_space
8079
81- def _start_emissions_tracking (self , task_name : str ) -> None :
82- """Start tracking emissions for a specific task.
83-
84- Args:
85- task_name: Name of the task to track emissions for.
86- """
87- self .tracker .start_task (task_name )
88-
89- def _stop_emissions_tracking (self ) -> dict [str , float ]:
90- """Stop tracking emissions and return the emissions data.
91-
92- Returns:
93- Dictionary containing emissions metrics.
94- """
95- emissions_data = self .tracker .stop_task ()
96- emissions_data_dict = json .loads (emissions_data .toJSON ())
97- _ = self .tracker .stop ()
98- return emissions_data_dict
99-
100- def _process_emissions_metrics (self , emissions_data : dict [str , float ]) -> dict [str , float ]:
101- """Process emissions data into metrics with the 'emissions/' prefix.
102-
103- Args:
104- emissions_data: Raw emissions data from the tracker.
105-
106- Returns:
107- Dictionary of processed emissions metrics with the 'emissions/' prefix.
108- """
109- return {f"emissions/{ k } " : v for k , v in emissions_data .items ()}
110-
11180 def fit (self , context : Context , sampler : SamplerType = "brute" ) -> None :
11281 """Performs the optimization process for the node.
11382
@@ -175,12 +144,12 @@ def objective(
175144
176145 self ._logger .debug ("Scoring %s module..." , module_name )
177146
178- self ._start_emissions_tracking ("module_scoring" )
179- all_metrics = module .score (context , metrics = self .metrics )
180- emissions_data = self ._stop_emissions_tracking ()
181- all_metrics . update ( self . _process_emissions_metrics ( emissions_data ))
147+ self .emissions_tracker . start_task ("module_scoring" )
148+ final_metrics = module .score (context , metrics = self .metrics )
149+ emissions_metrics = self .emissions_tracker . stop_task ()
150+ all_metrics = { ** final_metrics , ** emissions_metrics }
182151
183- target_metric = all_metrics [self .target_metric ]
152+ target_metric = final_metrics [self .target_metric ]
184153
185154 context .callback_handler .log_metrics (all_metrics )
186155 context .callback_handler .end_module ()
@@ -199,7 +168,7 @@ def objective(
199168 config ,
200169 target_metric ,
201170 self .target_metric ,
202- all_metrics ,
171+ final_metrics ,
203172 module .get_assets (), # retriever name / scores / predictions
204173 module_dump_dir ,
205174 module = module if not context .is_ram_to_clear () else None ,
0 commit comments