33
44from __future__ import annotations
55
6+ import copy
67import logging
7- import tempfile
88from functools import partial
9- from pathlib import Path
109from typing import Any
1110
1211import numpy as np
2423
2524from tabpfn import TabPFNClassifier
2625from tabpfn .finetune_utils import clone_model_for_evaluation
27- from tabpfn .model_loading import load_fitted_tabpfn_model , save_fitted_tabpfn_model
2826from tabpfn .utils import meta_dataset_collator
2927
3028# Configure logging to show INFO level messages (including validation metrics)
@@ -188,15 +186,15 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
188186 X_train , X_val , y_train , y_val = validation_splitter (X , y )
189187
190188 # Calculate the context size used during finetuning
191- context_size = min (self .n_inference_context_samples , len (y_train ))
192- print (f"Context size: { context_size } " )
189+ n_finetuning_fit_predict_context_samples = min (self .n_inference_context_samples , len (y_train ))
193190
191+ # Unpack kwargs to allow any TabPFNClassifier hyperparameter to be specified,
192+ # then override with required config values
194193 classifier_config = {
194+ ** self .kwargs ,
195195 "ignore_pretraining_limits" : True ,
196196 "device" : self .device ,
197- "n_estimators" : self .kwargs .get ("n_estimators" , 8 ),
198197 "random_state" : self .random_state ,
199- # inference_precision": torch.float32,
200198 }
201199
202200 # Initialize the base TabPFNClassifier
@@ -207,38 +205,15 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
207205 )
208206 # Required to access model parameters for the optimizer
209207 self .finetuned_classifier_ ._initialize_model_variables ()
208+ self .finetuned_classifier_ .softmax_temperature_ = self .finetuned_classifier_ .softmax_temperature
210209
211210 eval_config = {
212211 ** classifier_config ,
213212 "inference_config" : {
214- "SUBSAMPLE_SAMPLES" : context_size , # Passing this to the dataloader causes an error, so we set eval config separately from the classifier config.
213+ "SUBSAMPLE_SAMPLES" : n_finetuning_fit_predict_context_samples , # Passing this to the dataloader causes an error, so we set eval config separately from the classifier config.
215214 },
216215 }
217216
218- # Prepare data for the fine-tuning loop
219- # This splitter function will be applied to the training data to create
220- # (context, query) pairs for each step of the loop.
221-
222- training_splitter = partial (
223- train_test_split ,
224- test_size = self .finetune_split_ratio ,
225- random_state = self .random_state ,
226- )
227-
228- training_datasets = self .finetuned_classifier_ .get_preprocessed_datasets (
229- X_train ,
230- y_train ,
231- training_splitter ,
232- context_size ,
233- equal_split_size = False ,
234- )
235-
236- finetuning_dataloader = DataLoader (
237- training_datasets ,
238- batch_size = self .meta_batch_size ,
239- collate_fn = meta_dataset_collator ,
240- )
241-
242217 # Setup optimizer and loss function
243218 optimizer = Adam (
244219 self .finetuned_classifier_ .model_ .parameters (), # type: ignore
@@ -259,9 +234,34 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
259234 # Early stopping variables
260235 best_roc_auc = - np .inf
261236 patience_counter = 0
262- best_model_path = None
237+ best_model = None
263238
264239 for epoch in range (self .epochs ):
240+ # Regenerate datasets each epoch with a different random_state to ensure
241+ # diversity in context/query pairs across epochs. This prevents the model
242+ # from seeing the exact same splits in every epoch, which could reduce
243+ # training signal diversity.
244+ training_splitter = partial (
245+ train_test_split ,
246+ test_size = self .finetune_split_ratio ,
247+ random_state = self .random_state + epoch ,
248+ )
249+
250+ training_datasets = self .finetuned_classifier_ .get_preprocessed_datasets (
251+ X_train ,
252+ y_train ,
253+ training_splitter ,
254+ n_finetuning_fit_predict_context_samples ,
255+ equal_split_size = False ,
256+ )
257+
258+ finetuning_dataloader = DataLoader (
259+ training_datasets ,
260+ batch_size = self .meta_batch_size ,
261+ collate_fn = meta_dataset_collator ,
262+ shuffle = True ,
263+ )
264+
265265 progress_bar = tqdm (
266266 finetuning_dataloader ,
267267 desc = f"Finetuning Epoch { epoch + 1 } /{ self .epochs } " ,
@@ -274,7 +274,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
274274 cat_ixs ,
275275 confs ,
276276 ) in progress_bar :
277-
277+
278278 ctx = set (np .unique (y_context_batch ))
279279 qry = set (np .unique (y_query_batch ))
280280 if not qry .issubset (ctx ):
@@ -285,11 +285,11 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
285285
286286 if (
287287 X_context_batch [0 ].shape [1 ] + X_query_batch [0 ].shape [1 ]
288- != context_size
288+ != n_finetuning_fit_predict_context_samples
289289 ):
290290 actual_size = X_context_batch [0 ].shape [1 ] + X_query_batch [0 ].shape [1 ]
291291 logging .warning (
292- f"Skipping batch: total batch size { actual_size } does not match context size { context_size } "
292+ f"Skipping batch: total batch size { actual_size } does not match n_finetuning_fit_predict_context_samples { n_finetuning_fit_predict_context_samples } "
293293 )
294294 continue
295295
@@ -361,7 +361,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
361361 y_train , # pyright: ignore[reportArgumentType]
362362 X_val , # pyright: ignore[reportArgumentType]
363363 y_val , # pyright: ignore[reportArgumentType]
364- )
364+ )
365365
366366 logging .info (
367367 f"📊 Epoch { epoch + 1 } Evaluation | Val ROC: { roc_auc :.4f} , Val Log Loss: { log_loss_score :.4f} \n " ,
@@ -375,16 +375,8 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
375375 if roc_auc > best_roc_auc + self .min_delta :
376376 best_roc_auc = roc_auc
377377 patience_counter = 0
378- # Save the best model using TabPFN's official save function
379- with tempfile .NamedTemporaryFile (
380- suffix = ".tabpfn_fit" ,
381- delete = False ,
382- ) as tmp_file :
383- best_model_path = Path (tmp_file .name )
384- save_fitted_tabpfn_model (
385- self .finetuned_classifier_ ,
386- best_model_path ,
387- )
378+ # Save the best model
379+ best_model = copy .deepcopy (self .finetuned_classifier_ )
388380 else :
389381 patience_counter += 1
390382 logging .info (
@@ -394,28 +386,19 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
394386 if patience_counter >= self .patience :
395387 logging .info (
396388 f"🛑 Early stopping triggered. Best ROC AUC: { best_roc_auc :.4f} " ,
397- )
398- # Restore the best model using TabPFN's official load function
399- if best_model_path is not None :
400- self .finetuned_classifier_ = load_fitted_tabpfn_model (
401- best_model_path ,
402- device = self .device ,
403389 )
404- # Clean up the temporary file
405- best_model_path .unlink (missing_ok = True )
390+ # Restore the best model
391+ if best_model is not None :
392+ self .finetuned_classifier_ = best_model
406393 break
407394
408395 logging .info ("--- ✅ Fine-tuning Finished ---" )
409396
410- # Clean up temporary file if early stopping didn't trigger
411- if best_model_path is not None and best_model_path .exists ():
412- best_model_path .unlink (missing_ok = True )
413-
414397 finetuned_inference_classifier = clone_model_for_evaluation (
415398 self .finetuned_classifier_ , # type: ignore
416399 eval_config ,
417400 TabPFNClassifier ,
418- )
401+ )
419402 self .finetuned_inference_classifier_ = finetuned_inference_classifier
420403 self .finetuned_inference_classifier_ .fit_mode = "fit_preprocessors" # type: ignore
421404 self .finetuned_inference_classifier_ .fit (self .X_ , self .y_ ) # type: ignore
0 commit comments