-
Notifications
You must be signed in to change notification settings - Fork 280
Description
Summary
When exporting an SNN model from snnTorch to NIR using export_to_nir(), the exported CubaLIF nodes (from snn.Synaptic) are missing their input_type and output_type fields. This causes type inference to fail when loading the model with nir.read().
The weights and network structure are exported correctly. Only the shape metadata is missing.
Environment
python: 3.11.13
snntorch: 0.9.4
nir: 1.0.7
torch: 2.9.1
Minimal Example
import torch
import torch.nn as nn
import snntorch as snn
from snntorch.export_nir import export_to_nir
import nir
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(2450, 128)
self.lif1 = snn.Synaptic(alpha=0.5, beta=0.9)
self.fc2 = nn.Linear(128, 4)
self.lif2 = snn.Synaptic(alpha=0.9, beta=0.9)
def forward(self, x):
cur1 = self.fc1(x)
spk1, syn1, mem1 = self.lif1(cur1)
cur2 = self.fc2(spk1)
spk2, syn2, mem2 = self.lif2(cur2)
return spk2
net = Net()
sample = torch.randn(2450)
nir_model = export_to_nir(net, sample)
nir.write("model.nir", nir_model)
# This fails
nir.read("model.nir")Error:
ValueError: Type inference error: type mismatch: fc1.output: [[128]] -> lif1.input: []
Root Cause
The exported CubaLIF nodes have empty type dictionaries:
# Node IDs may differ depending on how snnTorch assigns names
nir_model = nir.read("model.nir", type_check=False)
print(nir_model.nodes['lif1'].input_type) # {}
print(nir_model.nodes['lif1'].output_type) # {}Meanwhile Affine nodes from nn.Linear are correctly annotated:
print(nir_model.nodes['fc1'].input_type) # {0: array([2450])}
print(nir_model.nodes['fc1'].output_type) # {0: array([128])}The exporter creates CubaLIF nodes without setting input_type and output_type.
Expected Behavior
export_to_nir() should populate input_type and output_type for all neuron nodes so that nir.read() passes type checking without requiring type_check=False.
Suggested Fix
When creating CubaLIF nodes in the exporter, include the type annotations:
# Current
nir.CubaLIF(
tau_syn=tau_syn,
tau_mem=tau_mem,
v_threshold=threshold,
)
# Fixed
nir.CubaLIF(
tau_syn=tau_syn,
tau_mem=tau_mem,
v_threshold=threshold,
input_type={0: input_shape},
output_type={0: output_shape},
)This applies to all neuron types: Synaptic, Leaky, LIF, etc.
Workaround
Currently users must either use type_check=False or manually construct the NIR graph with proper annotations. Neither is ideal for production use.