This repository contains the implementation of Spectral-Distilled Neural Forests (SD-NF), a framework for distilling large Gradient Boosted Decision Trees (GBDTs) into compact, efficient, and interpretable ensembles of shallow oblique neural trees.
The core logic is contained in a single, self-contained Python script: sd_nf_colab.py.
This script is designed to be "drop-in" ready for Google Colab or local environments. It handles:
- Dependency Management: Automatically installs required packages (
xgboost,openml,pytorch-tabnet, etc.) if missing. - Data Pipeline: Downloads and preprocesses datasets from OpenML-CC18 (Adult, Bank-Marketing, Electricity, Nomao).
- Model Training:
- Teacher: Trains a strong XGBoost classifier.
- Baselines: Trains Logistic Regression, CART (Decision Tree), MLP, and TabNet.
- Ablation: Trains a "Vanilla" Neural Forest (Sigmoid routing, no diversity regularization).
- SD-NF: Trains the proposed model using Entmax sparse routing and Spectral Diversity Regularization (SDR).
- Evaluation: Computes AUC, Accuracy, NLL, and theoretical FLOPs.
- Artifact Generation: Produces training curves, eigen-spectrum plots, and extracted boolean rules.
-
Oblique Decision Trees: Neural trees that split data using learnable linear hyperplanes (
$w^Tx + b$ ) rather than axis-aligned cuts. -
Sparse Routing: Uses
entmax1.5to force routing probabilities to exactly 0 or 1, enabling conditional computation and crisp rule extraction. - Spectral Diversity Regularization (SDR): A determinantal point process (DPP) based loss term that forces the decision boundaries of different trees to be orthogonal, preventing mode collapse during distillation.
To run the experiments locally:
python sd_nf_colab.pyHyperparameters are defined in the TRAINING_CONFIG dictionary near the top of the file:
TRAINING_CONFIG = {
"nf_depth": 5, # Depth of student trees
"nf_trees": 6, # Number of trees in student forest
"lambda_sdr": 0.01, # Strength of spectral diversity regularization
"lambda_sparse": 1e-3, # Strength of entropy sparsity penalty
# ...
}Outputs (metrics, plots, and rules) will be saved to /content/sd_nf_outputs (or relative to execution path).