22from torch import nn
33import time
44from torch .utils .data import DataLoader
5- from typing import Tuple , Dict , Callable
5+ from typing import Dict
66from pfns .bar_distribution import FullSupportBarDistribution
77import schedulefree
88
9+ from nanotabpfn .callbacks import Callback
910from nanotabpfn .model import NanoTabPFNModel
1011from nanotabpfn .utils import get_default_device
1112
12- def train (model : NanoTabPFNModel , prior : DataLoader , criterion : nn .CrossEntropyLoss | FullSupportBarDistribution , epochs : int ,
13- accumulate_gradients : int = 1 , lr : float = 1e-4 , device : torch .device = None ,
14- epoch_callback : Callable [[int , float , float , NanoTabPFNModel , FullSupportBarDistribution | None ], None ] = None , ckpt : Dict [str , torch .Tensor ] = None ):
13+
14+ def train (model : NanoTabPFNModel , prior : DataLoader , criterion : nn .CrossEntropyLoss | FullSupportBarDistribution ,
15+ epochs : int , accumulate_gradients : int = 1 , lr : float = 1e-4 , device : torch .device = None ,
16+ callbacks : list [Callback ]= None , ckpt : Dict [str , torch .Tensor ] = None ):
1517 """
1618 Trains our model on the given prior using the given criterion.
1719
@@ -22,14 +24,17 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
2224 epochs: (int) the number of epochs we train for, the number of steps that constitute an epoch are decided by the prior
2325 accumulate_gradients: (int) the number of gradients to accumulate before updating the weights
2426 device: (torch.device) the device we are using
25- epoch_callback: (Callable[[int, float, float, NanoTabPFNModel], None]) optional callback function that will be called
26- at the end of each epoch with the current epoch, epoch duration, mean loss, and the model,
27- intended to be used for logging/validation/evaluation
27+ callbacks: A list of callback instances to execute at the end of each epoch. These can be used for
28+ logging, validation, or other custom actions.
29+ ckpt (Dict[str, torch.Tensor], optional): A checkpoint dictionary containing the model and optimizer states,
30+ as well as the last completed epoch. If provided, training resumes from this checkpoint.
2831
2932 Returns:
3033 (torch.Tensor) a tensor of shape (num_rows, batch_size, num_features, embedding_size)
3134 """
3235 # print(f"Using a Transformer with {sum(p.numel() for p in model.parameters())/1000/1000:.{2}f} M parameters")
36+ if callbacks is None :
37+ callbacks = []
3338 if not device :
3439 device = get_default_device ()
3540 model .to (device )
@@ -41,8 +46,8 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
4146 assert prior .num_steps % accumulate_gradients == 0 , 'num_steps must be divisible by accumulate_gradients'
4247
4348 try :
44- for epoch in range (ckpt ['epoch' ]+ 1 if ckpt else 1 , epochs + 1 ):
45- start_time = time .time ()
49+ for epoch in range (ckpt ['epoch' ] + 1 if ckpt else 1 , epochs + 1 ):
50+ epoch_start_time = time .time ()
4651 model .train () # Turn on the train mode
4752 optimizer .train ()
4853 total_loss = 0.
@@ -81,12 +86,15 @@ def train(model: NanoTabPFNModel, prior: DataLoader, criterion: nn.CrossEntropyL
8186 }
8287 torch .save (training_state , 'latest_checkpoint.pth' )
8388
84- if epoch_callback :
89+ for callback in callbacks :
8590 if type (criterion ) is FullSupportBarDistribution :
86- epoch_callback (epoch , end_time - start_time , mean_loss , model , dist = criterion )
91+ callback . on_epoch_end (epoch , end_time - epoch_start_time , mean_loss , model , dist = criterion )
8792 else :
88- epoch_callback (epoch , end_time - start_time , mean_loss , model )
93+ callback . on_epoch_end (epoch , end_time - epoch_start_time , mean_loss , model )
8994 except KeyboardInterrupt :
9095 pass
96+ finally :
97+ for callback in callbacks :
98+ callback .close ()
9199
92100 return model , total_loss
0 commit comments