diff --git a/pyproject.toml b/pyproject.toml index 6a9c317..f13dbc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "schedulefree>=1.4", "pfns==0.3.0", "openml==0.15.1", + "tqdm", + "hydra-core" ] [project.optional-dependencies] diff --git a/scripts/training/configs/train_classification.yaml b/scripts/training/configs/train_classification.yaml new file mode 100644 index 0000000..b8b6631 --- /dev/null +++ b/scripts/training/configs/train_classification.yaml @@ -0,0 +1,22 @@ +#### Hydra specific configs, do not change #### +############################################### + +model: + num_attention_heads: 6 + embedding_size: 192 + mlp_hidden_size: 768 + num_layers: 6 + +dataset: + filename: ./50x3_3_100k_classification.h5 + device: cuda + batch_size: 50 + starting_index: 0 + +training: + steps: 25 + accumulate_gradients: 1 + dataloader_num_workers: 0 # keep 0 for now + lr: 1e-4 + epochs: 80 + run_name: nanotabpfn \ No newline at end of file diff --git a/pretrain_classification.py b/scripts/training/pretrain_classification.py similarity index 86% rename from pretrain_classification.py rename to scripts/training/pretrain_classification.py index 6cea228..94c74f6 100644 --- a/pretrain_classification.py +++ b/scripts/training/pretrain_classification.py @@ -4,16 +4,17 @@ from sklearn.metrics import accuracy_score, roc_auc_score from torch import nn -from tfmplayground.callbacks import ConsoleLoggerCallback, WandbLoggerCallback +from tfmplayground.training.callbacks import ConsoleLoggerCallback, WandbLoggerCallback from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_CLASSIFICATION, TABARENA_TASKS from tfmplayground.interface import NanoTabPFNClassifier from tfmplayground.model import NanoTabPFNModel -from tfmplayground.priors import PriorDumpDataLoader +from tfmplayground.priors.dataloader import PriorDumpDataLoader, PriorDataLoader +from tfmplayground.priors.tabicl import TabICLPriorDataLoader from tfmplayground.train import train from tfmplayground.utils import get_default_device, set_randomness_seed parser = argparse.ArgumentParser() -parser.add_argument("--priordump", type=str, default="/50x3_3_100k_classification.h5", help="path to the prior dump") +parser.add_argument("--priordump", type=str, default="./50x3_3_100k_classification.h5", help="path to the prior dump") parser.add_argument("--heads", type=int, default=6, help="number of attention heads") parser.add_argument("--embeddingsize", type=int, default=192, help="the size of the embeddings used for the cells") parser.add_argument("--hiddensize", type=int, default=768, help="size of the hidden layer of the mlps") @@ -36,6 +37,18 @@ if args.loadcheckpoint: ckpt = torch.load(args.loadcheckpoint) + +# prior = TabICLPriorDataLoader( +# num_steps=args.steps, +# batch_size=args.batchsize, +# num_datapoints_min=100, +# num_datapoints_max=1000, +# min_features=3, +# max_features=15, +# max_num_classes=2, +# device=device, +# ) + prior = PriorDumpDataLoader(filename=args.priordump, num_steps=args.steps, batch_size=args.batchsize, device=device, starting_index=args.steps*(ckpt['epoch'] if ckpt else 0)) criterion = nn.CrossEntropyLoss() @@ -85,7 +98,7 @@ def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwar print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg roc auc {avg_score:.3f}', flush=True) -#callbacks = [ProductionEvaluationLoggerCallback('nanoTFM', args.runname)] +# callbacks = [ProductionEvaluationLoggerCallback('nanoTFM', args.runname)] callbacks = [ToyEvaluationLoggerCallback(TOY_TASKS_CLASSIFICATION)] trained_model, loss = train( diff --git a/scripts/training/pretrain_classification_new.py b/scripts/training/pretrain_classification_new.py new file mode 100644 index 0000000..fd0b76d --- /dev/null +++ b/scripts/training/pretrain_classification_new.py @@ -0,0 +1,80 @@ +import hydra +from omegaconf import DictConfig + +from tqdm import tqdm + +from sklearn.metrics import accuracy_score, roc_auc_score +from torch import nn +import torch.multiprocessing as mp + +from tfmplayground.training.callbacks import ConsoleLoggerCallback, WandbLoggerCallback +from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_CLASSIFICATION, TABARENA_TASKS +from tfmplayground.interface import NanoTabPFNClassifier +from tfmplayground.model import NanoTabPFNModel +from tfmplayground.priors.tabicl import TabICLPriorDataLoader +from tfmplayground.priors.dataloader import PriorDumpDataLoader +from tfmplayground.priors.dataset import PriorDumpDataset +from tfmplayground.utils import set_randomness_seed +from tfmplayground.training.trainer import BaseTrainer +from tfmplayground.training.util import tqdm_on_main + +set_randomness_seed(2402) + +class ToyEvaluationLoggerCallback(ConsoleLoggerCallback): + def __init__(self, tasks): + self.tasks = tasks + + def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs): + classifier = NanoTabPFNClassifier(model, "cuda") + predictions = get_openml_predictions(model=classifier, tasks=self.tasks) + scores = [] + for dataset_name, (y_true, y_pred, y_proba) in predictions.items(): + scores.append(accuracy_score(y_true, y_pred)) + avg_score = sum(scores) / len(scores) + tqdm_on_main(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg accuracy {avg_score:.3f}') + +class ProductionEvaluationLoggerCallback(WandbLoggerCallback): + def __init__(self, project: str, name: str = None, config: dict = None, log_dir: str = None): + super().__init__(project, name, config, log_dir) + + def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs): + classifier = NanoTabPFNClassifier(model, "cuda") + predictions = get_openml_predictions(model=classifier, classification=True, tasks=TABARENA_TASKS) + scores = [] + for dataset_name, (y_true, y_pred, y_proba) in predictions.items(): + scores.append(roc_auc_score(y_true, y_proba, multi_class='ovr')) + avg_score = sum(scores) / len(scores) + self.wandb.log({ + 'epoch': epoch, + 'epoch_time': epoch_time, + 'mean_loss': loss, + 'tabarena_avg_roc_auc': avg_score + }) + print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg roc auc {avg_score:.3f}', + flush=True) + +@hydra.main(version_base=None, config_path="configs", config_name="train_classification") +def main(cfg: DictConfig): + dataset = PriorDumpDataset( + **cfg.dataset, + num_steps=cfg.training.steps + ) + model = NanoTabPFNModel( + **cfg.model, + num_outputs=dataset.max_num_classes, + ) + # dataset = TabICLPriorDataLoader( + # **cfg.dataset + # ) + callbacks = [ToyEvaluationLoggerCallback(TOY_TASKS_CLASSIFICATION)] + trainer = BaseTrainer( + model=model, + train_dataset=dataset, + criterion=nn.CrossEntropyLoss(), + callbacks=callbacks, + **cfg.training + ) + model = trainer.train() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pretrain_regression.py b/scripts/training/pretrain_regression.py similarity index 98% rename from pretrain_regression.py rename to scripts/training/pretrain_regression.py index 311559c..a2c4bcb 100644 --- a/pretrain_regression.py +++ b/scripts/training/pretrain_regression.py @@ -4,7 +4,7 @@ from pfns.bar_distribution import FullSupportBarDistribution from sklearn.metrics import r2_score -from tfmplayground.callbacks import ConsoleLoggerCallback +from tfmplayground.training.callbacks import ConsoleLoggerCallback from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_REGRESSION from tfmplayground.interface import NanoTabPFNRegressor from tfmplayground.model import NanoTabPFNModel diff --git a/tfmplayground/evaluation.py b/tfmplayground/evaluation.py index 3f205a3..072271a 100644 --- a/tfmplayground/evaluation.py +++ b/tfmplayground/evaluation.py @@ -3,6 +3,7 @@ import numpy as np import openml import torch +import logging from openml.config import set_root_cache_directory from openml.tasks import TaskType from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score @@ -10,6 +11,7 @@ from tfmplayground.interface import NanoTabPFNRegressor, NanoTabPFNClassifier +openml.config.set_console_log_level(logging.WARNING) TOY_TASKS_REGRESSION = [ 362443, # diabetes ] diff --git a/tfmplayground/priors/__init__.py b/tfmplayground/priors/__init__.py new file mode 100644 index 0000000..b28b04f --- /dev/null +++ b/tfmplayground/priors/__init__.py @@ -0,0 +1,3 @@ + + + diff --git a/tfmplayground/priors.py b/tfmplayground/priors/dataloader.py similarity index 100% rename from tfmplayground/priors.py rename to tfmplayground/priors/dataloader.py diff --git a/tfmplayground/priors/dataset.py b/tfmplayground/priors/dataset.py new file mode 100644 index 0000000..d2441eb --- /dev/null +++ b/tfmplayground/priors/dataset.py @@ -0,0 +1,45 @@ +from torch.utils.data import IterableDataset +import torch +import h5py + +class PriorDumpDataset(IterableDataset): + def __init__(self, filename, num_steps, batch_size, device, starting_index=0): + self.filename = filename + self.num_steps = num_steps + self.batch_size = batch_size + with h5py.File(self.filename, "r") as f: + self.num_datapoints_max = f['X'].shape[0] + if "max_num_classes" in f: + self.max_num_classes = f["max_num_classes"][0] + else: + self.max_num_classes = None + self.problem_type = f['problem_type'][()].decode('utf-8') + self.device = device + self.pointer = starting_index + + def __iter__(self): + with h5py.File(self.filename, "r") as f: + for _ in range(self.num_steps): + self.data = f + end = self.pointer + self.batch_size + + num_features=self.data['num_features'][self.pointer:end].max() + x = torch.from_numpy(self.data['X'][self.pointer:end,:,:num_features]) + y = torch.from_numpy(self.data['y'][self.pointer:end]) + single_eval_pos = self.data['single_eval_pos'][self.pointer:end] + + self.pointer += self.batch_size + if self.pointer >= self.data['X'].shape[0]: + print("""Finished iteration over all stored datasets! """ + """Will start reusing the same data with different splits now.""") + self.pointer = 0 + + yield dict( + x=x.to(self.device), + y=y.to(self.device), + target_y=y.to(self.device), + single_eval_pos=single_eval_pos[0].item() + ) + + def __len__(self): + return self.num_steps \ No newline at end of file diff --git a/tfmplayground/priors/tabicl.py b/tfmplayground/priors/tabicl.py new file mode 100644 index 0000000..2872486 --- /dev/null +++ b/tfmplayground/priors/tabicl.py @@ -0,0 +1,68 @@ +import torch +from tabicl.prior.dataset import PriorDataset as TabICLPriorDataset +from torch.utils.data import DataLoader + + +class TabICLPriorDataLoader(DataLoader): + """DataLoader sampling synthetic prior data on-the-fly from TabICL's PriorDataset. + + Args: + num_steps (int): Number of batches to generate per epoch. + batch_size (int): Number of functions per batch. + num_datapoints_min (int): Minimum number of datapoints per function. + num_datapoints_max (int): Maximum number of datapoints per function. + min_features (int): Minimum number of features in x. + max_features (int): Maximum number of features in x. + max_num_classes (int): Maximum number of classes (for classification tasks). + device (torch.device): Target device for tensors. + """ + + def __init__( + self, + num_steps: int, + batch_size: int, + num_datapoints_min: int, + num_datapoints_max: int, + min_features: int, + max_features: int, + max_num_classes: int, + device: torch.device, + ): + self.num_steps = num_steps + self.batch_size = batch_size + self.num_datapoints_min = num_datapoints_min + self.num_datapoints_max = num_datapoints_max + self.min_features = min_features + self.max_features = max_features + self.max_num_classes = max_num_classes + self.device = device + + self.pd = TabICLPriorDataset( + batch_size=batch_size, + batch_size_per_gp=batch_size, + min_features=min_features, + max_features=max_features, + max_classes=max_num_classes, + min_seq_len=num_datapoints_min, + max_seq_len=num_datapoints_max, + ) + + def tabicl_to_ours(self, d): + x, y, active_features, seqlen, train_size = d + active_features = active_features[ + 0 + ].item() # should be all the same since we use batch_size_per_gp=batch_size (not true in practice!) + x = x[:, :, :active_features] + single_eval_pos = train_size[0].item() # should be all the same since we use batch_size_per_gp=batch_size + return dict( + x=x.to(self.device), + y=y.to(self.device), + target_y=y.to(self.device), # target_y is identical to y (for downstream compatibility) + single_eval_pos=single_eval_pos, + ) + + def __iter__(self): + return iter(self.tabicl_to_ours(next(self.pd)) for _ in range(self.num_steps)) + + def __len__(self): + return self.num_steps \ No newline at end of file diff --git a/tfmplayground/train.py b/tfmplayground/train.py index 3973b87..1970cf6 100644 --- a/tfmplayground/train.py +++ b/tfmplayground/train.py @@ -7,7 +7,7 @@ import schedulefree import os -from tfmplayground.callbacks import Callback +from tfmplayground.training.callbacks import Callback from tfmplayground.model import NanoTabPFNModel from tfmplayground.utils import get_default_device diff --git a/tfmplayground/training/__init__.py b/tfmplayground/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tfmplayground/training/base.py b/tfmplayground/training/base.py new file mode 100644 index 0000000..93b6fb1 --- /dev/null +++ b/tfmplayground/training/base.py @@ -0,0 +1,48 @@ +from abc import ABC, abstractmethod + +import torch.nn as nn + +class Callback(ABC): + """ Abstract base class for callbacks.""" + + @abstractmethod + def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs): + """ + Called at the end of each epoch. + + Args: + epoch (int): The current epoch number. + epoch_time (float): Time of the epoch in seconds. + loss (float): Mean loss for the epoch. + model: The model being trained. + **kwargs: Additional arguments. + """ + pass + + @abstractmethod + def close(self): + """ + Called to release any resources or perform cleanup. + """ + pass + + +class Trainer(ABC): + """Trainer class for training models.""" + + @abstractmethod + def train(self) -> nn.Module: + """ + Trains the given model on the provided dataset. + + Args: + model: The model to be trained. + train_dataset: The dataset to train the model on. + callbacks (list[Callback]): List of callback instances to be used during training. + run_dir (str): Directory for saving training outputs. + run_name (str): Name of the training run. + + Returns: + The trained model and final loss. + """ + pass \ No newline at end of file diff --git a/tfmplayground/callbacks.py b/tfmplayground/training/callbacks.py similarity index 77% rename from tfmplayground/callbacks.py rename to tfmplayground/training/callbacks.py index 8297bef..bcd63ff 100644 --- a/tfmplayground/callbacks.py +++ b/tfmplayground/training/callbacks.py @@ -1,29 +1,4 @@ -from abc import ABC, abstractmethod - - -class Callback(ABC): - """ Abstract base class for callbacks.""" - - @abstractmethod - def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs): - """ - Called at the end of each epoch. - - Args: - epoch (int): The current epoch number. - epoch_time (float): Time of the epoch in seconds. - loss (float): Mean loss for the epoch. - model: The model being trained. - **kwargs: Additional arguments. - """ - pass - - @abstractmethod - def close(self): - """ - Called to release any resources or perform cleanup. - """ - pass +from .base import Callback class BaseLoggerCallback(Callback): diff --git a/tfmplayground/training/trainer.py b/tfmplayground/training/trainer.py new file mode 100644 index 0000000..6318f1a --- /dev/null +++ b/tfmplayground/training/trainer.py @@ -0,0 +1,280 @@ +from typing import Dict, Any, Optional, Union + +import logging +import os +import time + +from pfns.bar_distribution import FullSupportBarDistribution +import schedulefree +import torch +import torch.nn as nn +from torch.distributed import init_process_group +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.data import IterableDataset, DataLoader +from tqdm import tqdm + +from .base import Callback, Trainer +from .util import ( + infer_device, + log_on_main, + ddp_teardown, + generate_random_run_name, +) + +logger = logging.getLogger(__name__) + + +class BaseTrainer(Trainer): + """Trainer class for training models.""" + + model: nn.Module + train_dataset: IterableDataset + callbacks: list[Callback] + run_dir: Union[str, None] + run_name: Union[str, None] + + # DDP + ddp: bool + ddp_rank: int + ddp_local_rank: int + ddp_world_size: int + master_process: bool + + def __init__( + self, + model: nn.Module, + train_dataset: IterableDataset, + criterion: nn.Module, + initial_lr: float = 1e-4, + weight_decay: float = 0.0, + accumulate_gradients: int = 1, + epochs: int = 10000, + steps: int = 100, + callbacks: list[Callback] = [], + run_dir: Optional[str] = None, + run_name: Optional[str] = None, + use_cpu: bool = False, + dataloader_num_workers: int = 0, + **kwargs, + ) -> None: + self.model = model + self.train_dataset = train_dataset + + self.callbacks = callbacks + + # output setup + self.run_dir = run_dir + self.run_name = run_name + self._setup_output_dir() + + # device setup + self.device, self.ddp = infer_device(use_cpu) + + if self.ddp: + self._configure_ddp() + self.model.to(self.device) + self.model = DDP( + self.model, + device_ids=[self.ddp_local_rank], + broadcast_buffers=False, + ) + self.raw_model = self.model.module + + # find a better solution for this + if train_dataset.batch_size % self.ddp_world_size != 0: + raise ValueError( + f"Dataset batch size {train_dataset.batch_size} not divisible by DDP world size {self.ddp_world_size}" + ) + else: + self.master_process = True + self.raw_model = self.model + self.model.to(self.device) + + # training setup + self.initial_lr = initial_lr + self.weight_decay = weight_decay + self.accumulate_gradients = accumulate_gradients + self.epochs = epochs + self.steps = steps + + # TODO: probably want to allow the user to specify more here? + self.optimizer = schedulefree.AdamWScheduleFree( + self.raw_model.parameters(), + lr=self.initial_lr, + weight_decay=self.weight_decay, + ) + self.criterion = criterion + if dataloader_num_workers > 0: + # NOTE: not yet supported as there is some difficulty with making torch.mp work for this might be due to + # the already gpu loaded tensors in the dataset but not 100% sure. + raise NotImplementedError("Dataloader num_workers > 0 not supported yet, requires MP.") + + # shard train dataset if DDP + self.train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=None, + num_workers=dataloader_num_workers, + ) + + # TODO: handle checkpoint loading + def _configure_ddp(self) -> None: + # right now we only support DDP on CUDA + init_process_group(backend="nccl") + self.ddp_rank = int(os.environ["RANK"]) + self.ddp_local_rank = int(os.environ["LOCAL_RANK"]) + self.ddp_world_size = int(os.environ["WORLD_SIZE"]) + self.master_process = self.ddp_rank == 0 + self.device = f"cuda:{self.ddp_local_rank}" + torch.cuda.set_device(self.device) + + log_on_main(logger, f"Running DDP with {self.ddp_world_size} Processes", logging.INFO) + + def _load_checkpoint(self) -> Dict[str, Any]: + pass # TODO + + def _setup_output_dir(self) -> None: + if self.run_dir is None: + self.run_dir = "training_outputs/" + os.makedirs(self.run_dir, exist_ok=True) + + if self.run_name is None: + # TODO: generate random name + self.run_name = generate_random_run_name() + + # name of each subdir should be self.run_name-run-id where id starts at 1 and increments + # if the dir already exists + run_id = 1 + while os.path.exists(os.path.join(self.run_dir, f"{self.run_name}-run-{run_id}")): + run_id += 1 + self.run_name = f"{self.run_name}-run-{run_id}" + self.run_dir = os.path.join(self.run_dir, self.run_name) + os.makedirs(self.run_dir, exist_ok=True) + + def _loss(self, output: torch.Tensor, targets: torch.Tensor) -> float: + losses = self.criterion(output, targets) + loss = losses.mean() / self.accumulate_gradients + loss.backward() + return loss.cpu().detach().item() * self.accumulate_gradients + + @ddp_teardown + def train(self, resume_from_checkpoint: bool = False) -> nn.Module: + if resume_from_checkpoint: + checkpoint = self._load_checkpoint() + if checkpoint is not None: + # TODO: this is not really safe, what if we load an empty ckpt? + # also, no we need model or raw model here? + self.raw_model.load_state_dict(checkpoint["model"]) + self.optimizer.load_state_dict(checkpoint["optimizer"]) + else: + checkpoint = None + + classification_task = isinstance(self.criterion, nn.CrossEntropyLoss) + regression_task = not classification_task + + assert self.steps % self.accumulate_gradients == 0, ( + "num_steps must be divisible by accumulate_gradients" + ) + + progress_bar = ( + tqdm(range(self.epochs), desc="Training", leave=True) + if self.master_process + else range(self.epochs) + ) + try: + # actual training loop + for epoch in range(checkpoint["epoch"] + 1 if checkpoint else 1, self.epochs + 1): + epoch_start_time = time.time() + self.model.train() + self.optimizer.train() + total_loss = 0.0 + + # NOTE: I am not sure if we want to rebuild every time, right now we have to as + # otherwise the iterator will be exhausted after one epoch but this might lead to + # weird behavior with otf datasets I think. + train_dataloader = iter(self.train_dataloader) + for step in range(self.steps): + # yields a batch of data where the shape of x is (batch_size, num_samples, num_features) + full_data = next(train_dataloader) + single_eval_pos = full_data["single_eval_pos"] + data = (full_data["x"].to(self.device), full_data["y"][:, :single_eval_pos].to(self.device)) + if torch.isnan(data[0]).any() or torch.isnan(data[1]).any(): + continue + targets = full_data["target_y"].to(self.device) + + if regression_task: + y_mean = data[1].mean(dim=1, keepdim=True) + y_std = data[1].std(dim=1, keepdim=True) + 1e-8 + y_norm = (data[1] - y_mean) / y_std + data = (data[0], y_norm) + + output = self.model(data, single_eval_pos=single_eval_pos) + targets = targets[:, single_eval_pos:] + if regression_task: + targets = (targets - y_mean) / y_std + if classification_task: + targets = targets.reshape((-1,)).to(torch.long) + output = output.view(-1, output.shape[-1]) + + if self.ddp and step % self.accumulate_gradients != 0: + with self.model.no_sync(): + total_loss += self._loss(output, targets) + else: + total_loss += self._loss(output, targets) + + if (step + 1) % self.accumulate_gradients == 0: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + self.optimizer.step() + self.optimizer.zero_grad() + + # update progress bar + if self.master_process: + # in this case progress bar is tqdm + progress_bar.set_postfix({ + "epoch": f"{epoch}/{self.epochs}", + "step": f"{step+1}/{self.steps}", + }) + + end_time = time.time() + mean_loss = total_loss / self.steps + self.model.eval() + self.optimizer.eval() + + training_state = { + "epoch": epoch, + "architecture": { + "num_layers": int(self.raw_model.num_layers), + "embedding_size": int(self.raw_model.embedding_size), + "num_attention_heads": int(self.raw_model.num_attention_heads), + "mlp_hidden_size": int(self.raw_model.mlp_hidden_size), + "num_outputs": int(self.raw_model.num_outputs), + }, + "model": self.raw_model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(training_state, os.path.join(self.run_dir, "latest_checkpoint.pth")) + + for callback in self.callbacks: + if type(self.criterion) is FullSupportBarDistribution: + callback.on_epoch_end( + epoch, + end_time - epoch_start_time, + mean_loss, + self.raw_model, + dist=self.criterion, + ) + else: + callback.on_epoch_end(epoch, end_time - epoch_start_time, mean_loss, self.raw_model) + if self.master_process: + progress_bar.update(1) + + if self.master_process: + progress_bar.close() + except KeyboardInterrupt: + if self.master_process: + progress_bar.close() + pass + finally: + for callback in self.callbacks: + callback.close() + + return self.model diff --git a/tfmplayground/training/util.py b/tfmplayground/training/util.py new file mode 100644 index 0000000..b18f439 --- /dev/null +++ b/tfmplayground/training/util.py @@ -0,0 +1,157 @@ +from functools import wraps +import logging +from typing import Tuple +import os + +import random +import torch +from torch.distributed import destroy_process_group +from tqdm import tqdm + +logger = logging.getLogger(__name__) + +RANDOM_IDENTIFIES = [ + "dog", + "cat", + "car", + "tree", + "house", + "computer", + "phone", + "bicycle", + "river", + "mountain", + "ocean", + "cloud", + "star", + "moon", + "sun", + "flower", + "bird", + "fish", + "butterfly", + "rainbow", + "forest", + "desert", + "island", + "valley", + "canyon", + "waterfall", + "volcano", +] + +RANDOM_ATTRIBUTES = [ + "red", + "blue", + "green", + "yellow", + "fast", + "slow", + "bright", + "dark", + "loud", + "quiet", + "happy", + "sad", + "strong", + "weak", + "hot", + "cold", + "soft", + "smooth", + "rough", + "shiny", + "dull", + "fresh", + "stale", +] + + +def generate_random_run_name() -> str: + identify = random.choice(RANDOM_IDENTIFIES) + attribute = random.choice(RANDOM_ATTRIBUTES) + number = random.randint(0, 999) + return f"{attribute}_{identify}_{number:03d}" + + +def check_ddp_availability() -> bool: + """ + Check whether DDP is available. + """ + + ddp_available = torch.distributed.is_available() and (torch.cuda.device_count() > 1) + if not ddp_available: + return ddp_available + + assert int(os.environ.get("WORLD_SIZE", 0)) <= torch.cuda.device_count(), ( + f"Number of GPUs ({torch.cuda.device_count()}) is less than the number of processes \ +({os.environ.get('WORLD_SIZE', 0)})" + ) + + if int(os.environ.get("WORLD_SIZE", 0)) != torch.cuda.device_count(): + raise RuntimeError( + f"Number of GPUs available ({torch.cuda.device_count()}) is not equal to the number of processes \ +({os.environ.get('WORLD_SIZE', 0)}). Please specify torchrun --nproc-per-node=NUM_GPUS. If you have more GPUs \ +on your machine, then you want to use, set CUDA_VISIBLE_DEVICES." + ) + + return ddp_available + +def infer_device(use_cpu: bool) -> Tuple[str, bool]: + """ + Automatically infer the device to use for training. If DDP is availeble, + this method will automatically setup DDP for training. + + Parameters + ---------- + use_cpu: bool + Force the use of CPU. If no CUDA is available CPU will automatically + be used. + + Returns + ------- + Tuple[str, bool] + - The device to use for training. Either "cuda" or "cpu". + - Whether DDP is available and used. + """ + + device = "cuda" + ddp = False + if use_cpu or not torch.cuda.is_available(): + device = "cpu" + logger.info("Using CPU for training.") + return device, ddp + + ddp = check_ddp_availability() + + return device, ddp + +def ddp_teardown(func): + """ + Decorator to ensure ddp process group is cleaned up even if training fails. + """ + + @wraps(func) + def wrapper(self, *args, **kwargs): + try: + return func(self, *args, **kwargs) + finally: + if self.ddp: + print("Cleaning up ddp process group...") + destroy_process_group() + + return wrapper + +def log_on_main(logger: logging.Logger, message: str, level: int) -> None: + """ + Simple function to log only on main process in ddp setting. + """ + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + logger.log(level, message) + +def tqdm_on_main(message: str) -> None: + """ + tqdm write only on main process in ddp setting. + """ + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + tqdm.write(message) \ No newline at end of file