-
Notifications
You must be signed in to change notification settings - Fork 31
Description
Summary
NIR's type inference system currently requires neuron parameters (tau, v_threshold, etc.) to be arrays with explicit shapes to determine the number of neurons in a layer. This forces scalar parameters to be broadcast into full arrays, creating unnecessary weight duplication that is inefficient for hardware deployment.
Problem Description
When exporting spiking neural networks to NIR, scalar parameters (which are common and hardware-efficient) must be converted to full arrays to satisfy NIR's type checking system.
Example
Original snnTorch model:
import snntorch as snn
lif = snn.Leaky(beta=0.9) # Single scalar parameter for all neuronsCurrent NIR export requirement:
# For a 128-neuron layer, scalar beta=0.9 becomes:
nir.LIF(
tau=np.array([0.001, 0.001, 0.001, ..., 0.001]), # 128 identical values
v_threshold=np.array([1.0, 1.0, 1.0, ..., 1.0]), # 128 identical values
# ... all parameters duplicated 128 times
)What we want:
nir.LIF(
tau=np.array(0.001), # Single scalar value
v_threshold=np.array(1.0), # Single scalar value
neuron_count=128, # Explicit neuron count
)Root Cause
NIR's type inference relies on parameter array shapes to determine neuron counts:
# From NIR source - infers neuron count from parameter shapes
def infer_neuron_count(self):
return self.tau.shape[0] if self.tau.ndim > 0 else NoneWhen parameters are scalars (0-d arrays), NIR cannot infer the neuron count and sets empty type annotations:
input_type: {'input': array([], dtype=float64)}(wrong)output_type: {'output': array([], dtype=float64)}(wrong)
This causes type checking to fail with errors like:
ValueError: Type inference error: type mismatch: fc1.output: [[128]] -> lif.input: []
Impact
1. Memory Inefficiency
- A scalar parameter becomes an array of N identical values
- For large layers (e.g., 1024 neurons), this creates 1024x memory overhead
- Particularly problematic for neuromorphic hardware with limited memory
2. Hardware Deployment Issues
- Neuromorphic chips are optimized for shared scalar parameters
- Broadcasting forces hardware to store/access redundant identical weights
- Reduces effective model capacity on memory-constrained devices
3. Export Complexity
- Forces export tools to choose between efficiency and type safety
- Current workarounds require parameter broadcasting that loses semantic meaning
Proposed Solutions
Option 1: Explicit Neuron Count Parameter
Add an optional neuron_count parameter to neuron nodes:
nir.LIF(
tau=np.array(0.001), # Scalar - applies to all neurons
v_threshold=np.array(1.0),
neuron_count=128, # Explicit count for type inference
input_type={'input': np.array([128])}, # Auto-generated from neuron_count
output_type={'output': np.array([128])}, # Auto-generated from neuron_count
)Option 2: Broadcast Semantics
Add support for broadcast-compatible parameters:
nir.LIF(
tau=np.array(0.001), # Scalar broadcasts to all neurons
v_threshold=np.array([1.0, 1.2, 1.0]), # Per-neuron values where needed
broadcast_params=['tau'], # Explicit broadcast annotation
)Option 3: Shape Inference Context
Allow type inference to use context from connected layers:
# Infer neuron count from connected Affine layer output shape
fc1 = nir.Affine(weight=np.random.randn(128, 784)) # 128 outputs
lif = nir.LIF(tau=np.array(0.001)) # Infer 128 neurons from fc1 connectionCurrent Workaround
Export tools currently broadcast scalars to satisfy type checking:
def to_array(val, n_neurons):
"""Broadcast scalar to array of n_neurons length"""
if np.isscalar(val):
return np.full(n_neurons, float(val))
# ... handle existing arraysThis works but loses the semantic meaning and efficiency of scalar parameters.
Test Case
import nir
import numpy as np
# This should work without broadcasting
lif_scalar = nir.LIF(
tau=np.array(0.001), # Scalar parameter
v_threshold=np.array(1.0),
neuron_count=128, # Proposed: explicit neuron count
)
# Type inference should succeed
graph = nir.NIRGraph(
nodes={'lif': lif_scalar, ...},
edges=[...]
)
# Should pass: nir.read(file, type_check=True)References
- snnTorch export issue: Scalar parameters must be broadcast to arrays for NIR compatibility
- Hardware efficiency: Neuromorphic chips (Intel Loihi, SpiNNaker) optimize for shared scalar parameters
- Memory impact: Large models with broadcast parameters can exceed hardware memory limits
Environment
- NIR version: 1.0.7
- nirtorch version: 2.0.5
- Use case: snnTorch → NIR → neuromorphic hardware deployment
Expected Outcome: NIR should support efficient scalar parameters while maintaining type safety, enabling optimal hardware deployment without unnecessary memory overhead.