Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions mambavision/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,43 @@
from tensorboardX import SummaryWriter

class TensorboardLogger(object):
def __init__(self, log_dir):
"""Logs data to TensorBoard.

Attributes:
writer: SummaryWriter instance used for logging.
step: Current step in the training process.
"""
def __init__(self, log_dir: str):
"""Initializes a SummaryWriter for tensorboard logging.

Args:
log_dir (str): The directory where the logs will be written.

Returns:
None

Raises:
None
"""
self.writer = SummaryWriter(logdir=log_dir)
self.step = 0

def set_step(self, step=None):
def set_step(self, step: int = None) -> None:
"""Sets the step value. If no step is provided, increments the current step.

Args:
step (int, optional): The new step value. If None, the current step is incremented. Defaults to None.

Returns:
None.

Raises:
TypeError: If step is provided and is not an integer.

"""
if step is not None:
if not isinstance(step, int):
raise TypeError("Step must be an integer or None.")
self.step = step
else:
self.step += 1
Expand Down