Skip to content

Commit 44ce2a2

Browse files
committed
chore: refactor Tether models for CIFAR-10 and MNIST using TimeDistributed for fused temporal processing
1 parent a342bed commit 44ce2a2

File tree

1 file changed

+45
-46
lines changed

1 file changed

+45
-46
lines changed

benchmarks/benchmark_sota.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -117,78 +117,77 @@ def get_stats(self) -> Tuple[float, float]:
117117
# MODEL DEFINITIONS
118118
# ============================================================================
119119

120+
class TimeDistributed(nn.Module):
121+
"""Processes the entire sequence through a standard 2D layer by folding Time into Batch."""
122+
def __init__(self, module):
123+
super().__init__()
124+
self.module = module
125+
126+
def forward(self, x):
127+
# x shape: (Time, Batch, C, H, W) or (Time, Batch, Features)
128+
T, B = x.shape[:2]
129+
# Reshape to (T*B, ...) for standard PyTorch layers
130+
x_reshaped = x.reshape(T * B, *x.shape[2:])
131+
y = self.module(x_reshaped)
132+
# Reshape back to (Time, Batch, ...)
133+
return y.reshape(T, B, *y.shape[1:])
134+
120135
class TetherCIFAR10Model(nn.Module):
121-
"""Tether model for CIFAR-10."""
136+
"""Refactored Tether model for CIFAR-10 using Fused Temporal Processing."""
122137
def __init__(self, n_steps=10):
123138
super().__init__()
124139
self.n_steps = n_steps
125140

126-
self.features = nn.Sequential(
127-
nn.Conv2d(3, 64, 3, padding=1),
141+
self.model = nn.Sequential(
142+
TimeDistributed(nn.Conv2d(3, 64, 3, padding=1)),
128143
TetherLIF(64 * 32 * 32),
129-
nn.AvgPool2d(2),
130-
nn.Conv2d(64, 128, 3, padding=1),
144+
TimeDistributed(nn.AvgPool2d(2)),
145+
TimeDistributed(nn.Conv2d(64, 128, 3, padding=1)),
131146
TetherLIF(128 * 16 * 16),
132-
nn.AvgPool2d(2),
133-
)
134-
135-
self.classifier = nn.Sequential(
136-
nn.Flatten(),
137-
nn.Linear(128 * 8 * 8, 256),
147+
TimeDistributed(nn.AvgPool2d(2)),
148+
TimeDistributed(nn.Flatten()),
149+
TimeDistributed(nn.Linear(128 * 8 * 8, 256)),
138150
TetherLIF(256),
139-
nn.Linear(256, 10)
151+
TimeDistributed(nn.Linear(256, 10))
140152
)
141153

142154
def forward(self, x):
143-
# x: (Batch, Time, C, H, W) or (Time, Batch, C, H, W)
144-
if len(x.shape) == 5 and x.shape[0] != self.n_steps:
145-
x = x.transpose(0, 1) # (Time, Batch, C, H, W)
146-
147-
outputs = []
148-
for t in range(self.n_steps):
149-
x_t = x[t]
150-
feat = self.features(x_t)
151-
out = self.classifier(feat)
152-
outputs.append(out)
155+
if len(x.shape) == 5 and x.shape[1] == self.n_steps:
156+
x = x.transpose(0, 1)
153157

154-
return torch.stack(outputs).mean(0)
155-
158+
x = self.model(x)
159+
return x.mean(0)
156160

157161
class TetherMNISTModel(nn.Module):
158-
"""Tether model for MNIST."""
162+
"""Refactored Tether model for MNIST using Fused Temporal Processing."""
159163
def __init__(self, n_steps=10):
160164
super().__init__()
161165
self.n_steps = n_steps
162166

163-
self.features = nn.Sequential(
164-
nn.Conv2d(1, 32, 3, padding=1),
165-
TetherLIF(32 * 28 * 28),
166-
nn.AvgPool2d(2),
167-
nn.Conv2d(32, 64, 3, padding=1),
167+
self.model = nn.Sequential(
168+
TimeDistributed(nn.Conv2d(1, 32, 3, padding=1)),
169+
TetherLIF(32 * 28 * 28), # Fused LIF: Processes entire T sequence at once
170+
TimeDistributed(nn.AvgPool2d(2)),
171+
TimeDistributed(nn.Conv2d(32, 64, 3, padding=1)),
168172
TetherLIF(64 * 14 * 14),
169-
nn.AvgPool2d(2),
170-
)
171-
172-
self.classifier = nn.Sequential(
173-
nn.Flatten(),
174-
nn.Linear(64 * 7 * 7, 128),
173+
TimeDistributed(nn.AvgPool2d(2)),
174+
TimeDistributed(nn.Flatten()),
175+
TimeDistributed(nn.Linear(64 * 7 * 7, 128)),
175176
TetherLIF(128),
176-
nn.Linear(128, 10)
177+
TimeDistributed(nn.Linear(128, 10))
177178
)
178179

179180
def forward(self, x):
180-
if len(x.shape) == 5 and x.shape[0] != self.n_steps:
181+
# Ensure Time is the first dimension: (Time, Batch, C, H, W)
182+
if len(x.shape) == 5 and x.shape[1] == self.n_steps:
181183
x = x.transpose(0, 1)
182184

183-
outputs = []
184-
for t in range(self.n_steps):
185-
x_t = x[t]
186-
feat = self.features(x_t)
187-
out = self.classifier(feat)
188-
outputs.append(out)
185+
# Process the entire sequence through the model in one go
186+
# No more manual 'for t in range(n_steps)' loop!
187+
x = self.model(x)
189188

190-
return torch.stack(outputs).mean(0)
191-
189+
# Mean over time for the final classification output
190+
return x.mean(0)
192191

193192
class SNNTorchCIFAR10Model(nn.Module):
194193
"""snnTorch model for CIFAR-10."""

0 commit comments

Comments
 (0)