Skip to content

Commit 5eabe82

Browse files
committed
fix callback
1 parent b9f65a9 commit 5eabe82

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

autointent/_callbacks/emissions_tracker.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class EmissionsTrackerCallback(OptimizerCallback):
1717

1818
current_module_name: str | None = None
1919

20-
def __init__(self, project_name: str, measure_power_secs: int = 1) -> None:
20+
def __init__(self) -> None:
2121
"""Initialize the emission tracker.
2222
2323
Args:
@@ -32,8 +32,9 @@ def __init__(self, project_name: str, measure_power_secs: int = 1) -> None:
3232
"Please install it with `pip install codecarbon`."
3333
)
3434
raise ImportError(msg) from e
35+
self.tracker: EmissionsTracker | None = None
36+
self.emission_tracker = EmissionsTracker
3537
logger.info("Emissions tracking is enabled via TRACK_EMISSIONS environment variable")
36-
self.tracker = EmissionsTracker(project_name=project_name, measure_power_secs=measure_power_secs)
3738

3839
def start_run(self, run_name: str, dirpath: Path, log_interval_time: float) -> None: # noqa: ARG002
3940
"""Start tracking emissions for the entire run.
@@ -43,6 +44,8 @@ def start_run(self, run_name: str, dirpath: Path, log_interval_time: float) -> N
4344
dirpath: Path to the directory where the logs will be saved.
4445
log_interval_time: Sampling interval for the system monitor in seconds.
4546
"""
47+
self.tracker = self.emission_tracker(project_name=run_name, measure_power_secs=log_interval_time)
48+
4649
self.tracker.start()
4750
self.current_module_name = None
4851

@@ -79,8 +82,8 @@ def update_final_metrics(self, metrics: dict[str, Any]) -> dict[str, Any]:
7982
Returns:
8083
Updated metrics including emissions data.
8184
"""
82-
emissions_data = self.tracker.stop()
83-
emissions_data_json = json.loads(emissions_data.toJSON())
85+
_ = self.tracker.stop()
86+
emissions_data_json = json.loads(self.tracker.final_emissions_data.toJSON())
8487
emissions_data_dict = {
8588
f"emissions/{k}": v for k, v in emissions_data_json.items() if isinstance(v, int | float)
8689
}

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ wandb = [
8888
"wandb (>=0.19.10,<1.0.0)",
8989
]
9090
codecarbon = [
91-
"codecarbon (==2.6)",
91+
"codecarbon (>=3.0.2, <3.1.0)",
92+
"pynvml ( <=12 )",
9293
]
9394

9495
[project.urls]

0 commit comments

Comments
 (0)