Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ dependencies = [
"schedulefree>=1.4",
"pfns==0.3.0",
"openml==0.15.1",
"tqdm",
"hydra-core"
]

[project.optional-dependencies]
Expand Down
22 changes: 22 additions & 0 deletions scripts/training/configs/train_classification.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#### Hydra specific configs, do not change ####
###############################################

model:
num_attention_heads: 6
embedding_size: 192
mlp_hidden_size: 768
num_layers: 6

dataset:
filename: ./50x3_3_100k_classification.h5
device: cuda
batch_size: 50
starting_index: 0

training:
steps: 25
accumulate_gradients: 1
dataloader_num_workers: 0 # keep 0 for now
lr: 1e-4
epochs: 80
run_name: nanotabpfn
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from sklearn.metrics import accuracy_score, roc_auc_score
from torch import nn

from tfmplayground.callbacks import ConsoleLoggerCallback, WandbLoggerCallback
from tfmplayground.training.callbacks import ConsoleLoggerCallback, WandbLoggerCallback
from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_CLASSIFICATION, TABARENA_TASKS
from tfmplayground.interface import NanoTabPFNClassifier
from tfmplayground.model import NanoTabPFNModel
from tfmplayground.priors import PriorDumpDataLoader
from tfmplayground.priors.dataloader import PriorDumpDataLoader, PriorDataLoader
from tfmplayground.priors.tabicl import TabICLPriorDataLoader
from tfmplayground.train import train
from tfmplayground.utils import get_default_device, set_randomness_seed

parser = argparse.ArgumentParser()
parser.add_argument("--priordump", type=str, default="/50x3_3_100k_classification.h5", help="path to the prior dump")
parser.add_argument("--priordump", type=str, default="./50x3_3_100k_classification.h5", help="path to the prior dump")
parser.add_argument("--heads", type=int, default=6, help="number of attention heads")
parser.add_argument("--embeddingsize", type=int, default=192, help="the size of the embeddings used for the cells")
parser.add_argument("--hiddensize", type=int, default=768, help="size of the hidden layer of the mlps")
Expand All @@ -36,6 +37,18 @@
if args.loadcheckpoint:
ckpt = torch.load(args.loadcheckpoint)


# prior = TabICLPriorDataLoader(
# num_steps=args.steps,
# batch_size=args.batchsize,
# num_datapoints_min=100,
# num_datapoints_max=1000,
# min_features=3,
# max_features=15,
# max_num_classes=2,
# device=device,
# )

prior = PriorDumpDataLoader(filename=args.priordump, num_steps=args.steps, batch_size=args.batchsize, device=device, starting_index=args.steps*(ckpt['epoch'] if ckpt else 0))

criterion = nn.CrossEntropyLoss()
Expand Down Expand Up @@ -85,7 +98,7 @@ def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwar
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg roc auc {avg_score:.3f}',
flush=True)

#callbacks = [ProductionEvaluationLoggerCallback('nanoTFM', args.runname)]
# callbacks = [ProductionEvaluationLoggerCallback('nanoTFM', args.runname)]
callbacks = [ToyEvaluationLoggerCallback(TOY_TASKS_CLASSIFICATION)]

trained_model, loss = train(
Expand Down
80 changes: 80 additions & 0 deletions scripts/training/pretrain_classification_new.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import hydra
from omegaconf import DictConfig

from tqdm import tqdm

from sklearn.metrics import accuracy_score, roc_auc_score
from torch import nn
import torch.multiprocessing as mp

from tfmplayground.training.callbacks import ConsoleLoggerCallback, WandbLoggerCallback
from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_CLASSIFICATION, TABARENA_TASKS
from tfmplayground.interface import NanoTabPFNClassifier
from tfmplayground.model import NanoTabPFNModel
from tfmplayground.priors.tabicl import TabICLPriorDataLoader
from tfmplayground.priors.dataloader import PriorDumpDataLoader
from tfmplayground.priors.dataset import PriorDumpDataset
from tfmplayground.utils import set_randomness_seed
from tfmplayground.training.trainer import BaseTrainer
from tfmplayground.training.util import tqdm_on_main

set_randomness_seed(2402)

class ToyEvaluationLoggerCallback(ConsoleLoggerCallback):
def __init__(self, tasks):
self.tasks = tasks

def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
classifier = NanoTabPFNClassifier(model, "cuda")
predictions = get_openml_predictions(model=classifier, tasks=self.tasks)
scores = []
for dataset_name, (y_true, y_pred, y_proba) in predictions.items():
scores.append(accuracy_score(y_true, y_pred))
avg_score = sum(scores) / len(scores)
tqdm_on_main(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg accuracy {avg_score:.3f}')

class ProductionEvaluationLoggerCallback(WandbLoggerCallback):
def __init__(self, project: str, name: str = None, config: dict = None, log_dir: str = None):
super().__init__(project, name, config, log_dir)

def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
classifier = NanoTabPFNClassifier(model, "cuda")
predictions = get_openml_predictions(model=classifier, classification=True, tasks=TABARENA_TASKS)
scores = []
for dataset_name, (y_true, y_pred, y_proba) in predictions.items():
scores.append(roc_auc_score(y_true, y_proba, multi_class='ovr'))
avg_score = sum(scores) / len(scores)
self.wandb.log({
'epoch': epoch,
'epoch_time': epoch_time,
'mean_loss': loss,
'tabarena_avg_roc_auc': avg_score
})
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg roc auc {avg_score:.3f}',
flush=True)

@hydra.main(version_base=None, config_path="configs", config_name="train_classification")
def main(cfg: DictConfig):
dataset = PriorDumpDataset(
**cfg.dataset,
num_steps=cfg.training.steps
)
model = NanoTabPFNModel(
**cfg.model,
num_outputs=dataset.max_num_classes,
)
# dataset = TabICLPriorDataLoader(
# **cfg.dataset
# )
callbacks = [ToyEvaluationLoggerCallback(TOY_TASKS_CLASSIFICATION)]
trainer = BaseTrainer(
model=model,
train_dataset=dataset,
criterion=nn.CrossEntropyLoss(),
callbacks=callbacks,
**cfg.training
)
model = trainer.train()

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pfns.bar_distribution import FullSupportBarDistribution
from sklearn.metrics import r2_score

from tfmplayground.callbacks import ConsoleLoggerCallback
from tfmplayground.training.callbacks import ConsoleLoggerCallback
from tfmplayground.evaluation import get_openml_predictions, TOY_TASKS_REGRESSION
from tfmplayground.interface import NanoTabPFNRegressor
from tfmplayground.model import NanoTabPFNModel
Expand Down
2 changes: 2 additions & 0 deletions tfmplayground/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import numpy as np
import openml
import torch
import logging
from openml.config import set_root_cache_directory
from openml.tasks import TaskType
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, r2_score
from sklearn.preprocessing import LabelEncoder

from tfmplayground.interface import NanoTabPFNRegressor, NanoTabPFNClassifier

openml.config.set_console_log_level(logging.WARNING)
TOY_TASKS_REGRESSION = [
362443, # diabetes
]
Expand Down
3 changes: 3 additions & 0 deletions tfmplayground/priors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@



File renamed without changes.
45 changes: 45 additions & 0 deletions tfmplayground/priors/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from torch.utils.data import IterableDataset
import torch
import h5py

class PriorDumpDataset(IterableDataset):
def __init__(self, filename, num_steps, batch_size, device, starting_index=0):
self.filename = filename
self.num_steps = num_steps
self.batch_size = batch_size
with h5py.File(self.filename, "r") as f:
self.num_datapoints_max = f['X'].shape[0]
if "max_num_classes" in f:
self.max_num_classes = f["max_num_classes"][0]
else:
self.max_num_classes = None
self.problem_type = f['problem_type'][()].decode('utf-8')
self.device = device
self.pointer = starting_index

def __iter__(self):
with h5py.File(self.filename, "r") as f:
for _ in range(self.num_steps):
self.data = f
end = self.pointer + self.batch_size

num_features=self.data['num_features'][self.pointer:end].max()
x = torch.from_numpy(self.data['X'][self.pointer:end,:,:num_features])
y = torch.from_numpy(self.data['y'][self.pointer:end])
single_eval_pos = self.data['single_eval_pos'][self.pointer:end]

self.pointer += self.batch_size
if self.pointer >= self.data['X'].shape[0]:
print("""Finished iteration over all stored datasets! """
"""Will start reusing the same data with different splits now.""")
self.pointer = 0

yield dict(
x=x.to(self.device),
y=y.to(self.device),
target_y=y.to(self.device),
single_eval_pos=single_eval_pos[0].item()
)

def __len__(self):
return self.num_steps
68 changes: 68 additions & 0 deletions tfmplayground/priors/tabicl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from tabicl.prior.dataset import PriorDataset as TabICLPriorDataset
from torch.utils.data import DataLoader


class TabICLPriorDataLoader(DataLoader):
"""DataLoader sampling synthetic prior data on-the-fly from TabICL's PriorDataset.

Args:
num_steps (int): Number of batches to generate per epoch.
batch_size (int): Number of functions per batch.
num_datapoints_min (int): Minimum number of datapoints per function.
num_datapoints_max (int): Maximum number of datapoints per function.
min_features (int): Minimum number of features in x.
max_features (int): Maximum number of features in x.
max_num_classes (int): Maximum number of classes (for classification tasks).
device (torch.device): Target device for tensors.
"""

def __init__(
self,
num_steps: int,
batch_size: int,
num_datapoints_min: int,
num_datapoints_max: int,
min_features: int,
max_features: int,
max_num_classes: int,
device: torch.device,
):
self.num_steps = num_steps
self.batch_size = batch_size
self.num_datapoints_min = num_datapoints_min
self.num_datapoints_max = num_datapoints_max
self.min_features = min_features
self.max_features = max_features
self.max_num_classes = max_num_classes
self.device = device

self.pd = TabICLPriorDataset(
batch_size=batch_size,
batch_size_per_gp=batch_size,
min_features=min_features,
max_features=max_features,
max_classes=max_num_classes,
min_seq_len=num_datapoints_min,
max_seq_len=num_datapoints_max,
)

def tabicl_to_ours(self, d):
x, y, active_features, seqlen, train_size = d
active_features = active_features[
0
].item() # should be all the same since we use batch_size_per_gp=batch_size (not true in practice!)
x = x[:, :, :active_features]
single_eval_pos = train_size[0].item() # should be all the same since we use batch_size_per_gp=batch_size
return dict(
x=x.to(self.device),
y=y.to(self.device),
target_y=y.to(self.device), # target_y is identical to y (for downstream compatibility)
single_eval_pos=single_eval_pos,
)

def __iter__(self):
return iter(self.tabicl_to_ours(next(self.pd)) for _ in range(self.num_steps))

def __len__(self):
return self.num_steps
2 changes: 1 addition & 1 deletion tfmplayground/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import schedulefree
import os

from tfmplayground.callbacks import Callback
from tfmplayground.training.callbacks import Callback
from tfmplayground.model import NanoTabPFNModel
from tfmplayground.utils import get_default_device

Expand Down
Empty file.
48 changes: 48 additions & 0 deletions tfmplayground/training/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from abc import ABC, abstractmethod

import torch.nn as nn

class Callback(ABC):
""" Abstract base class for callbacks."""

@abstractmethod
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
"""
Called at the end of each epoch.

Args:
epoch (int): The current epoch number.
epoch_time (float): Time of the epoch in seconds.
loss (float): Mean loss for the epoch.
model: The model being trained.
**kwargs: Additional arguments.
"""
pass

@abstractmethod
def close(self):
"""
Called to release any resources or perform cleanup.
"""
pass


class Trainer(ABC):
"""Trainer class for training models."""

@abstractmethod
def train(self) -> nn.Module:
"""
Trains the given model on the provided dataset.

Args:
model: The model to be trained.
train_dataset: The dataset to train the model on.
callbacks (list[Callback]): List of callback instances to be used during training.
run_dir (str): Directory for saving training outputs.
run_name (str): Name of the training run.

Returns:
The trained model and final loss.
"""
pass
Original file line number Diff line number Diff line change
@@ -1,29 +1,4 @@
from abc import ABC, abstractmethod


class Callback(ABC):
""" Abstract base class for callbacks."""

@abstractmethod
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
"""
Called at the end of each epoch.

Args:
epoch (int): The current epoch number.
epoch_time (float): Time of the epoch in seconds.
loss (float): Mean loss for the epoch.
model: The model being trained.
**kwargs: Additional arguments.
"""
pass

@abstractmethod
def close(self):
"""
Called to release any resources or perform cleanup.
"""
pass
from .base import Callback


class BaseLoggerCallback(Callback):
Expand Down
Loading