diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index e9a63f5d..9a73fd41 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -1,9 +1,15 @@ """Default pipeline.""" from collections.abc import Iterable +from datetime import datetime, timezone from functools import partial +import importlib.metadata +from json import dumps +import os from pathlib import Path +import subprocess from typing import final from urllib.parse import quote_plus +import warnings from jaxtyping import Int, Int64 import torch @@ -121,7 +127,7 @@ def __init__( # noqa: PLR0913 self.loss = loss self.metrics = metrics self.optimizer = optimizer - self.run_name = run_name + self.run_name = run_name + datetime.now(tz=timezone.utc).strftime("-%Y-%m-%d-%H-%M-%S") self.source_data_batch_size = source_data_batch_size self.source_dataset = source_dataset self.source_model = source_model @@ -338,14 +344,54 @@ def validate_sae(self, validation_number_activations: int) -> None: if wandb.run is not None: wandb.log(data=calculated, commit=False) + @staticmethod + def get_git_commit_hash() -> None | str: + """Get the Git commit hash of the current directory.""" + try: + return ( + subprocess.check_output(["/usr/bin/git", "rev-parse", "HEAD"]) # noqa: S603 + .decode("ascii") + .strip() + ) + except subprocess.CalledProcessError: + warnings.warn( + "Directory is not a Git repository, not logging commit hash", stacklevel=2 + ) + return None + + @staticmethod + def get_package_version() -> str | None: + """Get the version of the package.""" + try: + return importlib.metadata.version("sparse-autoencoder") + except importlib.metadata.PackageNotFoundError: + warnings.warn("Package not found, not logging version", stacklevel=2) + return None + @final def save_checkpoint(self) -> None: """Save the model as a checkpoint.""" if self.checkpoint_directory: run_name_file_system_safe = quote_plus(self.run_name) + run_directory = self.checkpoint_directory / run_name_file_system_safe + if not run_directory.exists(): + run_directory.mkdir(parents=True) + if "config.json" not in os.listdir(run_directory): + config_dict = dict(wandb.config) + + git_hash = self.get_git_commit_hash() + config_dict["git_hash"] = git_hash + + package_version = self.get_package_version() + config_dict["package_version"] = package_version + + with Path.open(run_directory / "config.json", "w") as config_file: + config_file.write(dumps(config_dict, indent=4)) + file_path: Path = ( self.checkpoint_directory - / f"{run_name_file_system_safe}-{self.total_activations_trained_on}.pt" + / run_name_file_system_safe + / f"checkpoint-{self.total_activations_trained_on}activations.pt" ) torch.save(self.autoencoder.state_dict(), file_path)