55
66
77class TensorBoardCallback (OptimizerCallback ):
8- """
9- TensorBoard callback.
10-
11- This callback logs the optimization process to TensorBoard.
12- """
8+ """TensorBoard callback for logging the optimization process."""
139
1410 name = "tensorboard"
1511
1612 def __init__ (self ) -> None :
17- """Initialize the callback."""
13+ """Initializes the TensorBoard callback.
14+
15+ Attempts to import `torch.utils.tensorboard` first. If unavailable, tries to import `tensorboardX`.
16+ Raises an ImportError if neither are installed.
17+ """
1818 try :
1919 from torch .utils .tensorboard import SummaryWriter # type: ignore[attr-defined]
2020
@@ -32,22 +32,22 @@ def __init__(self) -> None:
3232 raise ImportError (msg ) from None
3333
3434 def start_run (self , run_name : str , dirpath : Path ) -> None :
35- """
36- Start a new run.
35+ """Starts a new run and sets the directory for storing logs.
3736
38- :param run_name: Name of the run.
39- :param dirpath: Path to the directory where the logs will be saved.
37+ Args:
38+ run_name: Name of the run.
39+ dirpath: Path to the directory where logs will be saved.
4040 """
4141 self .run_name = run_name
4242 self .dirpath = dirpath
4343
4444 def start_module (self , module_name : str , num : int , module_kwargs : dict [str , Any ]) -> None :
45- """
46- Start a new module.
45+ """Starts a new module and initializes a TensorBoard writer for it.
4746
48- :param module_name: Name of the module.
49- :param num: Number of the module.
50- :param module_kwargs: Module parameters.
47+ Args:
48+ module_name: Name of the module.
49+ num: Identifier number of the module.
50+ module_kwargs: Dictionary containing module parameters.
5151 """
5252 module_run_name = f"{ self .run_name } _{ module_name } _{ num } "
5353 log_dir = Path (self .dirpath ) / module_run_name
@@ -57,43 +57,38 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5757 for key , value in module_kwargs .items ():
5858 self .module_writer .add_text (f"module_params/{ key } " , str (value )) # type: ignore[no-untyped-call]
5959
60- def log_value (self , ** kwargs : dict [str , Any ]) -> None :
61- """
62- Log data.
60+ def log_value (self , ** kwargs : dict [str , int | float | Any ]) -> None :
61+ """Logs scalar or text values.
6362
64- :param kwargs: Data to log.
63+ Args:
64+ **kwargs: Key-value pairs of data to log. Scalars will be logged as numerical values, others as text.
6565 """
66- if self .module_writer is None :
67- msg = "start_run must be called before log_value."
68- raise RuntimeError (msg )
69-
7066 for key , value in kwargs .items ():
7167 if isinstance (value , int | float ):
7268 self .module_writer .add_scalar (key , value )
7369 else :
7470 self .module_writer .add_text (key , str (value )) # type: ignore[no-untyped-call]
7571
7672 def log_metrics (self , metrics : dict [str , Any ]) -> None :
77- """
78- Log metrics during training.
73+ """Logs training metrics.
7974
80- :param metrics: Metrics to log.
75+ Args:
76+ metrics: Dictionary of metrics to log.
8177 """
82- if self .module_writer is None :
83- msg = "start_run must be called before log_value."
84- raise RuntimeError (msg )
85-
8678 for key , value in metrics .items ():
8779 if isinstance (value , int | float ):
8880 self .module_writer .add_scalar (key , value ) # type: ignore[no-untyped-call]
8981 else :
9082 self .module_writer .add_text (key , str (value )) # type: ignore[no-untyped-call]
9183
9284 def log_final_metrics (self , metrics : dict [str , Any ]) -> None :
93- """
94- Log final metrics.
85+ """Logs final metrics at the end of training.
86+
87+ Args:
88+ metrics: Dictionary of final metrics.
9589
96- :param metrics: Final metrics.
90+ Raises:
91+ RuntimeError: If `start_run` has not been called before logging final metrics.
9792 """
9893 if self .module_writer is None :
9994 msg = "start_run must be called before log_final_metrics."
@@ -109,7 +104,11 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
109104 self .module_writer .add_text (key , str (value )) # type: ignore[no-untyped-call]
110105
111106 def end_module (self ) -> None :
112- """End a module."""
107+ """Ends the current module and closes the TensorBoard writer.
108+
109+ Raises:
110+ RuntimeError: If `start_run` has not been called before ending the module.
111+ """
113112 if self .module_writer is None :
114113 msg = "start_run must be called before end_module."
115114 raise RuntimeError (msg )
@@ -118,4 +117,4 @@ def end_module(self) -> None:
118117 self .module_writer .close () # type: ignore[no-untyped-call]
119118
120119 def end_run (self ) -> None :
121- pass
120+ """Ends the current run. This method is currently a placeholder."""
0 commit comments