Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 119 additions & 3 deletions src/tabpfn_extensions/rf_pfn/sklearn_based_decision_tree_tabpfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import logging
import random
import warnings

Expand Down Expand Up @@ -35,6 +36,9 @@
)
from tabpfn_extensions.utils import softmax

# Define a module-level logger
logger = logging.getLogger(__name__)

###############################################################################
# BASE DECISION TREE #
###############################################################################
Expand Down Expand Up @@ -296,6 +300,8 @@ def _fit(
self : DecisionTreeTabPFNBase
The fitted model.
"""
if self.verbose:
logger.info("Starting DecisionTreeTabPFN fit process...")
# Initialize attributes (per scikit-learn conventions)
self._leaf_nodes = []
self._leaf_train_data = {}
Expand All @@ -317,6 +323,8 @@ def _fit(
y,
ensure_all_finite=False, # scikit-learn sets self.n_features_in_ automatically
)
if self.verbose:
logger.info(f"Input data shape: X={X.shape}, y={y.shape}")

if self.task_type == "multiclass":
self.classes_ = unique_labels(y)
Expand Down Expand Up @@ -345,20 +353,30 @@ def _fit(

# If adaptive_tree is on, do a train/validation split
if self.adaptive_tree:
if self.verbose:
logger.info(
"Adaptive tree is enabled. Preparing train/validation split."
)
stratify = y_ if (self.task_type == "multiclass") else None

# Basic checks for classification to see if splitting is feasible
if self.task_type == "multiclass":
unique_classes, counts = np.unique(y_, return_counts=True)
# Disable adaptive tree in extreme cases
if counts.min() == 1 or len(unique_classes) < 2:
if self.verbose:
logger.info(
"Disabling adaptive tree: minimum class count is 1 or only one class present."
)
self.adaptive_tree = False
elif len(unique_classes) > int(len(y_) * self.adaptive_tree_test_size):
self.adaptive_tree_test_size = min(
0.5,
len(unique_classes) / len(y_) * 1.5,
)
if len(y_) < 10:
if self.verbose:
logger.info("Disabling adaptive tree: fewer than 10 samples.")
self.adaptive_tree = False

if self.adaptive_tree:
Expand All @@ -380,9 +398,18 @@ def _fit(
random_state=self.random_state,
stratify=stratify,
)
if self.verbose:
logger.info(
f"Train/Valid split created: "
f"Train size={len(y_train)}, Valid size={len(y_valid)}"
)

# Safety check - if split is empty, revert
if len(y_train) == 0 or len(y_valid) == 0:
if self.verbose:
logger.info(
"Disabling adaptive tree: train or validation split is empty."
)
self.adaptive_tree = False
X_train, X_preproc_train, y_train, sw_train = (
X,
Expand All @@ -398,6 +425,10 @@ def _fit(
and self.adaptive_tree
and (len(np.unique(y_train)) != len(np.unique(y_valid)))
):
if self.verbose:
logger.info(
"Disabling adaptive tree: train and validation sets have different classes."
)
self.adaptive_tree = False
else:
# If we were disabled, keep all data as training
Expand All @@ -410,6 +441,8 @@ def _fit(
X_valid = X_preproc_valid = y_valid = sw_valid = None
else:
# Not adaptive, everything is train
if self.verbose:
logger.info("Adaptive tree is disabled. Using all data for training.")
X_train, X_preproc_train, y_train, sw_train = (
X,
X_preprocessed,
Expand All @@ -419,9 +452,15 @@ def _fit(
X_valid = X_preproc_valid = y_valid = sw_valid = None

# Build the sklearn decision tree
if self.verbose:
logger.info("Fitting the initial scikit-learn decision tree structure...")
self._decision_tree = self._init_decision_tree()
self._decision_tree.fit(X_preproc_train, y_train, sample_weight=sw_train)
self._tree = self._decision_tree # for sklearn compatibility
if self.verbose:
logger.info(
f"Decision tree fitting complete. Tree has {self._tree.tree_.node_count} nodes."
)

# Keep references for potential post-fitting (leaf-level fitting)
self.X = X
Expand All @@ -439,6 +478,8 @@ def _fit(

# We will do a leaf-fitting step on demand (lazy) in predict
self._need_post_fit = True
if self.verbose:
logger.info("Leaf fitting is deferred until the first predict() call.")

# If verbose, optionally do it right away:
if self.verbose:
Expand All @@ -461,7 +502,7 @@ def _init_decision_tree(self) -> BaseDecisionTree:
def _post_fit(self) -> None:
"""Hook after the decision tree is fitted. Can be used for final prints/logs."""
if self.verbose:
pass
logger.info("Base tree structure has been fitted.")

def _preprocess_data_for_tree(self, X: np.ndarray) -> np.ndarray:
"""Handle missing data prior to feeding into the decision tree.
Expand Down Expand Up @@ -620,15 +661,25 @@ def _predict_internal(
"""
# If we haven't yet done the final leaf fit, do it here
if self._need_post_fit:
if self.verbose:
logger.info("First predict call: executing deferred leaf fitting.")
self._need_post_fit = False
if self.adaptive_tree:
# Fit leaves on train data, check performance on valid data if available
if self.verbose:
logger.info(
"Fitting leaves on training data for adaptive pruning..."
)
self.fit_leaves(self.train_X, self.train_y)
if (
hasattr(self, "valid_X")
and self.valid_X is not None
and self.valid_y is not None
):
if self.verbose:
logger.info(
"Evaluating node performance on validation set for pruning decisions."
)
# Force a pass to evaluate node performance
# so we can prune or decide node updates
self._predict_internal(
Expand All @@ -637,6 +688,8 @@ def _predict_internal(
check_input=False,
)
# Now fit leaves again using the entire dataset (train + valid, effectively)
if self.verbose:
logger.info("Fitting leaves on the full dataset.")
self.fit_leaves(self.X, self.y)

# Assign TabPFNs categorical features if needed
Expand All @@ -646,6 +699,10 @@ def _predict_internal(
# Find leaf membership in X
X_leaf_nodes = self._apply_tree(X)
n_samples, n_nodes, n_estims = X_leaf_nodes.shape
if self.verbose:
logger.info(
f"Starting prediction for {n_samples} samples across {n_nodes} nodes."
)

# Track intermediate predictions
y_prob: dict[int, dict[int, np.ndarray]] = {}
Expand Down Expand Up @@ -701,6 +758,13 @@ def _predict_internal(
X_leaf_nodes[test_sample_indices, leaf_id + 1 :, est_id].sum()
== 0.0
)
if self.verbose:
logger.info(
f"Processing Node {leaf_id}: "
f"Train Samples={X_train_leaf.shape[0]}, "
f"Test Samples={len(test_sample_indices)}, "
f"Is Final Leaf={is_leaf}"
)

# If it's not a leaf and we are not fitting internal nodes, skip
# (unless leaf_id==0 and we do a top-level check for adaptive_tree)
Expand All @@ -709,6 +773,10 @@ def _predict_internal(
and (not self.fit_nodes)
and not (leaf_id == 0 and self.adaptive_tree)
):
if self.verbose:
logger.info(
f" -> Skipping Node {leaf_id}: Not a final leaf and fit_nodes is False."
)
if do_pruning:
self._node_prediction_type[est_id][leaf_id] = "previous"
continue
Expand All @@ -725,6 +793,10 @@ def _predict_internal(
should_skip_previously_pruned = True

if should_skip_previously_pruned:
if self.verbose:
logger.info(
f" -> Skipping Node {leaf_id}: Node was previously pruned."
)
continue

# Skip if classification is missing a class
Expand All @@ -733,6 +805,10 @@ def _predict_internal(
and len(np.unique(y_train_leaf)) < self.n_classes_
and self.adaptive_tree_skip_class_missing
):
if self.verbose:
logger.info(
f" -> Skipping Node {leaf_id}: Not all classes are present in training data."
)
self._node_prediction_type[est_id][leaf_id] = "previous"
continue

Expand All @@ -749,6 +825,10 @@ def _predict_internal(
and not is_leaf
)
):
if self.verbose:
logger.info(
f" -> Skipping Node {leaf_id}: Does not meet sample size requirements for adaptive fitting."
)
if do_pruning:
self._node_prediction_type[est_id][leaf_id] = "previous"
continue
Expand Down Expand Up @@ -797,10 +877,18 @@ def _predict_internal(
y,
y_prob[est_id][leaf_id],
)
if self.verbose:
logger.info(
f" -> Pruning Result for Node {leaf_id}: "
f"Type='{self._node_prediction_type[est_id][leaf_id]}', "
f"Score={y_metric[est_id][leaf_id]:.4f}"
)
else:
# If not validating and not adaptive, just use replacement
y_prob[est_id][leaf_id] = y_prob_replacement

if self.verbose:
logger.info("Prediction process finished.")
# Final predictions come from the last estimators last node
return y_prob[n_estims - 1][n_nodes - 1]

Expand Down Expand Up @@ -1151,12 +1239,18 @@ def _predict_leaf(

# If only one class, fill probability 1.0 for that class
if len(classes_in_leaf) == 1:
if self.verbose:
logger.info(
f" -> Node {leaf_id}: Only one class present. Predicting 1.0 for class {classes_in_leaf[0]}."
)
y_eval_prob[indices, classes_in_leaf[0]] = 1.0
return y_eval_prob

# Otherwise, fit TabPFN
leaf_seed = leaf_id + self.tree_seed
try:
if self.verbose:
logger.info(f" -> Node {leaf_id}: Fitting TabPFNClassifier.")
self.tabpfn.random_state = leaf_seed
self.tabpfn.fit(X_train_leaf, y_train_leaf)

Expand All @@ -1182,6 +1276,10 @@ def _predict_leaf(
"One node has constant features for TabPFN. Using class-ratio fallback.",
stacklevel=2,
)
if self.verbose:
logger.warning(
f" -> Node {leaf_id}: TabPFN failed due to constant features. Using class ratio fallback."
)
_, counts = np.unique(y_train_leaf, return_counts=True)
ratio = counts / counts.sum()
for i, c in enumerate(classes_in_leaf):
Expand Down Expand Up @@ -1231,7 +1329,7 @@ def predict_proba(self, X: np.ndarray, check_input: bool = True) -> np.ndarray:
def _post_fit(self) -> None:
"""Optional hook after the decision tree is fitted."""
if self.verbose:
pass
logger.info("Classifier tree structure has been fitted.")


###############################################################################
Expand Down Expand Up @@ -1354,23 +1452,37 @@ def _predict_leaf(

# If no training data or just 1 sample, fall back to 0 or single value
if len(X_train_leaf) < 1:
if self.verbose:
logger.info(
f" -> Node {leaf_id}: No training samples. Predicting 0.0."
)
warnings.warn(
f"Leaf {leaf_id} has zero training samples. Returning 0.0 predictions.",
stacklevel=2,
)
return y_eval
elif len(X_train_leaf) == 1:
if self.verbose:
logger.info(
f" -> Node {leaf_id}: Only one training sample. Predicting its value."
)
y_eval[indices] = y_train_leaf[0]
return y_eval

# If all y are identical, return that constant
if np.all(y_train_leaf == y_train_leaf[0]):
if self.verbose:
logger.info(
f" -> Node {leaf_id}: All target values are constant. Predicting {y_train_leaf[0]}."
)
y_eval[indices] = y_train_leaf[0]
return y_eval

# Fit TabPFNRegressor
leaf_seed = leaf_id + self.tree_seed
try:
if self.verbose:
logger.info(f" -> Node {leaf_id}: Fitting TabPFNRegressor.")
self.tabpfn.random_state = leaf_seed
self.tabpfn.fit(X_train_leaf, y_train_leaf)

Expand All @@ -1389,6 +1501,10 @@ def _predict_leaf(
f"TabPFN fit/predict failed at leaf {leaf_id}: {e}. Using mean fallback.",
stacklevel=2,
)
if self.verbose:
logger.warning(
f" -> Node {leaf_id}: TabPFN failed ({e}). Using mean fallback."
)
y_eval[indices] = np.mean(y_train_leaf)

return y_eval
Expand Down Expand Up @@ -1442,4 +1558,4 @@ def predict_full(self, X: np.ndarray) -> np.ndarray:
def _post_fit(self) -> None:
"""Optional hook after the regressor's tree is fitted."""
if self.verbose:
pass
logger.info("Regressor tree structure has been fitted.")
Loading