Compute training data influence for MLPs and Transformers using K-FAC. Find which training examples most affected a model’s prediction — interpretability and debugging for PyTorch.
Influence functions estimate how much each training example contributed to a model’s prediction. This repo implements efficient influence computation via K-FAC (Kronecker-Factored Approximate Curvature) for:
- MLPs — e.g. MNIST image classification
- Transformers — autoregressive character-level language models
Use it for model interpretability, data debugging, finding mislabeled or influential examples, and understanding prediction behavior.
- PyTorch — native
torchmodels and training loops - K-FAC / EKFAC — scalable second-order influence approximation
- MLP + Transformer — MNIST demo and mini decoder-only transformer
- Visualization — top influential training samples (e.g. MNIST images)
- Modular —
InfluenceCalculableinterface to plug in your own layers
- Python 3.8+
- PyTorch 2.x (CPU or CUDA)
- See requirements.txt for full dependencies (
torch,torchvision,einops,matplotlib,tqdm, etc.)
git clone https://github.com/KuchikiRenji/InfluenceFunctions.git
cd InfluenceFunctions
pip install -r requirements.txt-
Train (optional; or use a pre-trained checkpoint):
python mnist_mlp.py
In
mnist_mlp.py, uncommenttrain_model()inif __name__ == "__main__"to train and savemodel.ckpt. -
Run influence analysis (find training examples that most influenced selected test samples):
python mnist_mlp.py
With the default
run_influence("model.ckpt", 1, 300, 1000), this loads the model, picks 1 query from the test set, uses 300 samples for gradient fitting, and searches over 1000 training samples. Results are printed and saved asresults_*.png.
-
Train (optional):
In
mini_transformer.py, uncommenttrain_char_predict()inif __name__ == "__main__"to train and savesmall_transformer.pth. -
Run influence:
python mini_transformer.py
With the default
calc_influence("small_transformer.pth"), this computes influential training sequences for chosen query sequences.
| File | Description |
|---|---|
influence_functions_mlp.py |
K-FAC influence for MLP blocks (MNIST-style) |
influence_functions_transformer.py |
K-FAC influence for transformer MLP blocks (autoregressive loss) |
mnist_mlp.py |
MNIST MLP model, training, and influence + visualization |
mini_transformer.py |
Small decoder-only transformer and influence on char-level data |
requirements.txt |
Python dependencies |
- K-FAC factors — approximate the Hessian with Kronecker factors from activations and gradient covariances over a dataset.
- Inverse-Hessian–vector products — computed efficiently using these factors.
- Influence scores — for a query point, estimate the effect of each training example on the loss (or prediction) via gradients and the approximate inverse Hessian.
The code uses an InfluenceCalculable interface: each block provides activations, gradient w.r.t. pre-activations, and weight gradients so K-FAC and influence can be computed layer-wise.
KuchikiRenji
- GitHub: github.com/KuchikiRenji
- Email: [email protected]
- Discord:
kuchiki_renji
See repository for license information. All other project content remains as in the original repository.