Skip to content

Commit 221357c

Browse files
authored
Merge pull request #266 from NeuroBench/feat/macs_neuron
Feat: Neuron Operations
2 parents 4fa2cfb + 81f58d5 commit 221357c

File tree

7 files changed

+2350
-1441
lines changed

7 files changed

+2350
-1441
lines changed

docs/metrics/workload_metrics/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ Workload Metrics
1313
smape
1414
r2
1515
mse
16+
neuron_operations
1617

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
Neuron Operations Metric
2+
========================
3+
4+
The `NeuronOperations` metric is designed to measure the computational workload associated with neuron activity in spiking neural networks (SNNs). It tracks the number of operations required to update the membrane potential of neurons during the forward pass of a model. These operations include the reset mechanisms defined in the `neuron_ops_reset_operations` dictionary, such as "subtract" and "zero".
5+
6+
Purpose
7+
-------
8+
The metric provides insights into the computational cost of neuron updates, which is critical for analyzing and optimizing the efficiency of spiking neural networks. By understanding the workload associated with different neuron types and reset mechanisms, researchers can identify bottlenecks and improve model performance.
9+
10+
Reset Mechanisms
11+
----------------
12+
Each neuron type has specific reset mechanisms that determine how the membrane potential is updated after a spike. The two main reset mechanisms are:
13+
14+
1. **Subtract**: The membrane potential is reduced by a certain value after a spike.
15+
2. **Zero**: The membrane potential is reset to zero after a spike.
16+
17+
The computational cost of these mechanisms is defined in the `neuron_ops_reset_operations` dictionary. For example, the "Leaky" neuron type has the following costs:
18+
- Subtract mechanism: 4 operations
19+
- Zero mechanism: 4 operations
20+
21+
Why 4 Operations for "Leaky" Neurons?
22+
-------------------------------------
23+
The computational cost of 4 operations for both the "subtract" and "zero" mechanisms in "Leaky" neurons is an abstraction that represents the number of basic mathematical operations required to perform the reset. These operations include the steps involved in updating the membrane potential, checking conditions, and writing the updated state back to memory.
24+
25+
**Subtract Mechanism**
26+
If the `reset_mechanism` is set to "subtract", the membrane potential :math:`U[t+1]` will have the `threshold` subtracted from it whenever the neuron emits a spike. The update equation is:
27+
28+
.. math::
29+
30+
U[t+1] = \beta U[t] + I_{\rm in}[t+1] - R U_{\rm thr}
31+
32+
Here’s the breakdown of the 4 operations:
33+
34+
1. **Decay Term**: Multiply the previous membrane potential :math:`U[t]` by the decay factor :math:`\beta` (1 operation).
35+
2. **Input Current**: Add the input current :math:`I_{\rm in}[t+1]` to the decayed potential (1 operation).
36+
3. **Reset Multiplication**: Multiply the reset factor :math:`R` by the threshold :math:`U_{\rm thr}` (1 operation).
37+
4. **Threshold Subtraction**: Subtract the result of the reset multiplication from the decayed potential and input current (1 operation).
38+
39+
**Zero Mechanism**
40+
If the `reset_mechanism` is set to "zero", the membrane potential :math:`U[t+1]` will be reset to zero whenever the neuron emits a spike. The update equation is:
41+
42+
.. math::
43+
44+
U[t+1] = \beta U[t] + I_{\rm syn}[t+1] - R(\beta U[t] + I_{\rm in}[t+1])
45+
46+
Here’s the breakdown of the 4 operations:
47+
48+
1. **Decay Term**: Multiply the previous membrane potential :math:`U[t]` by the decay factor :math:`\beta` (1 operation).
49+
2. **Input Current**: Add the synaptic input current :math:`I_{\rm syn}[t+1]` to the decayed potential (1 operation).
50+
3. **Reset Multiplication**: Multiply the reset factor :math:`R` by the sum of the decayed potential and input current :math:`(\beta U[t] + I_{\rm in}[t+1])` (1 operation).
51+
4. **Reset Subtraction**: Subtract the result of the reset multiplication from the decayed potential and input current (1 operation).
52+
53+
**Why Abstract the Cost to 4 Operations?**
54+
55+
The value of 4 operations is an abstraction that simplifies the computational workload into a consistent metric. While the actual number of operations may vary slightly depending on the implementation, this abstraction provides a way to compare the computational cost of different neuron types and reset mechanisms. It is particularly useful for analyzing and optimizing spiking neural networks across various implementations.
56+
57+
Example: Leaky Neuron with snnTorch
58+
-----------------------------------
59+
Let’s consider an example using the "Leaky" neuron type with `snntorch`. Assume we have a layer of "Leaky" neurons, and we want to compute the workload for the "subtract" and "zero" reset mechanisms.
60+
61+
1. **Subtract Mechanism**:
62+
63+
- After a spike, the membrane potential is reduced by a fixed value.
64+
- If there are 100 neurons in the layer and each neuron spikes once, the total computational cost is:
65+
.. math::
66+
\text{Total Cost} = \text{Number of Neurons} \times \text{Cost per Subtract}
67+
= 100 \times 4 = 400 \text{ operations.}
68+
69+
2. **Zero Mechanism**:
70+
71+
- After a spike, the membrane potential is reset to zero.
72+
- If the same 100 neurons spike once, the total computational cost is:
73+
.. math::
74+
\text{Total Cost} = \text{Number of Neurons} \times \text{Cost per Zero}
75+
= 100 \times 4 = 400 \text{ operations.}
76+
77+
Outputs
78+
-------
79+
The `NeuronOperations` metric provides two key outputs:
80+
81+
1. **Effective Neuron Ops**: The total number of operations actually performed by neurons, normalized by the number of samples. This value accounts for the actual activity of the neurons during the forward pass, considering only the updates that occur when neurons spike.
82+
83+
2. **Neuron Dense Ops**: The total number of operations that would be computed if all neurons were updated at every time step, regardless of whether they spiked or not. This represents the theoretical maximum workload for the network under full activity.
84+
85+
These outputs help quantify the computational workload of the network, both in terms of actual activity and theoretical maximum activity, and can be used to optimize the model's efficiency.

examples/nehar/benchmark.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
import os
2+
import sys
3+
import os
4+
5+
SCRIPT_DIR = os.path.dirname(os.path.abspath("./"))
6+
sys.path.append(os.path.dirname(SCRIPT_DIR))
7+
print("Adding path to sys.path:", os.path.dirname(SCRIPT_DIR))
8+
29
from neurobench.datasets import WISDM
310
from training import SpikingNetwork
411
from neurobench.processors.postprocessors import ChooseMaxCount
@@ -11,6 +18,7 @@
1118
SynapticOperations,
1219
ClassificationAccuracy,
1320
ActivationSparsityByLayer,
21+
NeuronOperations
1422
)
1523
from neurobench.metrics.static import (
1624
ParameterCount,
@@ -34,9 +42,7 @@
3442
num_outputs = data_module.num_outputs
3543
num_steps = data_module.num_steps
3644

37-
spiking_network = SpikingNetwork.load_from_checkpoint(
38-
model_path, map_location="cpu"
39-
)
45+
spiking_network = SpikingNetwork(lr=1)
4046

4147
model = SNNTorchModel(spiking_network.model, custom_forward=True)
4248
test_set_loader = data_module.test_dataloader()
@@ -48,7 +54,7 @@
4854

4955
# #
5056
static_metrics = [ParameterCount, Footprint, ConnectionSparsity]
51-
workload_metrics = [ActivationSparsity, ActivationSparsityByLayer,MembraneUpdates, SynapticOperations, ClassificationAccuracy]
57+
workload_metrics = [ActivationSparsity, ActivationSparsityByLayer,MembraneUpdates, SynapticOperations, ClassificationAccuracy, NeuronOperations]
5258
# #
5359
benchmark = Benchmark(
5460
model, test_set_loader, [], postprocessors, [static_metrics, workload_metrics]

neurobench/metrics/workload/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .smape import SMAPE
88
from .r2 import R2
99
from .coco_map import CocoMap
10+
from .neuron_operations import NeuronOperations
1011

1112
__stateless__ = [
1213
"ClassificationAccuracy",
@@ -16,6 +17,12 @@
1617
"ActivationSparsityByLayer",
1718
]
1819

19-
__stateful__ = ["MembraneUpdates", "SynapticOperations", "R2", "CocoMap"]
20+
__stateful__ = [
21+
"MembraneUpdates",
22+
"SynapticOperations",
23+
"R2",
24+
"CocoMap",
25+
"NeuronOperations",
26+
]
2027

2128
__all__ = __stateful__ + __stateless__
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import torch
2+
from neurobench.metrics.abstract.workload_metric import AccumulatedMetric
3+
from collections import defaultdict
4+
5+
6+
neuron_ops_reset_operations = {
7+
"Leaky": {
8+
"subtract": 4,
9+
"zero": 4,
10+
},
11+
"Synaptic": {
12+
"subtract": 6,
13+
"zero": 8,
14+
},
15+
"Lapicque": {
16+
"subtract": 11,
17+
"zero": 11,
18+
},
19+
"Alpha": {
20+
"subtract": 18,
21+
"zero": 24,
22+
},
23+
}
24+
"""
25+
The `neuron_ops_reset_operations` dictionary defines the computational cost associated
26+
with resetting the membrane potential of neurons for different neuron types. The reset
27+
mechanisms are categorized into two types:
28+
29+
1. **Subtract**: Represents a reset mechanism where the membrane potential is reduced
30+
by a certain value. The value associated with this mechanism indicates the computational
31+
cost (in terms of basic operations) required to perform this type of reset.
32+
33+
2. **Zero**: Represents a reset mechanism where the membrane potential is reset to zero.
34+
The value associated with this mechanism indicates the computational cost (in terms of
35+
basic operations) required to perform this type of reset.
36+
37+
### Neuron Types and Their Computational Costs:
38+
- **Leaky**:
39+
- Subtract mechanism: 4 operations
40+
- Zero mechanism: 4 operations
41+
- **Synaptic**:
42+
- Subtract mechanism: 6 operations
43+
- Zero mechanism: 8 operations
44+
- **Lapicque**:
45+
- Subtract mechanism: 11 operations
46+
- Zero mechanism: 11 operations
47+
- **Alpha**:
48+
- Subtract mechanism: 18 operations
49+
- Zero mechanism: 24 operations
50+
51+
### Purpose:
52+
The values in this dictionary represent the computational cost (measured in terms of
53+
basic operations like addition, subtraction, etc.) required for each neuron type to
54+
reset its membrane potential using a specific reset mechanism.
55+
"""
56+
57+
58+
class NeuronOperations(AccumulatedMetric):
59+
"""
60+
Neuron operations metric.
61+
62+
This metric computes the number of operations performed by neurons during the
63+
forward pass of the model. The operations are tracked per neuron, per layer.
64+
65+
The `NeuronOperations` metric is designed to measure the computational workload
66+
associated with neuron activity in spiking neural networks. Specifically, it tracks
67+
the number of operations required to update the membrane potential of neurons during
68+
the forward pass. These operations include the reset mechanisms defined in the
69+
`neuron_ops_reset_operations` dictionary, such as "subtract" and "zero".
70+
71+
"""
72+
73+
def __init__(self):
74+
"""Initialize the NeuronOperations metric."""
75+
super().__init__(requires_hooks=True)
76+
self.total_samples = 0
77+
self.dense = defaultdict(int)
78+
self.macs = defaultdict(int)
79+
80+
def reset(self):
81+
"""Reset the metric state for a new evaluation."""
82+
self.total_samples = 0
83+
self.dense = defaultdict(int)
84+
self.macs = defaultdict(int)
85+
86+
def __call__(self, model, preds, data):
87+
"""
88+
Accumulate the neuron operations.
89+
90+
Args:
91+
model: A NeuroBenchModel.
92+
preds: A tensor of model predictions.
93+
data: A tuple of data and labels.
94+
Returns:
95+
float: Number of membrane potential updates.
96+
97+
"""
98+
for hook in model.activation_hooks:
99+
layer_type = hook.layer.__class__.__name__
100+
reset_mechanism = hook.layer._reset_mechanism
101+
updates = 0
102+
103+
if len(hook.pre_fire_mem_potential) > 1:
104+
pre_fire_mem = torch.stack(hook.pre_fire_mem_potential[1:])
105+
post_fire_mem = torch.stack(hook.post_fire_mem_potential[1:])
106+
updates += torch.count_nonzero(pre_fire_mem - post_fire_mem).item()
107+
if hook.post_fire_mem_potential:
108+
updates += hook.post_fire_mem_potential[0].numel()
109+
110+
self.macs[layer_type] += (
111+
updates * neuron_ops_reset_operations[layer_type][reset_mechanism]
112+
)
113+
self.dense[layer_type] += (
114+
hook.post_fire_mem_potential[0].numel()
115+
* len(hook.post_fire_mem_potential)
116+
* neuron_ops_reset_operations[layer_type][reset_mechanism]
117+
)
118+
119+
self.total_samples += data[0].size(0)
120+
121+
return self.compute()
122+
123+
def compute(self):
124+
"""
125+
Compute the total membrane updates normalized by the number of samples.
126+
127+
Returns:
128+
float: Compute the total updates to each neuron's membrane potential within the model,
129+
aggregated across all neurons and normalized by the number of samples processed.
130+
131+
"""
132+
if self.total_samples == 0:
133+
return {"Effective Neuron Ops": 0, "Neuron Dense Ops": 0}
134+
135+
macs = sum(self.macs.values())
136+
dense = sum(self.dense.values())
137+
138+
return {
139+
"Effective Neuron Ops": macs / self.total_samples,
140+
"Neuron Dense Ops": dense / self.total_samples,
141+
}

0 commit comments

Comments
 (0)