Skip to content

Commit c45cbec

Browse files
Merge pull request #3 from PriorLabs/callbacks
Add Callbacks
2 parents 101f41e + 857555d commit c45cbec

File tree

6 files changed

+157
-46
lines changed

6 files changed

+157
-46
lines changed

README.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ from nanotabpfn.train import train
6262
from nanotabpfn.utils import get_default_device
6363
from nanotabpfn.interface import NanoTabPFNClassifier
6464
from torch.nn import CrossEntropyLoss
65+
from nanotabpfn.callbacks import ConsoleLoggerCallback
6566
```
6667
then we instantiate our model and loss criterion:
6768
```python
@@ -81,17 +82,12 @@ prior = PriorDumpDataLoader(filename='50x3_3_100k_classification.h5', num_steps=
8182
```
8283
and finally train our model:
8384
```python
84-
def epoch_callback(epoch, epoch_time, mean_loss, model):
85-
classifier = NanoTabPFNClassifier(model, device)
86-
# you can add your own eval code here that runs after every epoch
87-
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {mean_loss:5.2f}', flush=True)
88-
8985
trained_model, loss = train(
9086
model=model,
9187
prior=prior,
9288
criterion=criterion,
9389
epochs=80,
9490
device=device,
95-
epoch_callback=epoch_callback
91+
callbacks=[ConsoleLoggerCallback()]
9692
)
9793
```

nanotabpfn/callbacks.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from abc import ABC, abstractmethod
2+
3+
4+
class Callback(ABC):
5+
""" Abstract base class for callbacks."""
6+
7+
@abstractmethod
8+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
9+
"""
10+
Called at the end of each epoch.
11+
12+
Args:
13+
epoch (int): The current epoch number.
14+
epoch_time (float): Time of the epoch in seconds.
15+
loss (float): Mean loss for the epoch.
16+
model: The model being trained.
17+
**kwargs: Additional arguments.
18+
"""
19+
pass
20+
21+
@abstractmethod
22+
def close(self):
23+
"""
24+
Called to release any resources or perform cleanup.
25+
"""
26+
pass
27+
28+
29+
class BaseLoggerCallback(Callback):
30+
""" Abstract base class for logger callbacks. """
31+
pass
32+
33+
34+
class ConsoleLoggerCallback(BaseLoggerCallback):
35+
""" Logger callback that prints epoch information to the console. """
36+
37+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
38+
print(f'Epoch {epoch:5d} | Time {epoch_time:5.2f}s | Mean Loss {loss:5.2f}', flush=True)
39+
40+
def close(self):
41+
""" Nothing to clean up for print logger. """
42+
pass
43+
44+
45+
class TensorboardLoggerCallback(BaseLoggerCallback):
46+
""" Logger callback that logs epoch information to TensorBoard. """
47+
48+
def __init__(self, log_dir: str):
49+
from torch.utils.tensorboard import SummaryWriter
50+
self.writer = SummaryWriter(log_dir=log_dir)
51+
52+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
53+
self.writer.add_scalar('Loss/train', loss, epoch)
54+
self.writer.add_scalar('Time/epoch', epoch_time, epoch)
55+
56+
def close(self):
57+
self.writer.close()
58+
59+
60+
class WandbLoggerCallback(BaseLoggerCallback):
61+
""" Logger callback that logs epoch information to Weights & Biases. """
62+
63+
def __init__(self, project: str, name: str = None, config: dict = None, log_dir: str = None):
64+
"""
65+
Initializes a WandbLoggerCallback.
66+
67+
Args:
68+
project (str): The name of the wandb project.
69+
name (str, optional): The name of the run. Defaults to None.
70+
config (dict, optional): Configuration dictionary for the run. Defaults to None.
71+
log_dir (str, optional): Directory to save wandb logs. Defaults to None.
72+
"""
73+
try:
74+
import wandb
75+
self.wandb = wandb # store wandb module to avoid import if not used
76+
wandb.init(
77+
project=project,
78+
name=name,
79+
config=config,
80+
dir=log_dir,
81+
)
82+
except ImportError:
83+
raise ImportError("wandb is not installed. Install it with: pip install wandb") from e
84+
85+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
86+
log_dict = {'epoch': epoch, 'loss': loss, ' epoch_time': epoch_time}
87+
self.wandb.log(log_dict)
88+
89+
def close(self):
90+
self.wandb.finish()

nanotabpfn/train.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
from torch import nn
33
import time
44
from torch.utils.data import DataLoader
5-
from typing import Tuple, Dict, Callable
5+
from typing import Dict
66
from pfns.bar_distribution import FullSupportBarDistribution
77
import schedulefree
88

9+
from nanotabpfn.callbacks import Callback
910
from nanotabpfn.model import NanoTabPFNModel
1011
from nanotabpfn.utils import get_default_device
1112

12-
def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyLoss | FullSupportBarDistribution, epochs: int,
13-
accumulate_gradients: int = 1, lr: float = 1e-4, device: torch.device = None,
14-
epoch_callback: Callable[[int, float, float, NanoTabPFNModel, FullSupportBarDistribution | None], None] = None, ckpt: Dict[str, torch.Tensor] = None):
13+
14+
def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyLoss | FullSupportBarDistribution,
15+
epochs: int, accumulate_gradients: int = 1, lr: float = 1e-4, device: torch.device = None,
16+
callbacks: list[Callback]=None, ckpt: Dict[str, torch.Tensor] = None):
1517
"""
1618
Trains our model on the given prior using the given criterion.
1719
@@ -22,14 +24,17 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
2224
epochs: (int) the number of epochs we train for, the number of steps that constitute an epoch are decided by the prior
2325
accumulate_gradients: (int) the number of gradients to accumulate before updating the weights
2426
device: (torch.device) the device we are using
25-
epoch_callback: (Callable[[int, float, float, NanoTabPFNModel], None]) optional callback function that will be called
26-
at the end of each epoch with the current epoch, epoch duration, mean loss, and the model,
27-
intended to be used for logging/validation/evaluation
27+
callbacks: A list of callback instances to execute at the end of each epoch. These can be used for
28+
logging, validation, or other custom actions.
29+
ckpt (Dict[str, torch.Tensor], optional): A checkpoint dictionary containing the model and optimizer states,
30+
as well as the last completed epoch. If provided, training resumes from this checkpoint.
2831
2932
Returns:
3033
(torch.Tensor) a tensor of shape (num_rows, batch_size, num_features, embedding_size)
3134
"""
3235
# print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters")
36+
if callbacks is None:
37+
callbacks = []
3338
if not device:
3439
device = get_default_device()
3540
model.to(device)
@@ -41,8 +46,8 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
4146
assert prior.num_steps % accumulate_gradients == 0, 'num_steps must be divisible by accumulate_gradients'
4247

4348
try:
44-
for epoch in range(ckpt['epoch']+1 if ckpt else 1, epochs + 1):
45-
start_time = time.time()
49+
for epoch in range(ckpt['epoch'] + 1 if ckpt else 1, epochs + 1):
50+
epoch_start_time = time.time()
4651
model.train() # Turn on the train mode
4752
optimizer.train()
4853
total_loss = 0.
@@ -81,12 +86,15 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
8186
}
8287
torch.save(training_state, 'latest_checkpoint.pth')
8388

84-
if epoch_callback:
89+
for callback in callbacks:
8590
if type(criterion) is FullSupportBarDistribution:
86-
epoch_callback(epoch, end_time - start_time, mean_loss, model, dist=criterion)
91+
callback.on_epoch_end(epoch, end_time - epoch_start_time, mean_loss, model, dist=criterion)
8792
else:
88-
epoch_callback(epoch, end_time-start_time, mean_loss, model)
93+
callback.on_epoch_end(epoch, end_time - epoch_start_time, mean_loss, model)
8994
except KeyboardInterrupt:
9095
pass
96+
finally:
97+
for callback in callbacks:
98+
callback.close()
9199

92100
return model, total_loss

pretrain_classification.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from torch import nn
66
from functools import partial
77

8+
from nanotabpfn.callbacks import ConsoleLoggerCallback
89
from nanotabpfn.priors import PriorDumpDataLoader
910
from nanotabpfn.model import NanoTabPFNModel
1011
from nanotabpfn.train import train
@@ -29,15 +30,14 @@
2930
parser.add_argument("-epochs", type=int, default=10000, help="number of epochs to train for")
3031
parser.add_argument("-loadcheckpoint", type=str, default=None, help="checkpoint from which to continue training")
3132

32-
3333
args = parser.parse_args()
3434

3535
set_randomness_seed(2402)
3636

3737
device = get_default_device()
38-
ckpt=None
38+
ckpt = None
3939
if args.loadcheckpoint:
40-
ckpt=torch.load(args.loadcheckpoint)
40+
ckpt = torch.load(args.loadcheckpoint)
4141

4242
prior = PriorDumpDataLoader(filename=args.priordump, num_steps=args.steps, batch_size=args.batchsize, device=device, starting_index=args.steps*(ckpt['epoch'] if ckpt else 0))
4343

@@ -60,17 +60,24 @@
6060
datasets.append(train_test_split(*load_wine(return_X_y=True), test_size=0.5, random_state=42))
6161
datasets.append(train_test_split(*load_breast_cancer(return_X_y=True), test_size=0.5, random_state=42))
6262

63-
def epoch_callback(epoch, epoch_time, mean_loss, model):
64-
classifier = NanoTabPFNClassifier(model, device)
65-
scores = []
66-
for X_train, X_test, y_train, y_test in datasets:
67-
classifier.fit(X_train, y_train)
68-
pred = classifier.predict(X_test)
69-
scores.append(accuracy_score(y_test, pred))
70-
avg_score = sum(scores)/len(scores)
71-
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {mean_loss:5.2f} | avg accuracy {avg_score:.3f}', flush=True)
63+
64+
class EvaluationLoggerCallback(ConsoleLoggerCallback):
65+
def __init__(self, datasets):
66+
self.datasets = datasets
67+
68+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
69+
classifier = NanoTabPFNClassifier(model, device)
70+
scores = []
71+
for X_train, X_test, y_train, y_test in self.datasets:
72+
classifier.fit(X_train, y_train)
73+
pred = classifier.predict(X_test)
74+
scores.append(accuracy_score(y_test, pred))
75+
avg_score = sum(scores) / len(scores)
76+
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg accuracy {avg_score:.3f}',
77+
flush=True)
7278

7379

80+
callbacks = [EvaluationLoggerCallback(datasets)]
7481

7582
trained_model, loss = train(
7683
model=model,
@@ -80,7 +87,7 @@ def epoch_callback(epoch, epoch_time, mean_loss, model):
8087
accumulate_gradients=args.accumulate,
8188
lr=args.lr,
8289
device=device,
83-
epoch_callback=epoch_callback,
90+
callbacks=callbacks,
8491
ckpt=ckpt
8592
)
8693

pretrain_regression.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import torch
33

4+
from nanotabpfn.callbacks import ConsoleLoggerCallback
45
from nanotabpfn.priors import PriorDumpDataLoader
56
from nanotabpfn.model import NanoTabPFNModel
67
from nanotabpfn.train import train
@@ -30,15 +31,14 @@
3031
parser.add_argument("-loadcheckpoint", type=str, default=None, help="checkpoint from which to continue training")
3132
parser.add_argument("-n_buckets", type=int, default=100, help="number of buckets for the data loader")
3233

33-
3434
args = parser.parse_args()
3535

3636
set_randomness_seed(2402)
3737

3838
device = get_default_device()
39-
ckpt=None
39+
ckpt = None
4040
if args.loadcheckpoint:
41-
ckpt=torch.load(args.loadcheckpoint)
41+
ckpt = torch.load(args.loadcheckpoint)
4242

4343
prior = PriorDumpDataLoader(filename=args.priordump, num_steps=args.steps, batch_size=args.batchsize, device=device, starting_index=args.steps*(ckpt['epoch'] if ckpt else 0))
4444

@@ -69,17 +69,24 @@
6969
datasets = []
7070
datasets.append(train_test_split(*load_diabetes(return_X_y=True), test_size=0.5, random_state=42))
7171

72-
def epoch_callback(epoch, epoch_time, mean_loss, model, dist):
73-
regressor = NanoTabPFNRegressor(model, dist, device)
74-
scores = []
75-
for X_train, X_test, y_train, y_test in datasets:
76-
regressor.fit(X_train, y_train)
77-
pred = regressor.predict(X_test)
78-
scores.append(r2_score(y_test, pred))
79-
print(r2_score(y_test, pred))
80-
avg_score = sum(scores)/len(scores)
81-
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {mean_loss:5.2f} | avg r2 score {avg_score:.3f}', flush=True)
8272

73+
class EvaluationLoggerCallback(ConsoleLoggerCallback):
74+
def __init__(self, datasets):
75+
self.datasets = datasets
76+
77+
def on_epoch_end(self, epoch: int, epoch_time: float, loss: float, model, **kwargs):
78+
regressor = NanoTabPFNRegressor(model, dist, device)
79+
scores = []
80+
for X_train, X_test, y_train, y_test in datasets:
81+
regressor.fit(X_train, y_train)
82+
pred = regressor.predict(X_test)
83+
scores.append(r2_score(y_test, pred))
84+
avg_score = sum(scores) / len(scores)
85+
print(f'epoch {epoch:5d} | time {epoch_time:5.2f}s | mean loss {loss:5.2f} | avg r2 score {avg_score:.3f}',
86+
flush=True)
87+
88+
89+
callbacks = [EvaluationLoggerCallback(datasets)]
8390

8491
trained_model, loss = train(
8592
model=model,
@@ -89,7 +96,7 @@ def epoch_callback(epoch, epoch_time, mean_loss, model, dist):
8996
accumulate_gradients=args.accumulate,
9097
lr=args.lr,
9198
device=device,
92-
epoch_callback=epoch_callback,
99+
callbacks=callbacks,
93100
ckpt=ckpt
94101
)
95102

pyproject.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"
66
name = "nanotabpfn"
77
version = "0.0.1"
88
authors = [
9-
{ name="Alexander Pfefferle", email="[email protected]" },
9+
{ name = "Alexander Pfefferle", email = "[email protected]" },
1010
]
1111
description = "A Playground for Tabular Foundation Models"
1212
readme = "README.md"
@@ -25,6 +25,9 @@ dependencies = [
2525
"pfns==0.3.0",
2626
]
2727

28+
[project.optional-dependencies]
29+
wandb = ["wandb>=0.20"]
30+
tensorboard = ["tensorboard>=2.19"]
2831
[project.urls]
2932
Homepage = "https://github.com/PriorLabs/nanoTabPFN"
3033
Issues = "https://github.com/PriorLabs/nanoTabPFN/issues"

0 commit comments

Comments
 (0)