Skip to content

Commit c8c0141

Browse files
committed
Fix non-diversity of training samples during finetuning
1 parent 877e65b commit c8c0141

File tree

2 files changed

+65
-76
lines changed

2 files changed

+65
-76
lines changed

examples/finetune/finetune_example.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,23 @@
1-
from sklearn.datasets import fetch_covtype
2-
from sklearn.model_selection import train_test_split
3-
from sklearn.metrics import log_loss, roc_auc_score
41
import numpy as np
52
import torch
6-
from tabpfn import TabPFNClassifier
3+
from sklearn.datasets import fetch_covtype
4+
from sklearn.metrics import log_loss, roc_auc_score
5+
from sklearn.model_selection import train_test_split
76

7+
from tabpfn import TabPFNClassifier
88
from tabpfn_extensions.finetune.finetune_classifier import FinetunedTabPFNClassifier
99

1010
# 1. Load and prepare the data
1111
# We use a small subset for a quick demonstration.
1212
print("--- 1. Loading Data ---")
1313
X_all, y_all = fetch_covtype(return_X_y=True, shuffle=True)
14-
X, y = X_all[:10000], y_all[:10000]
14+
X, y = X_all[:11000], y_all[:11000]
15+
16+
# df = pd.read_csv("/home/anurag_priorlabs_ai/tabpfn-extensions/PrudentialLifeInsuranceAssessment.csv")
1517

18+
# print(df.columns)
19+
# X = df.drop(columns=["Id", "Response"])
20+
# y = df["Response"]
1621
# Create a final hold-out test set. This is NOT used during fine-tuning.
1722
X_train, X_test, y_train, y_test = train_test_split(
1823
X, y, test_size=0.2, random_state=42, stratify=y
@@ -22,17 +27,16 @@
2227
# Calculate ROC AUC
2328
def calculate_roc_auc(y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
2429
if len(np.unique(y_true)) == 2:
25-
return roc_auc_score(y_true, y_pred_proba[:, 1])
26-
else:
27-
return roc_auc_score(y_true, y_pred_proba, multi_class="ovr", average="weighted")
30+
return roc_auc_score(y_true, y_pred_proba[:, 1]) # pyright: ignore[reportReturnType]
31+
return roc_auc_score(y_true, y_pred_proba, multi_class="ovr", average="weighted") # pyright: ignore[reportReturnType]
2832

2933
# 2. Initial model evaluation on test set
3034

31-
base_clf = TabPFNClassifier(device='cuda' if torch.cuda.is_available() else 'cpu', n_estimators=2)
35+
base_clf = TabPFNClassifier(device="cuda" if torch.cuda.is_available() else "cpu", n_estimators=2)
3236
base_clf.fit(X_train, y_train)
3337

3438
base_pred_proba = base_clf.predict_proba(X_test)
35-
roc_auc = calculate_roc_auc(y_test, base_pred_proba)
39+
roc_auc = calculate_roc_auc(y_test, base_pred_proba) # pyright: ignore[reportReturnType, reportArgumentType]
3640
log_loss_score = log_loss(y_test, base_pred_proba)
3741

3842
print(f"📊 Initial Test ROC: {roc_auc:.4f}")
@@ -43,25 +47,27 @@ def calculate_roc_auc(y_true: np.ndarray, y_pred_proba: np.ndarray) -> float:
4347

4448
# Instantiate the wrapper with your desired hyperparameters
4549
finetuned_clf = FinetunedTabPFNClassifier(
46-
device='cuda' if torch.cuda.is_available() else 'cpu',
50+
device="cuda" if torch.cuda.is_available() else "cpu",
4751
epochs=10,
48-
learning_rate=1e-5,
52+
learning_rate=1e-6,
4953
n_inference_context_samples=10_000,
5054
finetune_split_ratio=0.3,
5155
random_state=42,
5256
n_estimators=2,
53-
patience=3
57+
patience=3,
58+
ignore_pretraining_limits=True,
59+
grad_clip_value=1.0,
5460
)
5561

5662
# 4. Call .fit() to start the fine-tuning process on the training data
57-
finetuned_clf.fit(X_train, y_train)
63+
finetuned_clf.fit(X_train, y_train) # pyright: ignore[reportArgumentType]
5864
print("\n")
5965

6066
# 5. Evaluate the fine-tuned model
6167
print("--- 3. Evaluating Model on Held-out Test Set ---\n")
62-
y_pred_proba = finetuned_clf.predict_proba(X_test)
68+
y_pred_proba = finetuned_clf.predict_proba(X_test) # pyright: ignore[reportArgumentType]
6369

64-
roc_auc = calculate_roc_auc(y_test, y_pred_proba)
70+
roc_auc = calculate_roc_auc(y_test, y_pred_proba) # pyright: ignore[reportArgumentType]
6571
loss = log_loss(y_test, y_pred_proba)
6672

6773
print(f"📊 Final Test ROC: {roc_auc:.4f}")

src/tabpfn_extensions/finetune/finetune_classifier.py

Lines changed: 43 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
from __future__ import annotations
55

6+
import copy
67
import logging
7-
import tempfile
88
from functools import partial
9-
from pathlib import Path
109
from typing import Any
1110

1211
import numpy as np
@@ -24,7 +23,6 @@
2423

2524
from tabpfn import TabPFNClassifier
2625
from tabpfn.finetune_utils import clone_model_for_evaluation
27-
from tabpfn.model_loading import load_fitted_tabpfn_model, save_fitted_tabpfn_model
2826
from tabpfn.utils import meta_dataset_collator
2927

3028
# Configure logging to show INFO level messages (including validation metrics)
@@ -188,15 +186,15 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
188186
X_train, X_val, y_train, y_val = validation_splitter(X, y)
189187

190188
# Calculate the context size used during finetuning
191-
context_size = min(self.n_inference_context_samples, len(y_train))
192-
print(f"Context size: {context_size}")
189+
n_finetuning_fit_predict_context_samples = min(self.n_inference_context_samples, len(y_train))
193190

191+
# Unpack kwargs to allow any TabPFNClassifier hyperparameter to be specified,
192+
# then override with required config values
194193
classifier_config = {
194+
**self.kwargs,
195195
"ignore_pretraining_limits": True,
196196
"device": self.device,
197-
"n_estimators": self.kwargs.get("n_estimators", 8),
198197
"random_state": self.random_state,
199-
# inference_precision": torch.float32,
200198
}
201199

202200
# Initialize the base TabPFNClassifier
@@ -207,38 +205,15 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
207205
)
208206
# Required to access model parameters for the optimizer
209207
self.finetuned_classifier_._initialize_model_variables()
208+
self.finetuned_classifier_.softmax_temperature_ = self.finetuned_classifier_.softmax_temperature
210209

211210
eval_config = {
212211
**classifier_config,
213212
"inference_config": {
214-
"SUBSAMPLE_SAMPLES": context_size, # Passing this to the dataloader causes an error, so we set eval config separately from the classifier config.
213+
"SUBSAMPLE_SAMPLES": n_finetuning_fit_predict_context_samples, # Passing this to the dataloader causes an error, so we set eval config separately from the classifier config.
215214
},
216215
}
217216

218-
# Prepare data for the fine-tuning loop
219-
# This splitter function will be applied to the training data to create
220-
# (context, query) pairs for each step of the loop.
221-
222-
training_splitter = partial(
223-
train_test_split,
224-
test_size=self.finetune_split_ratio,
225-
random_state=self.random_state,
226-
)
227-
228-
training_datasets = self.finetuned_classifier_.get_preprocessed_datasets(
229-
X_train,
230-
y_train,
231-
training_splitter,
232-
context_size,
233-
equal_split_size=False,
234-
)
235-
236-
finetuning_dataloader = DataLoader(
237-
training_datasets,
238-
batch_size=self.meta_batch_size,
239-
collate_fn=meta_dataset_collator,
240-
)
241-
242217
# Setup optimizer and loss function
243218
optimizer = Adam(
244219
self.finetuned_classifier_.model_.parameters(), # type: ignore
@@ -259,9 +234,34 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
259234
# Early stopping variables
260235
best_roc_auc = -np.inf
261236
patience_counter = 0
262-
best_model_path = None
237+
best_model = None
263238

264239
for epoch in range(self.epochs):
240+
# Regenerate datasets each epoch with a different random_state to ensure
241+
# diversity in context/query pairs across epochs. This prevents the model
242+
# from seeing the exact same splits in every epoch, which could reduce
243+
# training signal diversity.
244+
training_splitter = partial(
245+
train_test_split,
246+
test_size=self.finetune_split_ratio,
247+
random_state=self.random_state + epoch,
248+
)
249+
250+
training_datasets = self.finetuned_classifier_.get_preprocessed_datasets(
251+
X_train,
252+
y_train,
253+
training_splitter,
254+
n_finetuning_fit_predict_context_samples,
255+
equal_split_size=False,
256+
)
257+
258+
finetuning_dataloader = DataLoader(
259+
training_datasets,
260+
batch_size=self.meta_batch_size,
261+
collate_fn=meta_dataset_collator,
262+
shuffle=True,
263+
)
264+
265265
progress_bar = tqdm(
266266
finetuning_dataloader,
267267
desc=f"Finetuning Epoch {epoch + 1}/{self.epochs}",
@@ -274,7 +274,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
274274
cat_ixs,
275275
confs,
276276
) in progress_bar:
277-
277+
278278
ctx = set(np.unique(y_context_batch))
279279
qry = set(np.unique(y_query_batch))
280280
if not qry.issubset(ctx):
@@ -285,11 +285,11 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
285285

286286
if (
287287
X_context_batch[0].shape[1] + X_query_batch[0].shape[1]
288-
!= context_size
288+
!= n_finetuning_fit_predict_context_samples
289289
):
290290
actual_size = X_context_batch[0].shape[1] + X_query_batch[0].shape[1]
291291
logging.warning(
292-
f"Skipping batch: total batch size {actual_size} does not match context size {context_size}"
292+
f"Skipping batch: total batch size {actual_size} does not match n_finetuning_fit_predict_context_samples {n_finetuning_fit_predict_context_samples}"
293293
)
294294
continue
295295

@@ -361,7 +361,7 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
361361
y_train, # pyright: ignore[reportArgumentType]
362362
X_val, # pyright: ignore[reportArgumentType]
363363
y_val, # pyright: ignore[reportArgumentType]
364-
)
364+
)
365365

366366
logging.info(
367367
f"📊 Epoch {epoch + 1} Evaluation | Val ROC: {roc_auc:.4f}, Val Log Loss: {log_loss_score:.4f}\n",
@@ -375,16 +375,8 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
375375
if roc_auc > best_roc_auc + self.min_delta:
376376
best_roc_auc = roc_auc
377377
patience_counter = 0
378-
# Save the best model using TabPFN's official save function
379-
with tempfile.NamedTemporaryFile(
380-
suffix=".tabpfn_fit",
381-
delete=False,
382-
) as tmp_file:
383-
best_model_path = Path(tmp_file.name)
384-
save_fitted_tabpfn_model(
385-
self.finetuned_classifier_,
386-
best_model_path,
387-
)
378+
# Save the best model
379+
best_model = copy.deepcopy(self.finetuned_classifier_)
388380
else:
389381
patience_counter += 1
390382
logging.info(
@@ -394,28 +386,19 @@ def fit(self, X: np.ndarray, y: np.ndarray) -> FinetunedTabPFNClassifier:
394386
if patience_counter >= self.patience:
395387
logging.info(
396388
f"🛑 Early stopping triggered. Best ROC AUC: {best_roc_auc:.4f}",
397-
)
398-
# Restore the best model using TabPFN's official load function
399-
if best_model_path is not None:
400-
self.finetuned_classifier_ = load_fitted_tabpfn_model(
401-
best_model_path,
402-
device=self.device,
403389
)
404-
# Clean up the temporary file
405-
best_model_path.unlink(missing_ok=True)
390+
# Restore the best model
391+
if best_model is not None:
392+
self.finetuned_classifier_ = best_model
406393
break
407394

408395
logging.info("--- ✅ Fine-tuning Finished ---")
409396

410-
# Clean up temporary file if early stopping didn't trigger
411-
if best_model_path is not None and best_model_path.exists():
412-
best_model_path.unlink(missing_ok=True)
413-
414397
finetuned_inference_classifier = clone_model_for_evaluation(
415398
self.finetuned_classifier_, # type: ignore
416399
eval_config,
417400
TabPFNClassifier,
418-
)
401+
)
419402
self.finetuned_inference_classifier_ = finetuned_inference_classifier
420403
self.finetuned_inference_classifier_.fit_mode = "fit_preprocessors" # type: ignore
421404
self.finetuned_inference_classifier_.fit(self.X_, self.y_) # type: ignore

0 commit comments

Comments
 (0)