diff --git a/mambavision/tensorboard.py b/mambavision/tensorboard.py index 50de85c..e5e40d1 100644 --- a/mambavision/tensorboard.py +++ b/mambavision/tensorboard.py @@ -2,12 +2,43 @@ from tensorboardX import SummaryWriter class TensorboardLogger(object): - def __init__(self, log_dir): + """Logs data to TensorBoard. + + Attributes: + writer: SummaryWriter instance used for logging. + step: Current step in the training process. + """ +def __init__(self, log_dir: str): + """Initializes a SummaryWriter for tensorboard logging. + + Args: + log_dir (str): The directory where the logs will be written. + + Returns: + None + + Raises: + None + """ self.writer = SummaryWriter(logdir=log_dir) self.step = 0 - def set_step(self, step=None): +def set_step(self, step: int = None) -> None: + """Sets the step value. If no step is provided, increments the current step. + + Args: + step (int, optional): The new step value. If None, the current step is incremented. Defaults to None. + + Returns: + None. + + Raises: + TypeError: If step is provided and is not an integer. + + """ if step is not None: + if not isinstance(step, int): + raise TypeError("Step must be an integer or None.") self.step = step else: self.step += 1