Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions examples/unsupervised/generate_data_following_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Prior Labs GmbH 2025.
# Licensed under the Apache License, Version 2.0

import torch
from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split

from tabpfn_extensions import TabPFNClassifier, unsupervised

# Load the breast cancer dataset
df = load_wine(return_X_y=False)
X, y = df["data"], df["target"]
attribute_names = df["feature_names"]

wine_dag = {
'alcohol': [],
'malic_acid': [],
'ash': ['magnesium'],
'alcalinity_of_ash': ['ash', 'magnesium'],
'magnesium': [],
'total_phenols': ['flavanoids', 'nonflavanoid_phenols', 'proanthocyanins'],
'flavanoids': [],
'nonflavanoid_phenols': [],
'proanthocyanins': [],
'color_intensity': ['flavanoids', 'proanthocyanins', 'total_phenols'],
'hue': ['color_intensity'],
'od280/od315_of_diluted_wines': ['flavanoids', 'total_phenols'],
'proline': ['alcohol', 'total_phenols']
}

# convert feature names to indices in keys and values
dag = {i: [list(wine_dag.keys()).index(dep) for dep in deps] for i, deps in enumerate(wine_dag.values())}
print(dag)

# Split the data
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.5,
random_state=42,
)

# Initialize TabPFN models
# Use parameters that work with both TabPFN and TabPFN-client
clf = TabPFNClassifier(n_estimators=3)

# Import TabPFNRegressor for numerical features
from tabpfn_extensions import TabPFNRegressor

reg = TabPFNRegressor(n_estimators=3)

# Initialize unsupervised model
model_unsupervised = unsupervised.TabPFNUnsupervisedModel(
tabpfn_clf=clf,
tabpfn_reg=reg,
)

# Create and run synthetic experiment
exp_synthetic = unsupervised.experiments.GenerateSyntheticDataExperiment(
task_type="unsupervised",
)

# Convert data to torch tensors
X_tensor = torch.tensor(X_train, dtype=torch.float32)
y_tensor = torch.tensor(y_train, dtype=torch.float32)

# Run the experiment
results = exp_synthetic.run(
tabpfn=model_unsupervised,
X=X_tensor,
y=y_tensor,
attribute_names=attribute_names,
temp=1.0,
n_samples=X_train.shape[0] * 3, # Generate 3x original samples
indices=list(range(X_train.shape[1])), # Use all features
n_permutations=3,
dag=dag,
)

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "tabpfn-extensions"
version = "0.1.0"
version = "0.1.1"
dependencies = [
"torch>=2.1,<3",
"pandas>=1.4.0,<3",
Expand Down
5 changes: 5 additions & 0 deletions src/tabpfn_extensions/unsupervised/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def run(self, tabpfn, **kwargs):

temp = kwargs.get("temp", 1.0)
n_samples = kwargs.get("n_samples", X.shape[0])

n_permutations = kwargs.get("n_permutations", 1)
dag = kwargs.get("dag", None)

self.X, self.y = X, y
self.X = self.X[:, indices]
Expand All @@ -133,6 +136,8 @@ def run(self, tabpfn, **kwargs):
self.synthetic_X = tabpfn.generate_synthetic_data(
n_samples=n_samples,
t=temp,
n_permutations=n_permutations,
dag=dag,
)

data_real = pd.DataFrame(
Expand Down
43 changes: 38 additions & 5 deletions src/tabpfn_extensions/unsupervised/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
- Compatibility with both TabPFN and TabPFN-client backends
- Support for mixed data types (categorical and numerical features)
- Flexible permutation-based approach for feature dependencies
- Support for Directed Acyclic Graphs (DAGs) to define feature relationships

Example usage:
```python
Expand Down Expand Up @@ -45,6 +46,7 @@
import random
from typing import Any

from graphlib import TopologicalSorter
import numpy as np
import pandas as pd
import torch
Expand Down Expand Up @@ -266,6 +268,7 @@ def impute_(
t: float = 0.000000001,
n_permutations: int = 10,
condition_on_all_features: bool = True,
dag: dict[int, list[int]] | None = None,
fast_mode: bool = False,
) -> torch.Tensor:
"""Impute missing values (np.nan) in X by sampling all cells independently from the trained models.
Expand All @@ -279,6 +282,8 @@ def impute_(
Number of permutations to use for imputation
condition_on_all_features: bool, default=True
Whether to condition on all other features (True) or only previous features (False)
dag: dict[int, list[int]] | None, default=None
Dictionary representing a Directed Acyclic Graph (DAG) defining feature dependencies.
fast_mode: bool, default=False
Whether to use faster settings for testing

Expand All @@ -290,11 +295,28 @@ def impute_(

X_fit = self.X_
impute_X = copy.deepcopy(X)


# check if dag is provided
if dag is not None:
if condition_on_all_features:
raise ValueError(
"DAG cannot be used with condition_on_all_features=True."
)
# fill up the DAG with empty lists for features not in the DAG
for i in all_features:
if i not in dag:
dag[i] = []
ts = TopologicalSorter(dag)
# re-order all_features based on the DAG (throws error in case of cycles)
all_features = list(ts.static_order())

for i in tqdm(range(len(all_features))):
column_idx = all_features[i]

if not condition_on_all_features:
if dag is not None:
# If a DAG is provided, use the dependencies from the DAG
conditional_idx = dag.get(column_idx, [])
elif not condition_on_all_features:
conditional_idx = all_features[:i] if i > 0 else []
else:
conditional_idx = list(set(range(X.shape[1])) - {column_idx})
Expand All @@ -310,7 +332,7 @@ def impute_(
densities: list[Any] = []
# Use fewer permutations in fast mode
actual_n_permutations = 1 if fast_mode else n_permutations

for perm in efficient_random_permutation(
conditional_idx,
actual_n_permutations,
Expand Down Expand Up @@ -502,8 +524,8 @@ def density_(
else:
# If the first feature, use a zero feature as input
# Because of preprocessing, we can't use a zero feature, so we use a random feature
X_fit, y_fit = torch.randn_like(X_fit[:, 0:1]), X_fit[:, 0]
X_predict, y_predict = torch.randn_like(X_predict[:, 0:1]), X_predict[:, 0]
X_fit, y_fit = torch.randn_like(X_fit[:, 0:1]), X_fit[:, column_idx]
X_predict, y_predict = torch.randn_like(X_predict[:, 0:1]), X_predict[:, column_idx]

model = (
self.tabpfn_clf
Expand All @@ -527,6 +549,7 @@ def impute(
X: torch.Tensor | np.ndarray | pd.DataFrame,
t: float = 0.000000001,
n_permutations: int = 10,
dag: dict[int, list[int]] | None = None,
) -> torch.Tensor:
"""Impute missing values in the input data using the fitted TabPFN models.

Expand All @@ -547,6 +570,9 @@ def impute(
n_permutations: int, default=10
Number of random feature permutations to use for imputation.
Higher values may improve robustness but increase computation time.

dag: dict[int, list[int]] | None, default=None
Dictionary representing a Directed Acyclic Graph (DAG) defining feature dependencies.

Returns:
torch.Tensor
Expand All @@ -569,6 +595,7 @@ def impute(
t,
condition_on_all_features=True,
n_permutations=n_permutations,
dag=dag,
fast_mode=fast_mode,
)

Expand Down Expand Up @@ -770,6 +797,7 @@ def generate_synthetic_data(
n_samples: int = 100,
t: float = 1.0,
n_permutations: int = 3,
dag: dict[int, list[int]] | None = None,
) -> torch.Tensor:
"""Generate synthetic tabular data samples using the fitted TabPFN models.

Expand All @@ -789,6 +817,10 @@ def generate_synthetic_data(
n_permutations: int, default=3
Number of feature permutations to use for generation
More permutations may provide more robust results but increase computation time

dag: dict[int, list[int]] | None, default=None
Dictionary representing a Directed Acyclic Graph (DAG) defining feature dependencies.
If provided, the generation will respect the dependencies defined in the DAG.

Returns:
torch.Tensor:
Expand Down Expand Up @@ -822,6 +854,7 @@ def generate_synthetic_data(
condition_on_all_features=False,
n_permutations=actual_n_permutations,
fast_mode=fast_mode,
dag=dag,
)

def get_embeddings(self, X: torch.tensor, per_column: bool = False) -> torch.tensor:
Expand Down