Skip to content

Conversation

@anuragg1209
Copy link

Hi @noahho & @klemens-floege,

Below is the code for the fine-tuning wrapper. Could you please review it when you get a chance? I’d appreciate any suggestions for improvements or changes to make it better. Thanks a lot!

Copilot AI review requested due to automatic review settings August 29, 2025 10:11
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @anuragg1209, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant enhancement by providing a dedicated wrapper for fine-tuning the TabPFN model. This allows users to adapt the powerful TabPFN model to their specific datasets more easily, integrating seamlessly with the familiar scikit-learn interface. The change aims to improve the model's applicability and performance on custom data by enabling targeted optimization.

Highlights

  • New Fine-tuning Wrapper: Introduced a new FinetunedTabPFNClassifier class, providing a scikit-learn compatible API for fine-tuning the TabPFN model on specific datasets.
  • Fine-tuning Logic: The FinetunedTabPFNClassifier encapsulates the fine-tuning loop, utilizing Adam optimizer and CrossEntropyLoss, and supports early stopping based on validation ROC AUC.
  • Example Usage: Added finetune_example.py to demonstrate how to load data, perform initial model evaluation, and then fine-tune and evaluate the FinetunedTabPFNClassifier.
  • Model Evaluation Utility: A helper function evaluate_model was added to facilitate performance assessment during the fine-tuning process.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a fine-tuning wrapper for TabPFN that provides a scikit-learn compatible interface for fine-tuning the TabPFNClassifier on specific datasets. The wrapper encapsulates the fine-tuning loop with features like early stopping, validation tracking, and familiar .fit() and .predict() APIs.

Key changes:

  • Implements FinetunedTabPFNClassifier with comprehensive fine-tuning capabilities including early stopping and validation monitoring
  • Adds evaluation utilities for model performance assessment during training
  • Provides a complete example demonstrating the fine-tuning workflow on the covtype dataset

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
src/tabpfn_extensions/finetune/finetune_classifier.py Core implementation of the fine-tuning wrapper with early stopping, evaluation, and scikit-learn compatibility
examples/finetune/finetune_example.py Example script demonstrating usage of the fine-tuning wrapper with before/after performance comparison

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a FinetunedTabPFNClassifier wrapper and an example script. The implementation is a solid foundation, but there are several opportunities for improvement regarding performance, robustness, and maintainability. My review focuses on addressing inefficiencies in the prediction methods, ensuring robust data splitting by using stratification, and highlighting dependencies on private library components. The example script can also be made more concise and efficient. The suggested changes will help make the new wrapper more performant and reliable.

@anuragg1209
Copy link
Author

Example output for covertype dataset:

--- 1. Loading Data ---
Data split: 8000 training samples, 2000 test samples.

📊 Initial Test ROC: 0.9613
📊 Initial Test Log Loss: 0.3597

--- 2. Initializing and Fitting Model ---

--- 🚀 Starting Fine-tuning ---
Finetuning Epoch 1/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.52it/s, loss=0.6835]
📊 Epoch 1 Evaluation | Val ROC: 0.9395, Val Log Loss: 0.4755

Finetuning Epoch 2/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.73it/s, loss=0.6981]
📊 Epoch 2 Evaluation | Val ROC: 0.9397, Val Log Loss: 0.4790

Finetuning Epoch 3/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.77it/s, loss=0.5625]
📊 Epoch 3 Evaluation | Val ROC: 0.9398, Val Log Loss: 0.4771

Finetuning Epoch 4/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.78it/s, loss=0.6195]
📊 Epoch 4 Evaluation | Val ROC: 0.9400, Val Log Loss: 0.4759

Finetuning Epoch 5/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.74it/s, loss=0.5666]
📊 Epoch 5 Evaluation | Val ROC: 0.9402, Val Log Loss: 0.4718

Finetuning Epoch 6/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.75it/s, loss=0.6196]
📊 Epoch 6 Evaluation | Val ROC: 0.9400, Val Log Loss: 0.4681

⚠️  No improvement for 1 epochs. Best ROC AUC: 0.9402
Finetuning Epoch 7/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.47it/s, loss=0.5549]
📊 Epoch 7 Evaluation | Val ROC: 0.9403, Val Log Loss: 0.4672

⚠️  No improvement for 2 epochs. Best ROC AUC: 0.9402
Finetuning Epoch 8/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.72it/s, loss=0.5938]
📊 Epoch 8 Evaluation | Val ROC: 0.9406, Val Log Loss: 0.4679

Finetuning Epoch 9/10: 100%|███████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.75it/s, loss=0.5921]
📊 Epoch 9 Evaluation | Val ROC: 0.9407, Val Log Loss: 0.4681

⚠️  No improvement for 1 epochs. Best ROC AUC: 0.9406
Finetuning Epoch 10/10: 100%|██████████████████████████████████████████████████████████████████████████████████| 6/6 [00:02<00:00,  2.70it/s, loss=0.5994]
📊 Epoch 10 Evaluation | Val ROC: 0.9405, Val Log Loss: 0.4663

⚠️  No improvement for 2 epochs. Best ROC AUC: 0.9406
--- ✅ Fine-tuning Finished ---


--- 3. Evaluating Model on Held-out Test Set ---

📊 Final Test ROC: 0.9571
📊 Final Test Log Loss: 0.3706

@klemens-floege
Copy link
Contributor

Thanks for the PR @anuragg1209, when looking at the logs that you shared, it seems like the fine-tuning is not really helping. Shall we still merge it though and do parameter tweaking later or should we figure this out first?

@anuragg1209
Copy link
Author

anuragg1209 commented Aug 29, 2025

Thanks for the PR @anuragg1209, when looking at the logs that you shared, it seems like the fine-tuning is not really helping. Shall we still merge it though and do parameter tweaking later or should we figure this out first?

Hi Klemens, it would be helpful to first review the code and ensure there's no bug in the wrapper that's preventing the fine-tuning from being effective, or if it's just the hyperparameter settings that need optimization. A thorough review might be beneficial before merging the code. Thanks! :)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
noahho
noahho previously requested changes Aug 29, 2025
Copy link
Contributor

@klemens-floege klemens-floege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice initial draft, maybe one general comment I would have could we do a BaseClass wit the main training functionality and the classifier inherits from this, I think this will make implementing the regressor much easier

anuragg1209 and others added 7 commits August 29, 2025 19:13
… update parameter names. Changed `meta_dataset_size` to `n_inference_context_samples`. Enhanced error handling in model evaluation and ensured context size consistency during fine-tuning and inference.
@noahho
Copy link
Contributor

noahho commented Oct 1, 2025

/gemini review

epochs: int = 5,
learning_rate: float = 1e-5,
n_inference_context_samples: int = 10_000,
finetune_split_ratio: float = 0.2,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename this to validation ratio. In my understanding this poart of the data is used for validation while finetuning, but finetuning is obviosu given the context

Suggested change
finetune_split_ratio: float = 0.2,
validation_ratio: float = 0.2,

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fine-tuning wrapper for TabPFN, which is a great addition. The implementation is well-structured and includes a helpful example. My review focuses on improving robustness, fixing a critical data leakage issue, and enhancing maintainability. Key suggestions include addressing the hardcoded 'cuda' device, simplifying the early stopping logic, and avoiding the use of internal TabPFN APIs. I've also pointed out some minor inconsistencies in the example script to improve clarity.

Comment on lines 315 to 324
# Save the best model using TabPFN's official save function
with tempfile.NamedTemporaryFile(
suffix=".tabpfn_fit",
delete=False,
) as tmp_file:
best_model_path = Path(tmp_file.name)
save_fitted_tabpfn_model(
self.finetuned_classifier_,
best_model_path,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The early stopping mechanism saves the best model to a temporary file on disk, which adds complexity with file I/O and manual cleanup. A simpler and more efficient approach would be to store a deep copy of the best model object in memory.

This can be implemented as follows:

  1. Add import copy at the top of the file.
  2. In fit, initialize best_model = None instead of best_model_path = None.
  3. When an improvement is found, save a copy of the model:
    if roc_auc > best_roc_auc + self.min_delta:
        best_roc_auc = roc_auc
        patience_counter = 0
        best_model = copy.deepcopy(self.finetuned_classifier_)
  4. When restoring, assign the saved model, and remove the file cleanup logic:
    if patience_counter >= self.patience:
        # ... logging ...
        if best_model is not None:
            self.finetuned_classifier_ = best_model
        break

This avoids disk I/O and the need to manage temporary files.

training_splitter = partial(
train_test_split,
test_size=self.finetune_split_ratio,
random_state=self.random_state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a problem and lead to significantly degraded performance?

Static Meta-Dataset Splits: The training_splitter uses a fixed random_state.

The Problem: This means that in every epoch, the model sees the exact same context/query pairs drawn from the training data. This lack of stochasticity can reduce the diversity of the training signal.

Suggestion: For potentially more robust training, the splits could be regenerated with a different seed each epoch. This could be achieved by creating a new partial function for the splitter inside the epoch loop, using a different random_state each time (e.g., self.random_state + epoch).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is actually randomized already though!

X_train, X_val, y_train, y_val = validation_splitter(X, y)

# Calculate the context size used during finetuning
context_size = min(self.n_inference_context_samples, len(y_train))
Copy link
Contributor

@noahho noahho Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd actually want to make this self.n_inference_context_samples/(1-self.finetune_split_ratio). This way we finetune with a trainset of self.n_inference_context_samples right? Or is that already the case? My worry is we are splitting out the valid set fro. this again making it smaller in finetuning vs final inference

X_train, X_val, y_train, y_val = validation_splitter(X, y)

# Calculate the context size used during finetuning
context_size = min(self.n_inference_context_samples, len(y_train))
Copy link
Contributor

@noahho noahho Oct 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
context_size = min(self.n_inference_context_samples, len(y_train))
n_finetuning_fit_predict_context_samples = min(self.n_inference_context_samples, len(y_train))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very long name but it would clarify that this is fit and predict and it would make consistent with the otehr context samples

Copy link
Contributor

@noahho noahho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made some minor comments and would also add these small points:
On CUDA, wrap the forward/loss/backward with torch.cuda.amp.autocast() and GradScaler; add gradient clipping (e.g., torch.nn.utils.clip_grad_norm_(..., 1.0)).

Learning Rate Schedulers: Fine-tuning often benefits greatly from a learning rate that changes over time (e.g., cosine annealing + warmup?). You could add a scheduler parameter to init that accepts a PyTorch scheduler instance, and then call scheduler.step() after each epoch.

X_context_batch[0].shape[1] + X_query_batch[0].shape[1]
!= context_size
):
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please log if this happens

ctx = set(np.unique(y_context_batch))
qry = set(np.unique(y_query_batch))
if not qry.issubset(ctx):
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we log if this happens

classifier_config = {
"ignore_pretraining_limits": True,
"device": self.device,
"n_estimators": self.kwargs.get("n_estimators", 8),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just unpack kwargs here directly? Would allow us to specify any hyperparameter

…NClassifier

- Added a cosine annealing learning rate scheduler.
- Updated logging to provide warnings for skipped batches during fine-tuning.
@anuragg1209 anuragg1209 requested a review from a team as a code owner November 11, 2025 15:15
@anuragg1209 anuragg1209 requested review from priorphil and removed request for a team November 11, 2025 15:15
@anuragg1209 anuragg1209 requested review from bejaeger and removed request for priorphil November 13, 2025 09:17
Copy link
Contributor

@bejaeger bejaeger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks @anuragg1209 !
I approved so that you can merge later, but would be great if you could implement the (mostly stylistic) things I commented on.
@klemens-floege @noahho , we decided to get this in sooner rather than later so we can iterate on it more easily; hope that's fine.

Copy link
Contributor

@klemens-floege klemens-floege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM; I am removing my requested changes in order not to block.

@bejaeger bejaeger requested review from noahho and removed request for noahho November 13, 2025 17:13
anuragg1209 and others added 11 commits November 14, 2025 00:05
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
… update parameter names. Changed `meta_dataset_size` to `n_inference_context_samples`. Enhanced error handling in model evaluation and ensured context size consistency during fine-tuning and inference.
…NClassifier

- Added a cosine annealing learning rate scheduler.
- Updated logging to provide warnings for skipped batches during fine-tuning.
@anuragg1209
Copy link
Author

PR Paused — Development Relocated

Hi, further development of this feature is continuing in a separate private repository. This PR will remain open in its current state for now and will be updated later.

@bejaeger bejaeger marked this pull request as draft November 14, 2025 08:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants