Skip to content

Commit 70f4a42

Browse files
committed
chore: simplify spike handling in SNNTorch models and reset hidden states in MNIST model
1 parent 70eb398 commit 70f4a42

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

benchmarks/bechmark_sota.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -222,16 +222,16 @@ def forward(self, x):
222222
x_t = x[:, t] if x.shape[1] == self.n_steps else x[t]
223223

224224
cur = self.conv1(x_t)
225-
spk, _ = self.lif1(cur)
225+
spk = self.lif1(cur)
226226
cur = self.pool1(spk)
227227

228228
cur = self.conv2(cur)
229-
spk, _ = self.lif2(cur)
229+
spk = self.lif2(cur)
230230
cur = self.pool2(spk)
231231

232232
cur = self.flatten(cur)
233233
cur = self.fc1(cur)
234-
spk, _ = self.lif3(cur)
234+
spk = self.lif3(cur)
235235
cur = self.fc2(spk)
236236

237237
outputs.append(cur)
@@ -261,6 +261,7 @@ def __init__(self, n_steps=10, beta=0.9):
261261
self.fc2 = nn.Linear(128, 10)
262262

263263
def forward(self, x):
264+
# Reset hidden states
264265
self.lif1.init_leaky()
265266
self.lif2.init_leaky()
266267
self.lif3.init_leaky()
@@ -270,16 +271,16 @@ def forward(self, x):
270271
x_t = x[:, t] if x.shape[1] == self.n_steps else x[t]
271272

272273
cur = self.conv1(x_t)
273-
spk, _ = self.lif1(cur)
274+
spk = self.lif1(cur)
274275
cur = self.pool1(spk)
275276

276277
cur = self.conv2(cur)
277-
spk, _ = self.lif2(cur)
278+
spk = self.lif2(cur)
278279
cur = self.pool2(spk)
279280

280281
cur = self.flatten(cur)
281282
cur = self.fc1(cur)
282-
spk, _ = self.lif3(cur)
283+
spk = self.lif3(cur)
283284
cur = self.fc2(spk)
284285

285286
outputs.append(cur)

pyproject.toml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,21 @@ requires = ["hatchling"]
1717
build-backend = "hatchling.build"
1818

1919
[project.optional-dependencies]
20-
dev = ["pytest>=9.0.2", "pytest-cov>=7.0.0"]
20+
dev = [
21+
"pytest>=9.0.2",
22+
"pytest-cov>=7.0.0",
23+
"psutil>=7.2.1",
24+
"snntorch>=0.9.4",
25+
"gputil>=1.4.0",
26+
"spikingjelly>=0.0.0.0.14",
27+
]
2128
docs = [
2229
"sphinx>=8.1.3",
2330
"sphinx-autodoc-typehints>=3.0.1",
2431
"sphinx-rtd-theme>=3.0.0",
2532
"myst-parser>=4.0.1",
2633
"furo>=2025.9.25",
27-
"ruff>=0.14.10"
34+
"ruff>=0.14.10",
2835
]
2936

3037
[tool.hatch.build.targets.wheel]
@@ -37,3 +44,6 @@ branch = "main"
3744
upload_to_pypi = false
3845
upload_to_release = true
3946
build_command = "pip install build && python -m build"
47+
48+
[dependency-groups]
49+
dev = []

0 commit comments

Comments
 (0)