This example demonstrates using DeepSHAP for neural network explanations with ONNX models.
-
ONNX Runtime installed:
# macOS brew install onnxruntime # Linux (Ubuntu/Debian) # Download from https://github.com/microsoft/onnxruntime/releases
-
Python dependencies for model generation:
pip install torch scikit-learn onnx
-
Generate the ONNX neural network model:
python generate_model.py
-
Run the example:
go run main.go
- Parsing ONNX graph structure with
onnx.ParseGraph - Creating
ActivationSessionfor intermediate layer capture - Using DeepSHAP for neural network explanations
- Analyzing layer activations
- Interpreting DeepLIFT-based SHAP values
The example uses a simple MLP (Multi-Layer Perceptron) trained on the Iris dataset:
Input (4 features)
↓
Dense (8 neurons) + ReLU
↓
Dense (8 neurons) + ReLU
↓
Dense (1 neuron) + Sigmoid
↓
Output (probability)
DeepSHAP combines DeepLIFT with Shapley values:
- Forward pass: Capture activations at each layer
- Reference activations: Compute baseline activations from background data
- Backward propagation: Apply DeepLIFT rescale rule to propagate attributions
- Average: Average attributions over multiple background samples
The rescale rule for each neuron:
multiplier_in = multiplier_out × (activation - reference) / (output - output_ref)
ONNX Graph Structure:
Inputs: [input]
Outputs: [output]
Nodes: 7
/fc1/Gemm (Gemm -> dense)
/relu1/Relu (Relu -> relu)
/fc2/Gemm (Gemm -> dense)
/relu2/Relu (Relu -> relu)
/fc3/Gemm (Gemm -> dense)
/sigmoid/Sigmoid (Sigmoid -> sigmoid)
DeepSHAP Neural Network Explanations
=====================================
Instance: Versicolor sample
Features: [6 2.7 4.5 1.5]
Prediction: 0.7234
Base Value: 0.3333
SHAP Values (DeepLIFT-based):
sepal_length : value= 6.00, SHAP=+0.0856
sepal_width : value= 2.70, SHAP=-0.0234
petal_length : value= 4.50, SHAP=+0.2145
petal_width : value= 1.50, SHAP=+0.1134
Top Contributing Features:
1. petal_length: +0.2145 ↑
2. petal_width: +0.1134 ↑
3. sepal_length: +0.0856 ↑
4. sepal_width: -0.0234 ↓
Current DeepSHAP implementation supports:
- Dense (Gemm, MatMul) layers
- ReLU, Sigmoid, Tanh activations
- Softmax output layers
- Sequential architectures
Not yet supported:
- Convolutional layers
- Recurrent layers
- Residual connections
- Attention mechanisms