Skip to content

PyTorch implementation of influence functions with K-FAC for MLPs and Transformers. Find which training examples most affect model predictions.

Notifications You must be signed in to change notification settings

KuchikiRenji/InfluenceFunctions

Repository files navigation

Influence Functions for Neural Networks — PyTorch

Compute training data influence for MLPs and Transformers using K-FAC. Find which training examples most affected a model’s prediction — interpretability and debugging for PyTorch.


What Are Influence Functions?

Influence functions estimate how much each training example contributed to a model’s prediction. This repo implements efficient influence computation via K-FAC (Kronecker-Factored Approximate Curvature) for:

  • MLPs — e.g. MNIST image classification
  • Transformers — autoregressive character-level language models

Use it for model interpretability, data debugging, finding mislabeled or influential examples, and understanding prediction behavior.


Features

  • PyTorch — native torch models and training loops
  • K-FAC / EKFAC — scalable second-order influence approximation
  • MLP + Transformer — MNIST demo and mini decoder-only transformer
  • Visualization — top influential training samples (e.g. MNIST images)
  • ModularInfluenceCalculable interface to plug in your own layers

Requirements

  • Python 3.8+
  • PyTorch 2.x (CPU or CUDA)
  • See requirements.txt for full dependencies (torch, torchvision, einops, matplotlib, tqdm, etc.)

Installation

git clone https://github.com/KuchikiRenji/InfluenceFunctions.git
cd InfluenceFunctions
pip install -r requirements.txt

Quick Start

MNIST MLP — influence on test predictions

  1. Train (optional; or use a pre-trained checkpoint):

    python mnist_mlp.py

    In mnist_mlp.py, uncomment train_model() in if __name__ == "__main__" to train and save model.ckpt.

  2. Run influence analysis (find training examples that most influenced selected test samples):

    python mnist_mlp.py

    With the default run_influence("model.ckpt", 1, 300, 1000), this loads the model, picks 1 query from the test set, uses 300 samples for gradient fitting, and searches over 1000 training samples. Results are printed and saved as results_*.png.

Transformer — character-level influence

  1. Train (optional):

    In mini_transformer.py, uncomment train_char_predict() in if __name__ == "__main__" to train and save small_transformer.pth.

  2. Run influence:

    python mini_transformer.py

    With the default calc_influence("small_transformer.pth"), this computes influential training sequences for chosen query sequences.


Project Structure

File Description
influence_functions_mlp.py K-FAC influence for MLP blocks (MNIST-style)
influence_functions_transformer.py K-FAC influence for transformer MLP blocks (autoregressive loss)
mnist_mlp.py MNIST MLP model, training, and influence + visualization
mini_transformer.py Small decoder-only transformer and influence on char-level data
requirements.txt Python dependencies

How It Works (High Level)

  1. K-FAC factors — approximate the Hessian with Kronecker factors from activations and gradient covariances over a dataset.
  2. Inverse-Hessian–vector products — computed efficiently using these factors.
  3. Influence scores — for a query point, estimate the effect of each training example on the loss (or prediction) via gradients and the approximate inverse Hessian.

The code uses an InfluenceCalculable interface: each block provides activations, gradient w.r.t. pre-activations, and weight gradients so K-FAC and influence can be computed layer-wise.


Author & Contact

KuchikiRenji


License

See repository for license information. All other project content remains as in the original repository.

About

PyTorch implementation of influence functions with K-FAC for MLPs and Transformers. Find which training examples most affect model predictions.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages