Skip to content

Commit e2cc537

Browse files
pcuencastevhliuabidlabs
authored
trackio (#3669)
* trackio * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co> * seven -> eight * Add trackio as a real tracker instead * Sort * Style * Style * Remove step * Disable trackio on Python < 3.10 * Update src/accelerate/tracking.py Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * More style --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> Co-authored-by: Abubakar Abid <abubakar@huggingface.co>
1 parent 847ae58 commit e2cc537

File tree

10 files changed

+128
-3
lines changed

10 files changed

+128
-3
lines changed

docs/source/package_reference/tracking.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ rendered properly in your Markdown viewer.
2929
[[autodoc]] tracking.WandBTracker
3030
- __init__
3131

32+
## Trackio
33+
34+
[[autodoc]] tracking.TrackioTracker
35+
- __init__
36+
3237
## CometMLTracker
3338

3439
[[autodoc]] tracking.CometMLTracker

docs/source/usage_guides/tracking.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ Accelerate provides a general tracking API that can be used to log useful items
2020

2121
## Integrated Trackers
2222

23-
Currently `Accelerate` supports seven trackers out-of-the-box:
23+
Currently `Accelerate` supports eight trackers out-of-the-box:
2424

2525
- TensorBoard
26-
- WandB
26+
- WandB
27+
- Trackio
2728
- CometML
2829
- Aim
2930
- MLFlow

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
"mlflow",
5050
"matplotlib",
5151
"swanlab",
52+
"trackio",
5253
]
5354
extras["dev"] = extras["quality"] + extras["testing"] + extras["rich"]
5455

src/accelerate/accelerator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,11 @@ class Accelerator:
233233
- `"all"`
234234
- `"tensorboard"`
235235
- `"wandb"`
236+
- `"trackio"`
237+
- `"aim"`
236238
- `"comet_ml"`
239+
- `"mlflow"`
240+
- `"dvclive"`
237241
- `"swanlab"`
238242
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
239243
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.

src/accelerate/test_utils/testing.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
is_torchao_available,
7070
is_torchdata_stateful_dataloader_available,
7171
is_torchvision_available,
72+
is_trackio_available,
7273
is_transformer_engine_available,
7374
is_transformers_available,
7475
is_triton_available,
@@ -459,6 +460,13 @@ def require_wandb(test_case):
459460
return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
460461

461462

463+
def require_trackio(test_case):
464+
"""
465+
Decorator marking a test that requires trackio installed. These tests are skipped when trackio isn't installed
466+
"""
467+
return unittest.skipUnless(is_trackio_available(), "test requires trackio")(test_case)
468+
469+
462470
def require_comet_ml(test_case):
463471
"""
464472
Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed
@@ -548,7 +556,8 @@ def require_matplotlib(test_case):
548556

549557

550558
_atleast_one_tracker_available = (
551-
any([is_wandb_available(), is_tensorboard_available(), is_swanlab_available()]) and not is_comet_ml_available()
559+
any([is_wandb_available(), is_tensorboard_available(), is_trackio_available(), is_swanlab_available()])
560+
and not is_comet_ml_available()
552561
)
553562

554563

src/accelerate/tracking.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
is_mlflow_available,
3737
is_swanlab_available,
3838
is_tensorboard_available,
39+
is_trackio_available,
3940
is_wandb_available,
4041
listify,
4142
)
@@ -67,6 +68,9 @@
6768
if is_swanlab_available():
6869
_available_trackers.append(LoggerType.SWANLAB)
6970

71+
if is_trackio_available():
72+
_available_trackers.append(LoggerType.TRACKIO)
73+
7074
logger = get_logger(__name__)
7175

7276

@@ -415,6 +419,83 @@ def finish(self):
415419
logger.debug("WandB run closed")
416420

417421

422+
class TrackioTracker(GeneralTracker):
423+
"""
424+
A `Tracker` class that supports `trackio`. Should be initialized at the start of your script.
425+
426+
Args:
427+
run_name (`str`):
428+
The name of the experiment run. Will be used as the `project` name when instantiating trackio.
429+
**kwargs (additional keyword arguments, *optional*):
430+
Additional key word arguments passed along to the `trackio.init` method. Refer to this
431+
[init](https://github.com/gradio-app/trackio/blob/814809552310468b13f84f33764f1369b4e5136c/trackio/__init__.py#L22)
432+
to see all supported key word arguments.
433+
"""
434+
435+
name = "trackio"
436+
requires_logging_directory = False
437+
main_process_only = False
438+
439+
def __init__(self, run_name: str, **kwargs):
440+
super().__init__()
441+
self.run_name = run_name
442+
self.init_kwargs = kwargs
443+
444+
@on_main_process
445+
def start(self):
446+
import trackio
447+
448+
self.run = trackio.init(project=self.run_name, **self.init_kwargs)
449+
logger.debug(f"Initialized trackio project {self.run_name}")
450+
logger.debug(
451+
"Make sure to log any initial configurations with `self.store_init_configuration` before training!"
452+
)
453+
454+
@property
455+
def tracker(self):
456+
return self.run
457+
458+
@on_main_process
459+
def store_init_configuration(self, values: dict):
460+
"""
461+
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment.
462+
463+
Args:
464+
values (Dictionary `str` to `bool`, `str`, `float` or `int`):
465+
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
466+
`str`, `float`, `int`, or `None`.
467+
"""
468+
import trackio
469+
470+
trackio.config.update(values, allow_val_change=True)
471+
logger.debug("Stored initial configuration hyperparameters to trackio")
472+
473+
@on_main_process
474+
def log(self, values: dict, step: Optional[int] = None, **kwargs):
475+
"""
476+
Logs `values` to the current run.
477+
478+
Args:
479+
values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`):
480+
Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of
481+
`str` to `float`/`int`.
482+
step (`int`, *optional*):
483+
The run step. If included, the log will be affiliated with this step.
484+
kwargs:
485+
Additional key word arguments passed along to the `trackio.log` method.
486+
"""
487+
self.run.log(values, **kwargs)
488+
logger.debug("Successfully logged to trackio")
489+
490+
@on_main_process
491+
def finish(self):
492+
"""
493+
Closes `trackio` run
494+
"""
495+
self.run.finish()
496+
logger.debug("trackio run closed")
497+
498+
418499
class CometMLTracker(GeneralTracker):
419500
"""
420501
A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script.
@@ -1174,6 +1255,7 @@ def finish(self):
11741255
"clearml": ClearMLTracker,
11751256
"dvclive": DVCLiveTracker,
11761257
"swanlab": SwanLabTracker,
1258+
"trackio": TrackioTracker,
11771259
}
11781260

11791261

@@ -1195,6 +1277,8 @@ def filter_trackers(
11951277
- `"all"`
11961278
- `"tensorboard"`
11971279
- `"wandb"`
1280+
- `"trackio"`
1281+
- `"aim"`
11981282
- `"comet_ml"`
11991283
- `"mlflow"`
12001284
- `"dvclive"`

src/accelerate/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@
129129
is_torchdata_available,
130130
is_torchdata_stateful_dataloader_available,
131131
is_torchvision_available,
132+
is_trackio_available,
132133
is_transformer_engine_available,
133134
is_transformers_available,
134135
is_triton_available,

src/accelerate/utils/dataclasses.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,7 +701,10 @@ class LoggerType(BaseEnum):
701701
- **ALL** -- all available trackers in the environment that are supported
702702
- **TENSORBOARD** -- TensorBoard as an experiment tracker
703703
- **WANDB** -- wandb as an experiment tracker
704+
- **TRACKIO** -- trackio as an experiment tracker
704705
- **COMETML** -- comet_ml as an experiment tracker
706+
- **MLFLOW** -- mlflow as an experiment tracker
707+
- **CLEARML** -- clearml as an experiment tracker
705708
- **DVCLIVE** -- dvclive as an experiment tracker
706709
- **SWANLAB** -- swanlab as an experiment tracker
707710
"""
@@ -710,6 +713,7 @@ class LoggerType(BaseEnum):
710713
AIM = "aim"
711714
TENSORBOARD = "tensorboard"
712715
WANDB = "wandb"
716+
TRACKIO = "trackio"
713717
COMETML = "comet_ml"
714718
MLFLOW = "mlflow"
715719
CLEARML = "clearml"

src/accelerate/utils/imports.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import importlib
1616
import importlib.metadata
1717
import os
18+
import sys
1819
import warnings
1920
from functools import lru_cache, wraps
2021

@@ -285,6 +286,10 @@ def is_swanlab_available():
285286
return _is_package_available("swanlab")
286287

287288

289+
def is_trackio_available():
290+
return sys.version_info >= (3, 10) and _is_package_available("trackio")
291+
292+
288293
def is_boto3_available():
289294
return _is_package_available("boto3")
290295

tests/test_tracking.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
require_pandas,
4646
require_swanlab,
4747
require_tensorboard,
48+
require_trackio,
4849
require_wandb,
4950
skip,
5051
)
@@ -57,6 +58,7 @@
5758
MLflowTracker,
5859
SwanLabTracker,
5960
TensorBoardTracker,
61+
TrackioTracker,
6062
WandBTracker,
6163
)
6264
from accelerate.utils import (
@@ -801,6 +803,15 @@ def test_wandb_deferred_init(self):
801803
_ = Accelerator(log_with=tracker)
802804
self.assertNotEqual(PartialState._shared_state, {})
803805

806+
@require_trackio
807+
def test_trackio_deferred_init(self):
808+
"""Test that trackio tracker initialization doesn't initialize distributed"""
809+
PartialState._reset_state()
810+
tracker = TrackioTracker(run_name="test_trackio")
811+
self.assertEqual(PartialState._shared_state, {})
812+
_ = Accelerator(log_with=tracker)
813+
self.assertNotEqual(PartialState._shared_state, {})
814+
804815
@require_comet_ml
805816
def test_comet_ml_deferred_init(self):
806817
"""Test that CometML tracker initialization doesn't initialize distributed"""

0 commit comments

Comments
 (0)