diff --git a/README.md b/README.md index 5fefe059..4d01e578 100644 --- a/README.md +++ b/README.md @@ -20,16 +20,21 @@ - **rf_pfn**: Combine TabPFN with decision trees and random forests - **unsupervised**: Data generation and outlier detection - **embedding**: Get TabPFNs internal dense sample embeddings +- **tabpfgen_datasynthesizer**: Synthetic tabular data generation with TabPFGen Detailed documentation for each extension is available in the respective module directories. ## āš™ļø Installation ```bash -# Clone and install the repository +# Clone and install the repository (Python 3.9+ compatible) pip install "tabpfn-extensions[all] @ git+https://github.com/PriorLabs/tabpfn-extensions.git" + +# Add TabPFGen Data Synthesizer (requires Python 3.10+) +pip install "tabpfn-extensions[all, tabpfgen_datasynthesizer] @ git+https://github.com/PriorLabs/tabpfn-extensions.git" ``` + ### šŸ”„ Backend Options TabPFN Extensions works with two TabPFN implementations: diff --git a/examples/tabpfgen_datasynthesizer/basic_classification_example.py b/examples/tabpfgen_datasynthesizer/basic_classification_example.py new file mode 100644 index 00000000..fd2fc3f0 --- /dev/null +++ b/examples/tabpfgen_datasynthesizer/basic_classification_example.py @@ -0,0 +1,86 @@ +"""Basic Classification Example with TabPFGen Data Synthesizer + +This example demonstrates how to use TabPFGen for synthetic data generation +in classification tasks, leveraging the actual TabPFGen package features. +""" + +from sklearn.datasets import load_breast_cancer +from sklearn.model_selection import train_test_split + +# Import TabPFN Extensions +from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer +from tabpfn_extensions.tabpfgen_datasynthesizer.utils import analyze_class_distribution + + +def main(): + """Run basic classification example.""" + print("=== TabPFGen Classification Example ===\n") + + # Load breast cancer dataset + print("Loading breast cancer dataset...") + X, y = load_breast_cancer(return_X_y=True) + feature_names = load_breast_cancer().feature_names + + # Split data + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.3, random_state=42, stratify=y + ) + + print(f"Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features") + print(f"Test data: {X_test.shape[0]} samples") + + # Analyze original distribution + analyze_class_distribution(y_train, "Original Training Data") + + # Initialize TabPFGen synthesizer + print("\nInitializing TabPFGen synthesizer...") + synthesizer = TabPFNDataSynthesizer( + n_sgld_steps=300, # Reduced for faster demo + device="auto", + ) + + # Generate synthetic data using TabPFGen's built-in methods + print("\nGenerating synthetic classification data...") + n_synthetic = 200 + X_synth, y_synth = synthesizer.generate_classification( + X_train, + y_train, + n_samples=n_synthetic, + balance_classes=True, # This balances only the synthetic samples + visualize=True, # Use TabPFGen's built-in visualization + feature_names=list(feature_names), + ) + + print(f"\nGenerated {len(X_synth)} synthetic samples") + analyze_class_distribution(y_synth, "Synthetic Data") + + # Combine original and synthetic data + from tabpfn_extensions.tabpfgen_datasynthesizer.utils import combine_datasets + + X_augmented, y_augmented = combine_datasets( + X_train, y_train, X_synth, y_synth, strategy="append" + ) + + analyze_class_distribution(y_augmented, "Augmented Training Data") + + print("\nāœ… Synthetic data generation completed successfully!") + + # Calculate quality metrics + from tabpfn_extensions.tabpfgen_datasynthesizer.utils import ( + calculate_synthetic_quality_metrics, + ) + + print("\n" + "=" * 60) + print("SYNTHETIC DATA QUALITY METRICS") + print("=" * 60) + + quality_metrics = calculate_synthetic_quality_metrics( + X_train, X_synth, y_train, y_synth + ) + + for metric, value in quality_metrics.items(): + print(f"{metric}: {value:.4f}") + + +if __name__ == "__main__": + main() diff --git a/examples/tabpfgen_datasynthesizer/basic_regression_example.py b/examples/tabpfgen_datasynthesizer/basic_regression_example.py new file mode 100644 index 00000000..e648a2b5 --- /dev/null +++ b/examples/tabpfgen_datasynthesizer/basic_regression_example.py @@ -0,0 +1,113 @@ +"""Basic Regression Example with TabPFGen Data Synthesizer + +This example demonstrates how to use TabPFGen for synthetic data generation +in regression tasks, using TabPFGen's built-in features. +""" + +import numpy as np +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split + +# Import TabPFN Extensions +from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer +from tabpfn_extensions.tabpfgen_datasynthesizer.utils import ( + calculate_synthetic_quality_metrics, +) + + +def main(): + """Run basic regression example.""" + print("=== TabPFGen Regression Example ===\n") + + # Load diabetes dataset + print("Loading diabetes dataset...") + X, y = load_diabetes(return_X_y=True) + feature_names = load_diabetes().feature_names + + # Split data + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.3, random_state=42 + ) + + print(f"Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features") + print(f"Test data: {X_test.shape[0]} samples") + print(f"Target range: [{y_train.min():.1f}, {y_train.max():.1f}]") + + # Initialize TabPFGen synthesizer + print("\nInitializing TabPFGen synthesizer...") + synthesizer = TabPFNDataSynthesizer( + n_sgld_steps=300, # Good balance for regression + device="auto", + ) + + # Generate synthetic regression data + print("\nGenerating synthetic regression data...") + n_synthetic = 150 + X_synth, y_synth = synthesizer.generate_regression( + X_train, + y_train, + n_samples=n_synthetic, + use_quantiles=True, # Important for regression quality + visualize=True, # Use TabPFGen's built-in visualization + feature_names=list(feature_names), + ) + + print(f"\nGenerated {len(X_synth)} synthetic samples") + print(f"Synthetic target range: [{y_synth.min():.1f}, {y_synth.max():.1f}]") + + # Combine original and synthetic data + from tabpfn_extensions.tabpfgen_datasynthesizer.utils import combine_datasets + + X_augmented, y_augmented = combine_datasets( + X_train, y_train, X_synth, y_synth, strategy="append" + ) + + print(f"Combined dataset: {len(X_augmented)} samples") + print(f"Combined target range: [{y_augmented.min():.1f}, {y_augmented.max():.1f}]") + + # Calculate quality metrics + print("\n" + "=" * 60) + print("SYNTHETIC DATA QUALITY METRICS") + print("=" * 60) + + quality_metrics = calculate_synthetic_quality_metrics( + X_train, X_synth, y_train, y_synth + ) + + print("\nFeature quality metrics:") + for metric, value in quality_metrics.items(): + print(f"{metric}: {value:.4f}") + + # Statistical comparison + print("\nStatistical comparison:") + print(f"Original data - Mean: {np.mean(X_train):.3f}, Std: {np.std(X_train):.3f}") + print(f"Synthetic data - Mean: {np.mean(X_synth):.3f}, Std: {np.std(X_synth):.3f}") + print("Target correlation preservation:") + + # Check target correlations + orig_target_corr = [] + synth_target_corr = [] + + for i in range(X_train.shape[1]): + orig_corr = np.corrcoef(X_train[:, i], y_train)[0, 1] + synth_corr = np.corrcoef(X_synth[:, i], y_synth)[0, 1] + orig_target_corr.append(orig_corr) + synth_target_corr.append(synth_corr) + + print( + f"Average target correlation - Original: {np.mean(np.abs(orig_target_corr)):.3f}" + ) + print( + f"Average target correlation - Synthetic: {np.mean(np.abs(synth_target_corr)):.3f}" + ) + + correlation_preservation = 1 - np.mean( + np.abs(np.array(orig_target_corr) - np.array(synth_target_corr)) + ) + print(f"Correlation preservation score: {correlation_preservation:.3f}") + + print("\nāœ… Synthetic regression data generation completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/examples/tabpfgen_datasynthesizer/class_balancing_demo.py b/examples/tabpfgen_datasynthesizer/class_balancing_demo.py new file mode 100644 index 00000000..1eb7a84e --- /dev/null +++ b/examples/tabpfgen_datasynthesizer/class_balancing_demo.py @@ -0,0 +1,147 @@ +"""Dataset Balancing Demo with TabPFGen's balance_dataset Method + +This example demonstrates the new balance_dataset method in TabPFGen v0.1.3+ +for automatically balancing imbalanced classification datasets. +""" + + +from sklearn.datasets import make_classification +from sklearn.model_selection import train_test_split + +# Import TabPFN Extensions +from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer + +# Calculate quality metrics for both approaches +from tabpfn_extensions.tabpfgen_datasynthesizer.utils import ( + analyze_class_distribution, + calculate_synthetic_quality_metrics, +) + + +def create_imbalanced_dataset(): + """Create a highly imbalanced classification dataset.""" + X, y = make_classification( + n_samples=1000, + n_features=20, + n_informative=15, + n_redundant=5, + n_classes=3, + weights=[0.7, 0.2, 0.1], # Highly imbalanced: 70%, 20%, 10% + random_state=42, + ) + return X, y + + +def main(): + """Run dataset balancing demonstration.""" + print("=== TabPFGen Dataset Balancing Demo ===\n") + + # Create imbalanced dataset + print("Creating highly imbalanced dataset...") + X, y = create_imbalanced_dataset() + + # Split data + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.3, random_state=42, stratify=y + ) + + print(f"Training data: {X_train.shape[0]} samples, {X_train.shape[1]} features") + print(f"Test data: {X_test.shape[0]} samples") + + # Analyze original imbalanced distribution + original_analysis = analyze_class_distribution( + y_train, "Original Imbalanced Training Data" + ) + + # Initialize TabPFGen synthesizer + print("\nInitializing TabPFGen synthesizer...") + synthesizer = TabPFNDataSynthesizer( + n_sgld_steps=400, # Good balance of quality and speed + device="auto", + ) + + print("\n" + "=" * 70) + print("AUTOMATIC BALANCING (to majority class size)") + print("=" * 70) + + # Use TabPFGen's balance_dataset method - automatic balancing + X_synth_auto, y_synth_auto, X_balanced_auto, y_balanced_auto = ( + synthesizer.balance_dataset( + X_train, + y_train, + visualize=True, # Use TabPFGen's built-in visualization + feature_names=[f"feature_{i}" for i in range(X_train.shape[1])], + ) + ) + + balanced_analysis_auto = analyze_class_distribution( + y_balanced_auto, "Auto-Balanced Dataset" + ) + + print("\n" + "=" * 70) + print("CUSTOM TARGET BALANCING (1000 samples per class)") + print("=" * 70) + + # Use TabPFGen's balance_dataset method - custom target + X_synth_custom, y_synth_custom, X_balanced_custom, y_balanced_custom = ( + synthesizer.balance_dataset( + X_train, + y_train, + target_per_class=1000, # Custom target + visualize=True, + feature_names=[f"feature_{i}" for i in range(X_train.shape[1])], + ) + ) + + balanced_analysis_custom = analyze_class_distribution( + y_balanced_custom, "Custom-Balanced Dataset (target=1000)" + ) + + balanced_analysis_custom = analyze_class_distribution( + y_balanced_custom, "Custom-Balanced Dataset (target=1000)" + ) + + # Quality analysis + print("\n" + "=" * 70) + print("BALANCING EFFECTIVENESS SUMMARY") + print("=" * 70) + + print( + f"\nOriginal dataset imbalance ratio: {original_analysis['imbalance_ratio']:.1f}:1" + ) + print( + f"Auto-balanced imbalance ratio: {balanced_analysis_auto['imbalance_ratio']:.1f}:1" + ) + print( + f"Custom-balanced imbalance ratio: {balanced_analysis_custom['imbalance_ratio']:.1f}:1" + ) + + print("\nData size summary:") + print(f"Original training: {len(X_train)} samples") + print( + f"Auto-balanced: {len(X_balanced_auto)} samples (+{len(X_synth_auto)} synthetic)" + ) + print( + f"Custom-balanced: {len(X_balanced_custom)} samples (+{len(X_synth_custom)} synthetic)" + ) + + print("\nSynthetic data quality metrics:") + print("Auto-balanced approach:") + quality_auto = calculate_synthetic_quality_metrics( + X_train, X_synth_auto, y_train, y_synth_auto + ) + for metric, value in quality_auto.items(): + print(f" {metric}: {value:.4f}") + + print("\nCustom-balanced approach:") + quality_custom = calculate_synthetic_quality_metrics( + X_train, X_synth_custom, y_train, y_synth_custom + ) + for metric, value in quality_custom.items(): + print(f" {metric}: {value:.4f}") + + print("\nāœ… Dataset balancing demo completed successfully!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index d39a29a2..ab0ea751 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,10 @@ interpretability = [ "shapiq>=0.4.0", "seaborn>=0.12.2", ] +tabpfgen_datasynthesizer = [ + "tabpfgen>=0.1.4", +] + post_hoc_ensembles = [ "kditransform>=0.2.0", "llvmlite", diff --git a/src/tabpfn_extensions/tabpfgen_datasynthesizer/README.md b/src/tabpfn_extensions/tabpfgen_datasynthesizer/README.md new file mode 100644 index 00000000..98cb568a --- /dev/null +++ b/src/tabpfn_extensions/tabpfgen_datasynthesizer/README.md @@ -0,0 +1,320 @@ +# TabPFGen Data Synthesizer Extension + +A TabPFN extension for synthetic tabular data generation using [TabPFGen](https://github.com/sebhaan/TabPFGen). + +Author: Sebastian Haan + +## Motivation + +While there are many tools available for generating synthetic images or text, creating realistic tabular data that preserves the statistical properties and relationships of the original dataset has been more challenging. + +Generating synthetic tabular data is particularly useful in scenarios where: + +1. You have limited real data but need more samples for training +2. You can't share real data due to privacy concerns +3. You need to balance an imbalanced dataset +4. You want to test how your models would perform with more data + +What makes TabPFGen interesting is that it's built on the TabPFN transformer architecture and doesn't require additional training. It includes built-in visualization tools to help you verify the quality of the generated data by comparing distributions, feature correlations, and other important metrics between the real and synthetic datasets. + + +## Key Features + +- Energy-based synthetic data generation +- Support for both classification and regression tasks +- Automatic dataset balancing for imbalanced classes +- Class-balanced sampling option +- Comprehensive visualization tools +- Built on TabPFN transformer architecture +- No additional training required + +## Requirements + +- **Python 3.10+** (due to TabPFGen dependency) +- TabPFN Extensions framework + +> **Note**: While tabpfn-extensions supports Python 3.9+, this specific extension requires Python 3.10+ due to its TabPFGen dependency. + +## Installation + +```bash +# Ensure Python 3.10+ +python --version # Should show 3.10 or higher + +# Install TabPFN (choose one) +pip install tabpfn # For local inference +pip install tabpfn-client # For cloud-based inference + +# Python 3.10+ users who want every extension including TabPFGen +pip install "tabpfn-extensions[all,tabpfgen_datasynthesizer]" + +# Or install only the tabpfgen_datasynthesizer extension +pip install "tabpfn-extensions[tabpfgen_datasynthesizer]" +``` + +## šŸš€ Quick Start + +### Basic Synthetic Data Generation + +```python +from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer +from sklearn.datasets import load_breast_cancer + +# Load example data +X, y = load_breast_cancer(return_X_y=True) + +# Initialize synthesizer +synthesizer = TabPFNDataSynthesizer(n_sgld_steps=500) + +# Generate synthetic classification data +X_synth, y_synth = synthesizer.generate_classification( + X, y, + n_samples=100, + balance_classes=True, # Only balances synthetic samples + visualize=True # TabPFGen's built-in visualization +) + +# Generate synthetic regression data +from sklearn.datasets import load_diabetes + +# Load regression example dataset +X, y = load_diabetes(return_X_y=True) + +X_synth, y_synth = synthesizer.generate_regression( + X, y, + n_samples=150, + use_quantiles=True, + visualize=True +) +``` + +### Automatic Dataset Balancing + +Automatically balance imbalanced datasets: + +```python +from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer +from sklearn.datasets import make_classification + +# Create imbalanced dataset +X, y = make_classification(n_samples=1000, n_classes=3, + n_informative=3, n_redundant=1, + weights=[0.7, 0.2, 0.1], random_state=42) + +print("Original class distribution:") +print("Class 0: 700 samples (70.0%)") +print("Class 1: 200 samples (20.0%)") +print("Class 2: 100 samples (10.0%)") + +# Initialize synthesizer +synthesizer = TabPFNDataSynthesizer(n_sgld_steps=500) + +# Balance dataset automatically +X_synth, y_synth, X_balanced, y_balanced = synthesizer.balance_dataset( + X, y, visualize=True +) + +print(f"Original dataset: {len(X)} samples") +print(f"Synthetic samples: {len(X_synth)} samples") +print(f"Balanced dataset: {len(X_balanced)} samples") +# Final distribution approximately balanced! +``` + +## šŸ“Š API Reference + +### TabPFNDataSynthesizer + +Main class for synthetic data generation: + +```python +TabPFNDataSynthesizer( + n_sgld_steps=500, # SGLD iterations + sgld_step_size=0.01, # Step size + sgld_noise_scale=0.01, # Noise scale + device='auto' # 'cpu', 'cuda', or 'auto' +) +``` + +**Key Methods:** + +#### `balance_dataset()` ⭐ NEW +```python +X_synth, y_synth, X_combined, y_combined = synthesizer.balance_dataset( + X, y, + target_per_class=None, # Auto-detect majority class size + visualize=False, + feature_names=None +) +``` + +**Returns:** +- `X_synth, y_synth`: Synthetic data only +- `X_combined, y_combined`: Original + synthetic data + +#### `generate_classification()` +```python +X_synth, y_synth = synthesizer.generate_classification( + X, y, + n_samples, + balance_classes=True, # Balance only synthetic samples + visualize=False, + feature_names=None +) +``` + +#### `generate_regression()` +```python +X_synth, y_synth = synthesizer.generate_regression( + X, y, + n_samples, + use_quantiles=True, + visualize=False, + feature_names=None +) +``` + +### TabPFNDataSynthesizer Parameters + +- `n_sgld_steps` (int, default=500): Number of SGLD iterations for generation +- `sgld_step_size` (float, default=0.01): Step size for SGLD updates +- `sgld_noise_scale` (float, default=0.01): Scale of noise in SGLD +- `device` (str, default='auto'): Computing device ('cpu', 'cuda', or 'auto') + +### balance_dataset() Parameters + +- `target_per_class` (int, optional): Custom target samples per class +- `visualize` (bool, default=False): Enable TabPFGen's built-in visualizations +- `feature_names` (list, optional): Feature names for visualization + +### Generation Parameters + +- `n_samples` (int): Number of synthetic samples to generate +- `balance_classes` (bool, default=True): Balance only synthetic samples +- `use_quantiles` (bool, default=True): Quantile-based sampling for regression +- `visualize` (bool, default=False): Enable visualization plots + + +### Utility Functions + +```python +from tabpfn_extensions.tabpfgen_datasynthesizer.utils import ( + validate_tabpfn_data, # Check TabPFN compatibility + analyze_class_distribution, # Analyze class balance + calculate_synthetic_quality_metrics, # Quality assessment + combine_datasets # Combine original + synthetic +) + +# Validate data +is_valid, message = validate_tabpfn_data(X, y) + +# Analyze distribution +analysis = analyze_class_distribution(y, "Dataset Name") + +# Calculate quality +quality = calculate_synthetic_quality_metrics(X_orig, X_synth, y_orig, y_synth) + +# Combine datasets +X_combined, y_combined = combine_datasets( + X_orig, y_orig, X_synth, y_synth, + strategy='append' # 'append', 'replace', or 'balanced' +) +``` + +## šŸŽÆ Use Cases + +### 1. Imbalanced Dataset Balancing +Perfect for datasets with class imbalance: + +```python +# Detect imbalance +is_valid, message = validate_tabpfn_data(X, y) +if "imbalanced" in message: + # Auto-balance + _, _, X_balanced, y_balanced = synthesizer.balance_dataset(X, y) +``` + +### 2. Data Augmentation +Increase training data size: + +```python + +X_synth, y_synth = synthesizer.generate_classification( + X_train, y_train, n_samples=int(len(X_train) * 0.5) +) +X_augmented, y_augmented = combine_datasets( + X_train, y_train, X_synth, y_synth, strategy='append' +) +``` + +### 3. Quality Assessment +Monitor synthetic data quality: + +```python +quality_metrics = calculate_synthetic_quality_metrics( + X_orig, X_synth, y_orig, y_synth +) + +for metric, value in quality_metrics.items(): + print(f"{metric}: {value:.4f}") +``` + +## šŸ“ˆ Examples + +The `examples/` directory contains comprehensive demonstrations: + +1. **`basic_classification_example.py`** - Standard classification workflow +2. **`basic_regression_example.py`** - Regression data generation +3. **`class_balancing_demo.py`** - Showcase of `balance_dataset()` method + + +```bash +cd examples/tabpfgen_datasynthesizer/ +python basic_classification_example.py +``` + +## ⚔ Troubleshooting + +### Common Issues + +1. **TabPFGen Import Error**: + ```bash + pip install tabpfgen>=0.1.4 + ``` + +2. **Memory Issues**: Reduce `n_samples` or `n_sgld_steps` + +3. **Generation Quality**: Increase `n_sgld_steps` or adjust step size + +4. **Imbalanced Results**: Use `balance_dataset()` instead of `generate_classification()` + +### Performance Optimization + +- **Development**: Use 100-300 SGLD steps for faster iteration +- **Production**: Use 500+ SGLD steps for best quality +- **GPU**: Enable with `device='cuda'` for 5-10x speedup +- **Batch Processing**: Generate larger batches rather than multiple small ones + + +## šŸ” Important Notes + +### balance_classes vs balance_dataset() + +- **`balance_classes=True`**: Only balances the generated synthetic samples +- **`balance_dataset()`**: Balances entire dataset by generating minority class samples + +### Approximate Balancing + +Final class distributions may be **approximately balanced** rather than perfectly balanced due to TabPFN's label refinement process, which prioritizes data quality over exact counts. + +## šŸ“š Citation + +@software{haan2025tabpfgen, + author = {Haan, Sebastian}, + title = {TabPFGen: Synthetic Tabular Data Generation with TabPFN}, + url = {https://github.com/sebhaan/TabPFGen}, + year = {2025} +} + +## šŸ“„ License + +Apache License 2.0 - same as TabPFN Extensions. diff --git a/src/tabpfn_extensions/tabpfgen_datasynthesizer/__init__.py b/src/tabpfn_extensions/tabpfgen_datasynthesizer/__init__.py new file mode 100644 index 00000000..ffc3cdce --- /dev/null +++ b/src/tabpfn_extensions/tabpfgen_datasynthesizer/__init__.py @@ -0,0 +1,34 @@ +"""TabPFGen Data Synthesizer Extension for TabPFN + +This extension integrates TabPFGen for synthetic tabular data generation +with the TabPFN ecosystem, providing seamless workflows for data augmentation, +class balancing, and performance improvement. + +Requirements: + - Python 3.10+ + - TabPFGen >= 0.1.4 + +Note: While tabpfn-extensions supports Python 3.9+, this extension specifically +requires Python 3.10+ due to the underlying TabPFGen package requirements. + +Citation: +- TabPFGen package: https://github.com/sebhaan/TabPFGen +""" + +# Check Python version before any other imports +import sys + +if sys.version_info < (3, 10): + raise ImportError( + "TabPFGen Data Synthesizer requires Python 3.10+ " + f"(current: {sys.version_info.major}.{sys.version_info.minor})." + "Please upgrade Python or use other TabPFN extensions." + ) + +from .tabpfgen_wrapper import TabPFNDataSynthesizer +from .utils import combine_datasets, validate_tabpfn_data + +__version__ = "0.1.0" +__author__ = "Sebastian Haan" + +__all__ = ["TabPFNDataSynthesizer", "validate_tabpfn_data", "combine_datasets"] diff --git a/src/tabpfn_extensions/tabpfgen_datasynthesizer/tabpfgen_wrapper.py b/src/tabpfn_extensions/tabpfgen_datasynthesizer/tabpfgen_wrapper.py new file mode 100644 index 00000000..e5dec3df --- /dev/null +++ b/src/tabpfn_extensions/tabpfgen_datasynthesizer/tabpfgen_wrapper.py @@ -0,0 +1,259 @@ +"""Streamlined wrapper around TabPFGen for integration with TabPFN Extensions ecosystem.""" +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + import pandas as pd + +try: + from tabpfgen import TabPFGen + from tabpfgen.visuals import ( + visualize_classification_results, + visualize_regression_results, + ) + + TABPFGEN_AVAILABLE = True +except ImportError: + TABPFGEN_AVAILABLE = False + TabPFGen = None + + +class TabPFNDataSynthesizer: + """Streamlined wrapper around TabPFGen for synthetic tabular data generation. + + This class provides a clean interface to TabPFGen functionality with + sensible defaults optimized for TabPFN workflows. It relies heavily on + the actual TabPFGen package features and built-in visualizations. + + Parameters + ---------- + n_sgld_steps : int, default=500 + Number of SGLD iterations for generation + sgld_step_size : float, default=0.01 + Step size for SGLD updates + sgld_noise_scale : float, default=0.01 + Scale of noise in SGLD + device : str, default='auto' + Computing device ('cpu', 'cuda', or 'auto') + + Examples: + -------- + >>> from sklearn.datasets import load_breast_cancer + >>> X, y = load_breast_cancer(return_X_y=True) + >>> synthesizer = TabPFNDataSynthesizer(n_sgld_steps=300) + >>> X_synth, y_synth = synthesizer.generate_classification(X, y, n_samples=100) + """ + + def __init__( + self, + n_sgld_steps: int = 500, + sgld_step_size: float = 0.01, + sgld_noise_scale: float = 0.01, + device: str = "auto", + ): + if not TABPFGEN_AVAILABLE: + raise ImportError( + "TabPFGen is required but not installed. " + "Install it with: pip install tabpfgen" + ) + + self.n_sgld_steps = n_sgld_steps + self.sgld_step_size = sgld_step_size + self.sgld_noise_scale = sgld_noise_scale + self.device = device + + # Initialize TabPFGen generator + self.generator = TabPFGen( + n_sgld_steps=n_sgld_steps, + sgld_step_size=sgld_step_size, + sgld_noise_scale=sgld_noise_scale, + device=device, + ) + + def generate_classification( + self, + X: np.ndarray | pd.DataFrame, + y: np.ndarray | pd.Series, + n_samples: int, + balance_classes: bool = True, + visualize: bool = False, + feature_names: list[str] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Generate synthetic classification data using TabPFGen. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training features + y : array-like of shape (n_samples,) + Training labels + n_samples : int + Number of synthetic samples to generate + balance_classes : bool, default=True + Whether to generate balanced class distributions + visualize : bool, default=False + Whether to create TabPFGen's built-in visualization plots + feature_names : list, optional + Names of features for visualization + + Returns: + ------- + X_synth : ndarray of shape (n_samples, n_features) + Generated synthetic features + y_synth : ndarray of shape (n_samples,) + Generated synthetic labels + """ + # Convert inputs to numpy arrays if needed + X = np.asarray(X) + y = np.asarray(y) + + # Generate synthetic data using TabPFGen + X_synth, y_synth = self.generator.generate_classification( + X, y, n_samples=n_samples, balance_classes=balance_classes + ) + + # Use TabPFGen's built-in visualization if requested + if visualize and TABPFGEN_AVAILABLE: + try: + visualize_classification_results( + X, y, X_synth, y_synth, feature_names=feature_names + ) + except (ImportError, AttributeError, ValueError, TypeError) as e: + warnings.warn(f"TabPFGen visualization failed: {e}", stacklevel=2) + + return X_synth, y_synth + + def generate_regression( + self, + X: np.ndarray | pd.DataFrame, + y: np.ndarray | pd.Series, + n_samples: int, + use_quantiles: bool = True, + visualize: bool = False, + feature_names: list[str] | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Generate synthetic regression data using TabPFGen. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training features + y : array-like of shape (n_samples,) + Training targets + n_samples : int + Number of synthetic samples to generate + use_quantiles : bool, default=True + Whether to use quantile-based sampling + visualize : bool, default=False + Whether to create TabPFGen's built-in visualization plots + feature_names : list, optional + Names of features for visualization + + Returns: + ------- + X_synth : ndarray of shape (n_samples, n_features) + Generated synthetic features + y_synth : ndarray of shape (n_samples,) + Generated synthetic targets + """ + # Convert inputs to numpy arrays if needed + X = np.asarray(X) + y = np.asarray(y) + + # Generate synthetic data using TabPFGen + X_synth, y_synth = self.generator.generate_regression( + X, y, n_samples=n_samples, use_quantiles=use_quantiles + ) + + # Use TabPFGen's built-in visualization if requested + if visualize and TABPFGEN_AVAILABLE: + try: + visualize_regression_results( + X, y, X_synth, y_synth, feature_names=feature_names + ) + except (ImportError, AttributeError, ValueError, TypeError) as e: + warnings.warn(f"TabPFGen visualization failed: {e}", stacklevel=2) + + return X_synth, y_synth + + def balance_dataset( + self, + X: np.ndarray | pd.DataFrame, + y: np.ndarray | pd.Series, + target_per_class: int | None = None, + visualize: bool = False, + feature_names: list[str] | None = None, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Balance classification dataset using TabPFGen's balance_dataset method. + + This method uses TabPFGen's new balance_dataset functionality to automatically + generate synthetic samples for minority classes, bringing them up to the + majority class size or a specified target. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Training features + y : array-like of shape (n_samples,) + Training labels + target_per_class : int, optional + Target number of samples per class. If None, balances to majority class size + visualize : bool, default=False + Whether to create TabPFGen's built-in visualization plots + feature_names : list, optional + Names of features for visualization + + Returns: + ------- + X_synth : ndarray + Generated synthetic features only + y_synth : ndarray + Generated synthetic labels only + X_combined : ndarray + Combined dataset features (original + synthetic) + y_combined : ndarray + Combined dataset labels (original + synthetic) + + Notes: + ----- + The final class distribution may be approximately balanced rather than + perfectly balanced due to TabPFN's label refinement process, which + prioritizes data quality and realism over exact class counts. + """ + X = np.asarray(X) + y = np.asarray(y) + + # Show original class distribution + unique_classes, class_counts = np.unique(y, return_counts=True) + + # Use TabPFGen's balance_dataset method + if target_per_class is None: + # Balance to majority class size automatically + X_synth, y_synth, X_combined, y_combined = self.generator.balance_dataset( + X, y + ) + else: + # Balance to specified target per class + X_synth, y_synth, X_combined, y_combined = self.generator.balance_dataset( + X, y, target_per_class=target_per_class + ) + + # Show results + + # Show final distribution + final_unique, final_counts = np.unique(y_combined, return_counts=True) + + # Use TabPFGen's built-in visualization if requested + if visualize and TABPFGEN_AVAILABLE: + try: + visualize_classification_results( + X, y, X_synth, y_synth, feature_names=feature_names + ) + except (ImportError, AttributeError, ValueError, TypeError) as e: + warnings.warn(f"TabPFGen visualization failed: {e}", stacklevel=2) + + return X_synth, y_synth, X_combined, y_combined diff --git a/src/tabpfn_extensions/tabpfgen_datasynthesizer/utils.py b/src/tabpfn_extensions/tabpfgen_datasynthesizer/utils.py new file mode 100644 index 00000000..e69d3e7b --- /dev/null +++ b/src/tabpfn_extensions/tabpfgen_datasynthesizer/utils.py @@ -0,0 +1,317 @@ +"""Utility functions for TabPFGen data synthesizer extension.""" +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING, Any + +import numpy as np + +if TYPE_CHECKING: + import pandas as pd + + +def validate_tabpfn_data( + X: np.ndarray | pd.DataFrame, + y: np.ndarray | pd.Series, + max_samples: int = 10000, + max_features: int = 100, +) -> tuple[bool, str]: + """Validate data compatibility with TabPFN requirements. + + Parameters + ---------- + X : array-like + Feature data + y : array-like + Target data + max_samples : int, default=10000 + Maximum number of samples recommended for TabPFN + max_features : int, default=100 + Maximum number of features recommended for TabPFN + + Returns: + ------- + is_valid : bool + Whether data meets TabPFN requirements + message : str + Validation message or warning + """ + X = np.asarray(X) + y = np.asarray(y) + + # Handle edge cases + if X.size == 0 or y.size == 0: + return False, "Dataset is empty" + + if len(X) != len(y): + return False, "X and y have mismatched dimensions" + + n_samples, n_features = X.shape + + warnings_list = [] + + # Check sample count + if n_samples > max_samples: + warnings_list.append( + f"Dataset has {n_samples} samples, but TabPFN works best with <={max_samples} samples" + ) + + # Check feature count + if n_features > max_features: + warnings_list.append( + f"Dataset has {n_features} features, but TabPFN works best with <={max_features} features" + ) + + # Check for missing values + if np.isnan(X).any(): + warnings_list.append("Dataset contains missing values - consider imputation") + + # Check for infinite values + if np.isinf(X).any(): + warnings_list.append("Dataset contains infinite values") + + # For classification, check class distribution + unique_classes = np.unique(y) + if ( + len(unique_classes) <= 20 and len(unique_classes) > 0 + ): # Assume classification if few unique values + unique_classes, counts = np.unique(y, return_counts=True) + + # Only check if we have actual counts + if len(counts) > 0: + min_count = np.min(counts) + if min_count < 2: + warnings_list.append( + "Some classes have very few samples - consider class balancing" + ) + + # Check for imbalance + max_count = np.max(counts) + imbalance_ratio = max_count / min_count if min_count > 0 else float("inf") + if imbalance_ratio > 5: + warnings_list.append( + f"Dataset is imbalanced (ratio: {imbalance_ratio:.1f}:1) - consider using balance_dataset()" + ) + + is_valid = len(warnings_list) == 0 + message = "; ".join(warnings_list) if warnings_list else "Data validation passed" + + return is_valid, message + + +def combine_datasets( + X_original: np.ndarray | pd.DataFrame, + y_original: np.ndarray | pd.Series, + X_synthetic: np.ndarray, + y_synthetic: np.ndarray, + strategy: str = "append", +) -> tuple[np.ndarray, np.ndarray]: + """Combine original and synthetic datasets. + + Parameters + ---------- + X_original, y_original : array-like + Original training data + X_synthetic, y_synthetic : array-like + Synthetic data generated by TabPFGen + strategy : str, default='append' + How to combine data: 'append', 'replace', or 'balanced' + + Returns: + ------- + X_combined : ndarray + Combined feature data + y_combined : ndarray + Combined target data + """ + X_orig = np.asarray(X_original) + y_orig = np.asarray(y_original) + + if strategy == "append": + # Simply append synthetic data to original + X_combined = np.vstack([X_orig, X_synthetic]) + y_combined = np.hstack([y_orig, y_synthetic]) + + elif strategy == "replace": + # Use only synthetic data + X_combined = X_synthetic + y_combined = y_synthetic + + elif strategy == "balanced": + # Balance original and synthetic data equally + n_orig = len(X_orig) + n_synth = len(X_synthetic) + + if n_synth > n_orig: + # Subsample synthetic data + indices = np.random.choice(n_synth, n_orig, replace=False) + X_synthetic = X_synthetic[indices] + y_synthetic = y_synthetic[indices] + elif n_synth < n_orig: + # Subsample original data + indices = np.random.choice(n_orig, n_synth, replace=False) + X_orig = X_orig[indices] + y_orig = y_orig[indices] + + X_combined = np.vstack([X_orig, X_synthetic]) + y_combined = np.hstack([y_orig, y_synthetic]) + + else: + raise ValueError("strategy must be 'append', 'replace', or 'balanced'") + + return X_combined, y_combined + + +def analyze_class_distribution( + y: np.ndarray | pd.Series, title: str = "Class Distribution" +) -> dict[str, Any]: + """Analyze and display class distribution statistics. + + Parameters + ---------- + y : array-like + Target labels + title : str + Title for the analysis + + Returns: + ------- + analysis : dict + Dictionary containing distribution statistics + """ + y = np.asarray(y) + unique_classes, counts = np.unique(y, return_counts=True) + + total_samples = len(y) + percentages = counts / total_samples * 100 + + # Calculate imbalance metrics + max_count = np.max(counts) + min_count = np.min(counts) + imbalance_ratio = max_count / min_count if min_count > 0 else float("inf") + + analysis = { + "title": title, + "classes": unique_classes.tolist(), + "counts": counts.tolist(), + "percentages": percentages.tolist(), + "total_samples": total_samples, + "num_classes": len(unique_classes), + "max_count": max_count, + "min_count": min_count, + "imbalance_ratio": imbalance_ratio, + "is_balanced": imbalance_ratio <= 2.0, # Reasonable threshold + } + + for cls, count, pct in zip(unique_classes, counts, percentages): + pass + + + return analysis + + +def calculate_synthetic_quality_metrics( + X_original: np.ndarray | pd.DataFrame, + X_synthetic: np.ndarray, + y_original: np.ndarray | pd.Series | None = None, + y_synthetic: np.ndarray | None = None, +) -> dict[str, float]: + """Calculate quality metrics comparing original and synthetic data. + + Parameters + ---------- + X_original : array-like + Original feature data + X_synthetic : array-like + Synthetic feature data + y_original, y_synthetic : array-like, optional + Original and synthetic labels + + Returns: + ------- + metrics : dict + Dictionary of quality metrics + """ + X_orig = np.asarray(X_original) + X_synth = np.asarray(X_synthetic) + + metrics = {} + + # Feature distribution comparison + try: + mean_orig = np.mean(X_orig, axis=0) + mean_synth = np.mean(X_synth, axis=0) + metrics["mean_absolute_error"] = np.mean(np.abs(mean_orig - mean_synth)) + + std_orig = np.std(X_orig, axis=0) + std_synth = np.std(X_synth, axis=0) + metrics["std_absolute_error"] = np.mean(np.abs(std_orig - std_synth)) + + except (ValueError, TypeError, np.linalg.LinAlgError) as e: + warnings.warn(f"Could not calculate distribution metrics: {e}", stacklevel=2) + + # Correlation comparison + try: + if X_orig.shape[1] > 1: # Need at least 2 features for correlation + corr_orig = np.corrcoef(X_orig.T) + corr_synth = np.corrcoef(X_synth.T) + + # Compare upper triangle (excluding diagonal) + mask = np.triu(np.ones_like(corr_orig), k=1).astype(bool) + if np.sum(mask) > 0: # Ensure we have correlations to compare + corr_diff = np.abs(corr_orig[mask] - corr_synth[mask]) + metrics["correlation_mae"] = np.mean(corr_diff) + + except (ValueError, TypeError, np.linalg.LinAlgError) as e: + warnings.warn(f"Could not calculate distribution metrics: {e}", stacklevel=2) + + # Class distribution comparison (if labels provided) + if y_original is not None and y_synthetic is not None: + try: + y_orig = np.asarray(y_original) + y_synth = np.asarray(y_synthetic) + + # Get class distributions + unique_orig, counts_orig = np.unique(y_orig, return_counts=True) + unique_synth, counts_synth = np.unique(y_synth, return_counts=True) + + # Normalize to proportions + prop_orig = counts_orig / len(y_orig) + prop_synth = counts_synth / len(y_synth) + + # Calculate KL divergence (simple version) + # Add small epsilon to avoid log(0) + eps = 1e-8 + prop_orig = prop_orig + eps + prop_synth = prop_synth + eps + + # Ensure same classes in both + all_classes = np.unique(np.concatenate([unique_orig, unique_synth])) + + prop_orig_full = np.zeros(len(all_classes)) + eps + prop_synth_full = np.zeros(len(all_classes)) + eps + + for i, cls in enumerate(all_classes): + if cls in unique_orig: + idx = np.where(unique_orig == cls)[0][0] + prop_orig_full[i] = prop_orig[idx] + if cls in unique_synth: + idx = np.where(unique_synth == cls)[0][0] + prop_synth_full[i] = prop_synth[idx] + + # Normalize again + prop_orig_full = prop_orig_full / np.sum(prop_orig_full) + prop_synth_full = prop_synth_full / np.sum(prop_synth_full) + + # Calculate JS divergence (symmetric version of KL) + m = 0.5 * (prop_orig_full + prop_synth_full) + js_div = 0.5 * np.sum( + prop_orig_full * np.log(prop_orig_full / m) + ) + 0.5 * np.sum(prop_synth_full * np.log(prop_synth_full / m)) + + metrics["js_divergence"] = js_div + + except (ValueError, TypeError, np.linalg.LinAlgError) as e: + warnings.warn(f"Could not calculate distribution metrics: {e}", stacklevel=2) + + return metrics diff --git a/tests/test_tabpfgen_datasynthesizer.py b/tests/test_tabpfgen_datasynthesizer.py new file mode 100644 index 00000000..43c18bea --- /dev/null +++ b/tests/test_tabpfgen_datasynthesizer.py @@ -0,0 +1,501 @@ +"""Test suite for TabPFGen Data Synthesizer Extension.""" + +from __future__ import annotations + +import sys + +import numpy as np +import pandas as pd +import pytest +from sklearn.datasets import make_classification, make_regression + +# Python version check +PYTHON_VERSION_OK = sys.version_info >= (3, 10) +SKIP_REASON_PYTHON = "Python >=3.10 required for tabpfgen_datasynthesizer extension" + +# Import modules to test (with proper error handling) +try: + from tabpfn_extensions.tabpfgen_datasynthesizer import TabPFNDataSynthesizer + from tabpfn_extensions.tabpfgen_datasynthesizer.tabpfgen_wrapper import ( + TABPFGEN_AVAILABLE, + ) + from tabpfn_extensions.tabpfgen_datasynthesizer.utils import ( + analyze_class_distribution, + calculate_synthetic_quality_metrics, + combine_datasets, + validate_tabpfn_data, + ) + + EXTENSION_IMPORTABLE = True +except ImportError: + EXTENSION_IMPORTABLE = False + # Create dummy objects to prevent test collection errors + TabPFNDataSynthesizer = None + TABPFGEN_AVAILABLE = False + + +# Combined skip condition - skip if Python version is too old OR if extension can't be imported +SKIP_EXTENSION_TESTS = not PYTHON_VERSION_OK or not EXTENSION_IMPORTABLE +SKIP_TABPFGEN_TESTS = SKIP_EXTENSION_TESTS or not TABPFGEN_AVAILABLE + +# Skip reasons +SKIP_REASON_EXTENSION = ( + SKIP_REASON_PYTHON + if not PYTHON_VERSION_OK + else "tabpfgen_datasynthesizer extension not available" +) +SKIP_REASON_TABPFGEN = ( + SKIP_REASON_EXTENSION if SKIP_EXTENSION_TESTS else "TabPFGen not available" +) + + +# Test data fixtures +@pytest.fixture +def classification_data(): + """Generate small classification dataset for testing.""" + X, y = make_classification( + n_samples=100, n_features=5, n_classes=2, n_informative=3, random_state=42 + ) + return X, y + + +@pytest.fixture +def regression_data(): + """Generate small regression dataset for testing.""" + X, y = make_regression(n_samples=100, n_features=5, noise=0.1, random_state=42) + return X, y + + +@pytest.fixture +def imbalanced_data(): + """Generate imbalanced classification dataset for testing.""" + X, y = make_classification( + n_samples=200, + n_features=6, + n_classes=3, + weights=[0.7, 0.2, 0.1], + n_informative=4, + random_state=42, + ) + return X, y + + +@pytest.mark.skipif(SKIP_EXTENSION_TESTS, reason=SKIP_REASON_EXTENSION) +class TestTabPFNDataSynthesizer: + """Test the TabPFNDataSynthesizer class.""" + + def test_init(self): + """Test synthesizer initialization.""" + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=100) + assert synthesizer.n_sgld_steps == 100 + assert synthesizer.device == "auto" + + @pytest.mark.skipif(SKIP_TABPFGEN_TESTS, reason=SKIP_REASON_TABPFGEN) + def test_generate_classification(self, classification_data): + """Test classification data generation.""" + X, y = classification_data + + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=50) # Small for testing + X_synth, y_synth = synthesizer.generate_classification( + X, y, n_samples=20, balance_classes=True + ) + + assert X_synth.shape == (20, X.shape[1]) + assert y_synth.shape == (20,) + assert set(np.unique(y_synth)).issubset(set(np.unique(y))) + + @pytest.mark.skipif(SKIP_TABPFGEN_TESTS, reason=SKIP_REASON_TABPFGEN) + def test_generate_regression(self, regression_data): + """Test regression data generation.""" + X, y = regression_data + + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=50) # Small for testing + X_synth, y_synth = synthesizer.generate_regression( + X, y, n_samples=20, use_quantiles=True + ) + + assert X_synth.shape == (20, X.shape[1]) + assert y_synth.shape == (20,) + assert np.isfinite(X_synth).all() + assert np.isfinite(y_synth).all() + + @pytest.mark.skipif(SKIP_TABPFGEN_TESTS, reason=SKIP_REASON_TABPFGEN) + def test_balance_dataset(self, imbalanced_data): + """Test the balance_dataset method.""" + X, y = imbalanced_data + + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=30) # Very small for testing + X_synth, y_synth, X_combined, y_combined = synthesizer.balance_dataset( + X, y, visualize=False + ) + + # Check return types and shapes + assert isinstance(X_synth, np.ndarray) + assert isinstance(y_synth, np.ndarray) + assert isinstance(X_combined, np.ndarray) + assert isinstance(y_combined, np.ndarray) + + # Check that combined data includes original data + assert len(X_combined) >= len(X) + assert len(y_combined) >= len(y) + + # Check that synthetic data was generated + assert len(X_synth) > 0 + assert len(y_synth) > 0 + + # Check feature dimensions match + assert X_synth.shape[1] == X.shape[1] + assert X_combined.shape[1] == X.shape[1] + + @pytest.mark.skipif(SKIP_EXTENSION_TESTS, reason=SKIP_REASON_EXTENSION) + def test_pandas_input(self, classification_data): + """Test that pandas DataFrames are handled correctly.""" + X, y = classification_data + X_df = pd.DataFrame(X) + y_series = pd.Series(y) + + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=50) + try: + X_synth, y_synth = synthesizer.generate_classification( + X_df, y_series, n_samples=10 + ) + assert isinstance(X_synth, np.ndarray) + assert isinstance(y_synth, np.ndarray) + except ImportError: + pytest.skip("TabPFGen not available") + + +@pytest.mark.skipif(SKIP_EXTENSION_TESTS, reason=SKIP_REASON_EXTENSION) +class TestUtilityFunctions: + """Test utility functions.""" + + def test_validate_tabpfn_data_valid(self, classification_data): + """Test data validation with valid data.""" + X, y = classification_data + is_valid, message = validate_tabpfn_data(X, y) + assert isinstance(is_valid, bool) + assert isinstance(message, str) + + def test_validate_tabpfn_data_large(self): + """Test data validation with large dataset.""" + X = np.random.randn(15000, 10) # Too large + y = np.random.randint(0, 2, 15000) + + is_valid, message = validate_tabpfn_data(X, y, max_samples=10000) + assert not is_valid + assert "samples" in message.lower() + + def test_validate_tabpfn_data_many_features(self): + """Test data validation with too many features.""" + X = np.random.randn(100, 150) # Too many features + y = np.random.randint(0, 2, 100) + + is_valid, message = validate_tabpfn_data(X, y, max_features=100) + assert not is_valid + assert "features" in message.lower() + + def test_validate_tabpfn_data_imbalanced(self, imbalanced_data): + """Test data validation with imbalanced data.""" + X, y = imbalanced_data + + is_valid, message = validate_tabpfn_data(X, y) + # Should detect imbalance + assert "imbalanced" in message.lower() or "balance_dataset" in message.lower() + + def test_combine_datasets_append(self, classification_data): + """Test combine_datasets with append strategy.""" + X, y = classification_data + X_synth = np.random.randn(20, X.shape[1]) + y_synth = np.random.randint(0, 2, 20) + + X_combined, y_combined = combine_datasets( + X, y, X_synth, y_synth, strategy="append" + ) + + assert X_combined.shape[0] == len(X) + len(X_synth) + assert y_combined.shape[0] == len(y) + len(y_synth) + + def test_combine_datasets_replace(self, classification_data): + """Test combine_datasets with replace strategy.""" + X, y = classification_data + X_synth = np.random.randn(20, X.shape[1]) + y_synth = np.random.randint(0, 2, 20) + + X_combined, y_combined = combine_datasets( + X, y, X_synth, y_synth, strategy="replace" + ) + + assert X_combined.shape[0] == len(X_synth) + assert y_combined.shape[0] == len(y_synth) + np.testing.assert_array_equal(X_combined, X_synth) + np.testing.assert_array_equal(y_combined, y_synth) + + def test_combine_datasets_balanced(self, classification_data): + """Test combine_datasets with balanced strategy.""" + X, y = classification_data + X_synth = np.random.randn(50, X.shape[1]) # Different size + y_synth = np.random.randint(0, 2, 50) + + X_combined, y_combined = combine_datasets( + X, y, X_synth, y_synth, strategy="balanced" + ) + + # Should have equal amounts of original and synthetic data + expected_size = 2 * min(len(X), len(X_synth)) + assert X_combined.shape[0] == expected_size + assert y_combined.shape[0] == expected_size + + def test_combine_datasets_invalid_strategy(self, classification_data): + """Test combine_datasets with invalid strategy.""" + X, y = classification_data + X_synth = np.random.randn(20, X.shape[1]) + y_synth = np.random.randint(0, 2, 20) + + with pytest.raises(ValueError): + combine_datasets(X, y, X_synth, y_synth, strategy="invalid") + + def test_analyze_class_distribution(self, classification_data): + """Test class distribution analysis.""" + X, y = classification_data + + analysis = analyze_class_distribution(y, "Test Dataset") + + assert isinstance(analysis, dict) + assert "classes" in analysis + assert "counts" in analysis + assert "percentages" in analysis + assert "total_samples" in analysis + assert "imbalance_ratio" in analysis + assert "is_balanced" in analysis + + assert analysis["total_samples"] == len(y) + assert len(analysis["classes"]) == len(np.unique(y)) + assert sum(analysis["counts"]) == len(y) + assert abs(sum(analysis["percentages"]) - 100.0) < 1e-10 + + def test_calculate_synthetic_quality_metrics(self, classification_data): + """Test synthetic data quality metrics calculation.""" + X, y = classification_data + + # Create synthetic data (just random for testing) + X_synth = np.random.randn(50, X.shape[1]) + y_synth = np.random.choice(np.unique(y), 50) + + metrics = calculate_synthetic_quality_metrics(X, X_synth, y, y_synth) + + assert isinstance(metrics, dict) + + # Check that we get some expected metrics + expected_metrics = ["mean_absolute_error", "std_absolute_error"] + for metric in expected_metrics: + if metric in metrics: + assert isinstance(metrics[metric], (int, float)) + assert np.isfinite(metrics[metric]) + + def test_calculate_synthetic_quality_metrics_no_labels(self, classification_data): + """Test quality metrics calculation without labels.""" + X, y = classification_data + X_synth = np.random.randn(50, X.shape[1]) + + metrics = calculate_synthetic_quality_metrics(X, X_synth) + + assert isinstance(metrics, dict) + # Should still get feature-based metrics + if "mean_absolute_error" in metrics: + assert isinstance(metrics["mean_absolute_error"], (int, float)) + + +@pytest.mark.skipif(SKIP_TABPFGEN_TESTS, reason=SKIP_REASON_TABPFGEN) +class TestIntegration: + """Integration tests requiring TabPFGen.""" + + def test_end_to_end_classification_with_balancing(self, imbalanced_data): + """Test complete classification workflow with balancing.""" + X, y = imbalanced_data + + # Validate original data + is_valid, message = validate_tabpfn_data(X, y) + print(f"Validation: {message}") + + # Analyze original distribution + original_analysis = analyze_class_distribution(y, "Original") + assert not original_analysis["is_balanced"] # Should be imbalanced + + # Balance dataset + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=20) # Small for testing + X_synth, y_synth, X_balanced, y_balanced = synthesizer.balance_dataset( + X, y, visualize=False + ) + + # Analyze balanced distribution + balanced_analysis = analyze_class_distribution(y_balanced, "Balanced") + + # Should be more balanced than original + assert ( + balanced_analysis["imbalance_ratio"] < original_analysis["imbalance_ratio"] + ) + + # Calculate quality metrics + quality_metrics = calculate_synthetic_quality_metrics(X, X_synth, y, y_synth) + assert isinstance(quality_metrics, dict) + + # Test combination + X_combined_append, y_combined_append = combine_datasets( + X, y, X_synth, y_synth, strategy="append" + ) + + assert len(X_combined_append) == len(X_balanced) + assert len(y_combined_append) == len(y_balanced) + + def test_end_to_end_regression(self, regression_data): + """Test complete regression workflow.""" + X, y = regression_data + + # Generate synthetic regression data + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=20) + X_synth, y_synth = synthesizer.generate_regression( + X, y, n_samples=30, visualize=False + ) + + # Combine data + X_combined, y_combined = combine_datasets( + X, y, X_synth, y_synth, strategy="append" + ) + + # Calculate quality + quality_metrics = calculate_synthetic_quality_metrics(X, X_synth) + + assert X_combined.shape[0] == len(X) + len(X_synth) + assert isinstance(quality_metrics, dict) + + # Basic sanity checks + assert np.isfinite(X_synth).all() + assert np.isfinite(y_synth).all() + assert X_synth.shape[1] == X.shape[1] + + +@pytest.mark.skipif(SKIP_EXTENSION_TESTS, reason=SKIP_REASON_EXTENSION) +class TestErrorHandling: + """Test error handling and edge cases.""" + + def test_tabpfgen_not_available(self, monkeypatch): + """Test behavior when TabPFGen is not available.""" + # Mock TabPFGen as unavailable + import tabpfn_extensions.tabpfgen_datasynthesizer.tabpfgen_wrapper as wrapper + + monkeypatch.setattr(wrapper, "TABPFGEN_AVAILABLE", False) + + with pytest.raises(ImportError, match="TabPFGen is required"): + TabPFNDataSynthesizer() + + def test_empty_dataset(self): + """Test handling of empty datasets.""" + X = np.array([]).reshape(0, 5) + y = np.array([]) + + is_valid, message = validate_tabpfn_data(X, y) + # Should handle gracefully and detect empty dataset + assert not is_valid + assert "empty" in message.lower() + + def test_single_class_dataset(self): + """Test handling of single-class datasets.""" + X = np.random.randn(100, 5) + y = np.zeros(100) # All same class + + is_valid, message = validate_tabpfn_data(X, y) + analysis = analyze_class_distribution(y, "Single Class") + + assert analysis["num_classes"] == 1 + assert analysis["imbalance_ratio"] == 1.0 # Perfectly balanced (trivially) + + def test_mismatched_dimensions(self): + """Test handling of mismatched X and y dimensions.""" + X = np.random.randn(100, 5) + y = np.random.randint(0, 2, 90) # Wrong size + + is_valid, message = validate_tabpfn_data(X, y) + # Should detect dimension mismatch + assert not is_valid + assert "mismatch" in message.lower() + + +# Performance and benchmarking tests +@pytest.mark.skipif(SKIP_TABPFGEN_TESTS, reason=SKIP_REASON_TABPFGEN) +class TestPerformance: + """Test performance characteristics.""" + + def test_generation_performance(self, classification_data): + """Test that generation completes in reasonable time.""" + import time + + X, y = classification_data + synthesizer = TabPFNDataSynthesizer(n_sgld_steps=10) # Very small for speed + + start_time = time.time() + X_synth, y_synth = synthesizer.generate_classification( + X, y, n_samples=20, visualize=False + ) + elapsed_time = time.time() - start_time + + # Should complete in reasonable time (adjust threshold as needed) + assert elapsed_time < 30.0 # 30 seconds max for small dataset + assert len(X_synth) == 20 + + @pytest.mark.skipif(SKIP_EXTENSION_TESTS, reason=SKIP_REASON_EXTENSION) + def test_utility_function_performance(self): + """Test that utility functions perform well on larger datasets.""" + import time + + # Create larger dataset + X = np.random.randn(5000, 20) + y = np.random.randint(0, 5, 5000) + + start_time = time.time() + is_valid, message = validate_tabpfn_data(X, y) + analysis = analyze_class_distribution(y, "Large Dataset") + elapsed_time = time.time() - start_time + + # Should be fast for utility functions + assert elapsed_time < 5.0 # 5 seconds max + assert isinstance(analysis, dict) + + +# Test to verify version checking works correctly +class TestVersionCompatibility: + """Test version compatibility checks.""" + + def test_python_version_detection(self): + """Test that Python version detection works correctly.""" + assert isinstance(PYTHON_VERSION_OK, bool) + assert isinstance(SKIP_EXTENSION_TESTS, bool) + assert isinstance(SKIP_TABPFGEN_TESTS, bool) + + # On Python <3.10, extension tests should be skipped + if sys.version_info < (3, 10): + assert SKIP_EXTENSION_TESTS + assert "Python >=3.10 required" in SKIP_REASON_EXTENSION + + def test_skip_conditions_logic(self): + """Test that skip conditions are logically correct.""" + # SKIP_TABPFGEN_TESTS should always be True if SKIP_EXTENSION_TESTS is True + if SKIP_EXTENSION_TESTS: + assert SKIP_TABPFGEN_TESTS + + # If Python version is OK and extension is importable, + # then SKIP_EXTENSION_TESTS should be False + if PYTHON_VERSION_OK and EXTENSION_IMPORTABLE: + assert not SKIP_EXTENSION_TESTS + + +if __name__ == "__main__": + # Print version information for debugging + print(f"Python version: {sys.version}") + print(f"Python >=3.10: {PYTHON_VERSION_OK}") + print(f"Extension importable: {EXTENSION_IMPORTABLE}") + print(f"TabPFGen available: {TABPFGEN_AVAILABLE}") + print(f"Skip extension tests: {SKIP_EXTENSION_TESTS}") + print(f"Skip TabPFGen tests: {SKIP_TABPFGEN_TESTS}") + + # Run tests with pytest + pytest.main([__file__, "-v"])