-
Notifications
You must be signed in to change notification settings - Fork 50
finetuning wrapper #152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
finetuning wrapper #152
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @anuragg1209, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant enhancement by providing a dedicated wrapper for fine-tuning the TabPFN model. This allows users to adapt the powerful TabPFN model to their specific datasets more easily, integrating seamlessly with the familiar scikit-learn interface. The change aims to improve the model's applicability and performance on custom data by enabling targeted optimization.
Highlights
- New Fine-tuning Wrapper: Introduced a new
FinetunedTabPFNClassifierclass, providing a scikit-learn compatible API for fine-tuning the TabPFN model on specific datasets. - Fine-tuning Logic: The
FinetunedTabPFNClassifierencapsulates the fine-tuning loop, utilizing Adam optimizer and CrossEntropyLoss, and supports early stopping based on validation ROC AUC. - Example Usage: Added
finetune_example.pyto demonstrate how to load data, perform initial model evaluation, and then fine-tune and evaluate theFinetunedTabPFNClassifier. - Model Evaluation Utility: A helper function
evaluate_modelwas added to facilitate performance assessment during the fine-tuning process.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a fine-tuning wrapper for TabPFN that provides a scikit-learn compatible interface for fine-tuning the TabPFNClassifier on specific datasets. The wrapper encapsulates the fine-tuning loop with features like early stopping, validation tracking, and familiar .fit() and .predict() APIs.
Key changes:
- Implements
FinetunedTabPFNClassifierwith comprehensive fine-tuning capabilities including early stopping and validation monitoring - Adds evaluation utilities for model performance assessment during training
- Provides a complete example demonstrating the fine-tuning workflow on the covtype dataset
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
| src/tabpfn_extensions/finetune/finetune_classifier.py | Core implementation of the fine-tuning wrapper with early stopping, evaluation, and scikit-learn compatibility |
| examples/finetune/finetune_example.py | Example script demonstrating usage of the fine-tuning wrapper with before/after performance comparison |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a FinetunedTabPFNClassifier wrapper and an example script. The implementation is a solid foundation, but there are several opportunities for improvement regarding performance, robustness, and maintainability. My review focuses on addressing inefficiencies in the prediction methods, ensuring robust data splitting by using stratification, and highlighting dependencies on private library components. The example script can also be made more concise and efficient. The suggested changes will help make the new wrapper more performant and reliable.
|
Example output for covertype dataset: |
|
Thanks for the PR @anuragg1209, when looking at the logs that you shared, it seems like the fine-tuning is not really helping. Shall we still merge it though and do parameter tweaking later or should we figure this out first? |
Hi Klemens, it would be helpful to first review the code and ensure there's no bug in the wrapper that's preventing the fine-tuning from being effective, or if it's just the hyperparameter settings that need optimization. A thorough review might be beneficial before merging the code. Thanks! :) |
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
klemens-floege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice initial draft, maybe one general comment I would have could we do a BaseClass wit the main training functionality and the classifier inherits from this, I think this will make implementing the regressor much easier
…ing finetuning and inference
… update parameter names. Changed `meta_dataset_size` to `n_inference_context_samples`. Enhanced error handling in model evaluation and ensured context size consistency during fine-tuning and inference.
|
/gemini review |
| epochs: int = 5, | ||
| learning_rate: float = 1e-5, | ||
| n_inference_context_samples: int = 10_000, | ||
| finetune_split_ratio: float = 0.2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rename this to validation ratio. In my understanding this poart of the data is used for validation while finetuning, but finetuning is obviosu given the context
| finetune_split_ratio: float = 0.2, | |
| validation_ratio: float = 0.2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a fine-tuning wrapper for TabPFN, which is a great addition. The implementation is well-structured and includes a helpful example. My review focuses on improving robustness, fixing a critical data leakage issue, and enhancing maintainability. Key suggestions include addressing the hardcoded 'cuda' device, simplifying the early stopping logic, and avoiding the use of internal TabPFN APIs. I've also pointed out some minor inconsistencies in the example script to improve clarity.
| # Save the best model using TabPFN's official save function | ||
| with tempfile.NamedTemporaryFile( | ||
| suffix=".tabpfn_fit", | ||
| delete=False, | ||
| ) as tmp_file: | ||
| best_model_path = Path(tmp_file.name) | ||
| save_fitted_tabpfn_model( | ||
| self.finetuned_classifier_, | ||
| best_model_path, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The early stopping mechanism saves the best model to a temporary file on disk, which adds complexity with file I/O and manual cleanup. A simpler and more efficient approach would be to store a deep copy of the best model object in memory.
This can be implemented as follows:
- Add
import copyat the top of the file. - In
fit, initializebest_model = Noneinstead ofbest_model_path = None. - When an improvement is found, save a copy of the model:
if roc_auc > best_roc_auc + self.min_delta: best_roc_auc = roc_auc patience_counter = 0 best_model = copy.deepcopy(self.finetuned_classifier_)
- When restoring, assign the saved model, and remove the file cleanup logic:
if patience_counter >= self.patience: # ... logging ... if best_model is not None: self.finetuned_classifier_ = best_model break
This avoids disk I/O and the need to manage temporary files.
| training_splitter = partial( | ||
| train_test_split, | ||
| test_size=self.finetune_split_ratio, | ||
| random_state=self.random_state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be a problem and lead to significantly degraded performance?
Static Meta-Dataset Splits: The training_splitter uses a fixed random_state.
The Problem: This means that in every epoch, the model sees the exact same context/query pairs drawn from the training data. This lack of stochasticity can reduce the diversity of the training signal.
Suggestion: For potentially more robust training, the splits could be regenerated with a different seed each epoch. This could be achieved by creating a new partial function for the splitter inside the epoch loop, using a different random_state each time (e.g., self.random_state + epoch).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this is actually randomized already though!
| X_train, X_val, y_train, y_val = validation_splitter(X, y) | ||
|
|
||
| # Calculate the context size used during finetuning | ||
| context_size = min(self.n_inference_context_samples, len(y_train)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we'd actually want to make this self.n_inference_context_samples/(1-self.finetune_split_ratio). This way we finetune with a trainset of self.n_inference_context_samples right? Or is that already the case? My worry is we are splitting out the valid set fro. this again making it smaller in finetuning vs final inference
| X_train, X_val, y_train, y_val = validation_splitter(X, y) | ||
|
|
||
| # Calculate the context size used during finetuning | ||
| context_size = min(self.n_inference_context_samples, len(y_train)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| context_size = min(self.n_inference_context_samples, len(y_train)) | |
| n_finetuning_fit_predict_context_samples = min(self.n_inference_context_samples, len(y_train)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very long name but it would clarify that this is fit and predict and it would make consistent with the otehr context samples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made some minor comments and would also add these small points:
On CUDA, wrap the forward/loss/backward with torch.cuda.amp.autocast() and GradScaler; add gradient clipping (e.g., torch.nn.utils.clip_grad_norm_(..., 1.0)).
Learning Rate Schedulers: Fine-tuning often benefits greatly from a learning rate that changes over time (e.g., cosine annealing + warmup?). You could add a scheduler parameter to init that accepts a PyTorch scheduler instance, and then call scheduler.step() after each epoch.
| X_context_batch[0].shape[1] + X_query_batch[0].shape[1] | ||
| != context_size | ||
| ): | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please log if this happens
| ctx = set(np.unique(y_context_batch)) | ||
| qry = set(np.unique(y_query_batch)) | ||
| if not qry.issubset(ctx): | ||
| continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we log if this happens
| classifier_config = { | ||
| "ignore_pretraining_limits": True, | ||
| "device": self.device, | ||
| "n_estimators": self.kwargs.get("n_estimators", 8), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just unpack kwargs here directly? Would allow us to specify any hyperparameter
…NClassifier - Added a cosine annealing learning rate scheduler. - Updated logging to provide warnings for skipped batches during fine-tuning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks @anuragg1209 !
I approved so that you can merge later, but would be great if you could implement the (mostly stylistic) things I commented on.
@klemens-floege @noahho , we decided to get this in sooner rather than later so we can iterate on it more easily; hope that's fine.
klemens-floege
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM; I am removing my requested changes in order not to block.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
…ing finetuning and inference
… update parameter names. Changed `meta_dataset_size` to `n_inference_context_samples`. Enhanced error handling in model evaluation and ensured context size consistency during fine-tuning and inference.
…NClassifier - Added a cosine annealing learning rate scheduler. - Updated logging to provide warnings for skipped batches during fine-tuning.
c8c0141 to
871ce0e
Compare
PR Paused — Development RelocatedHi, further development of this feature is continuing in a separate private repository. This PR will remain open in its current state for now and will be updated later. |
Hi @noahho & @klemens-floege,
Below is the code for the fine-tuning wrapper. Could you please review it when you get a chance? I’d appreciate any suggestions for improvements or changes to make it better. Thanks a lot!