Skip to content

adithyap/spectral-distilled-neural-forests

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 

Repository files navigation

Spectral-Distilled Neural Forests (SD-NF)

This repository contains the implementation of Spectral-Distilled Neural Forests (SD-NF), a framework for distilling large Gradient Boosted Decision Trees (GBDTs) into compact, efficient, and interpretable ensembles of shallow oblique neural trees.

The Code

The core logic is contained in a single, self-contained Python script: sd_nf_colab.py.

This script is designed to be "drop-in" ready for Google Colab or local environments. It handles:

  1. Dependency Management: Automatically installs required packages (xgboost, openml, pytorch-tabnet, etc.) if missing.
  2. Data Pipeline: Downloads and preprocesses datasets from OpenML-CC18 (Adult, Bank-Marketing, Electricity, Nomao).
  3. Model Training:
    • Teacher: Trains a strong XGBoost classifier.
    • Baselines: Trains Logistic Regression, CART (Decision Tree), MLP, and TabNet.
    • Ablation: Trains a "Vanilla" Neural Forest (Sigmoid routing, no diversity regularization).
    • SD-NF: Trains the proposed model using Entmax sparse routing and Spectral Diversity Regularization (SDR).
  4. Evaluation: Computes AUC, Accuracy, NLL, and theoretical FLOPs.
  5. Artifact Generation: Produces training curves, eigen-spectrum plots, and extracted boolean rules.

Key Features Implemented

  • Oblique Decision Trees: Neural trees that split data using learnable linear hyperplanes ($w^Tx + b$) rather than axis-aligned cuts.
  • Sparse Routing: Uses entmax1.5 to force routing probabilities to exactly 0 or 1, enabling conditional computation and crisp rule extraction.
  • Spectral Diversity Regularization (SDR): A determinantal point process (DPP) based loss term that forces the decision boundaries of different trees to be orthogonal, preventing mode collapse during distillation.

Usage

To run the experiments locally:

python sd_nf_colab.py

Configuration

Hyperparameters are defined in the TRAINING_CONFIG dictionary near the top of the file:

TRAINING_CONFIG = {
    "nf_depth": 5,          # Depth of student trees
    "nf_trees": 6,          # Number of trees in student forest
    "lambda_sdr": 0.01,     # Strength of spectral diversity regularization
    "lambda_sparse": 1e-3,  # Strength of entropy sparsity penalty
    # ...
}

Outputs (metrics, plots, and rules) will be saved to /content/sd_nf_outputs (or relative to execution path).

About

A PyTorch implementation of Spectral-Distilled Neural Forests, a framework for distilling large Gradient Boosted Decision Trees (GBDTs) into compact, interpretable, and efficient ensembles of shallow oblique neural trees.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages