@@ -117,78 +117,77 @@ def get_stats(self) -> Tuple[float, float]:
117117# MODEL DEFINITIONS
118118# ============================================================================
119119
120+ class TimeDistributed (nn .Module ):
121+ """Processes the entire sequence through a standard 2D layer by folding Time into Batch."""
122+ def __init__ (self , module ):
123+ super ().__init__ ()
124+ self .module = module
125+
126+ def forward (self , x ):
127+ # x shape: (Time, Batch, C, H, W) or (Time, Batch, Features)
128+ T , B = x .shape [:2 ]
129+ # Reshape to (T*B, ...) for standard PyTorch layers
130+ x_reshaped = x .reshape (T * B , * x .shape [2 :])
131+ y = self .module (x_reshaped )
132+ # Reshape back to (Time, Batch, ...)
133+ return y .reshape (T , B , * y .shape [1 :])
134+
120135class TetherCIFAR10Model (nn .Module ):
121- """Tether model for CIFAR-10."""
136+ """Refactored Tether model for CIFAR-10 using Fused Temporal Processing ."""
122137 def __init__ (self , n_steps = 10 ):
123138 super ().__init__ ()
124139 self .n_steps = n_steps
125140
126- self .features = nn .Sequential (
127- nn .Conv2d (3 , 64 , 3 , padding = 1 ),
141+ self .model = nn .Sequential (
142+ TimeDistributed ( nn .Conv2d (3 , 64 , 3 , padding = 1 ) ),
128143 TetherLIF (64 * 32 * 32 ),
129- nn .AvgPool2d (2 ),
130- nn .Conv2d (64 , 128 , 3 , padding = 1 ),
144+ TimeDistributed ( nn .AvgPool2d (2 ) ),
145+ TimeDistributed ( nn .Conv2d (64 , 128 , 3 , padding = 1 ) ),
131146 TetherLIF (128 * 16 * 16 ),
132- nn .AvgPool2d (2 ),
133- )
134-
135- self .classifier = nn .Sequential (
136- nn .Flatten (),
137- nn .Linear (128 * 8 * 8 , 256 ),
147+ TimeDistributed (nn .AvgPool2d (2 )),
148+ TimeDistributed (nn .Flatten ()),
149+ TimeDistributed (nn .Linear (128 * 8 * 8 , 256 )),
138150 TetherLIF (256 ),
139- nn .Linear (256 , 10 )
151+ TimeDistributed ( nn .Linear (256 , 10 ) )
140152 )
141153
142154 def forward (self , x ):
143- # x: (Batch, Time, C, H, W) or (Time, Batch, C, H, W)
144- if len (x .shape ) == 5 and x .shape [0 ] != self .n_steps :
145- x = x .transpose (0 , 1 ) # (Time, Batch, C, H, W)
146-
147- outputs = []
148- for t in range (self .n_steps ):
149- x_t = x [t ]
150- feat = self .features (x_t )
151- out = self .classifier (feat )
152- outputs .append (out )
155+ if len (x .shape ) == 5 and x .shape [1 ] == self .n_steps :
156+ x = x .transpose (0 , 1 )
153157
154- return torch . stack ( outputs ). mean ( 0 )
155-
158+ x = self . model ( x )
159+ return x . mean ( 0 )
156160
157161class TetherMNISTModel (nn .Module ):
158- """Tether model for MNIST."""
162+ """Refactored Tether model for MNIST using Fused Temporal Processing ."""
159163 def __init__ (self , n_steps = 10 ):
160164 super ().__init__ ()
161165 self .n_steps = n_steps
162166
163- self .features = nn .Sequential (
164- nn .Conv2d (1 , 32 , 3 , padding = 1 ),
165- TetherLIF (32 * 28 * 28 ),
166- nn .AvgPool2d (2 ),
167- nn .Conv2d (32 , 64 , 3 , padding = 1 ),
167+ self .model = nn .Sequential (
168+ TimeDistributed ( nn .Conv2d (1 , 32 , 3 , padding = 1 ) ),
169+ TetherLIF (32 * 28 * 28 ), # Fused LIF: Processes entire T sequence at once
170+ TimeDistributed ( nn .AvgPool2d (2 ) ),
171+ TimeDistributed ( nn .Conv2d (32 , 64 , 3 , padding = 1 ) ),
168172 TetherLIF (64 * 14 * 14 ),
169- nn .AvgPool2d (2 ),
170- )
171-
172- self .classifier = nn .Sequential (
173- nn .Flatten (),
174- nn .Linear (64 * 7 * 7 , 128 ),
173+ TimeDistributed (nn .AvgPool2d (2 )),
174+ TimeDistributed (nn .Flatten ()),
175+ TimeDistributed (nn .Linear (64 * 7 * 7 , 128 )),
175176 TetherLIF (128 ),
176- nn .Linear (128 , 10 )
177+ TimeDistributed ( nn .Linear (128 , 10 ) )
177178 )
178179
179180 def forward (self , x ):
180- if len (x .shape ) == 5 and x .shape [0 ] != self .n_steps :
181+ # Ensure Time is the first dimension: (Time, Batch, C, H, W)
182+ if len (x .shape ) == 5 and x .shape [1 ] == self .n_steps :
181183 x = x .transpose (0 , 1 )
182184
183- outputs = []
184- for t in range (self .n_steps ):
185- x_t = x [t ]
186- feat = self .features (x_t )
187- out = self .classifier (feat )
188- outputs .append (out )
185+ # Process the entire sequence through the model in one go
186+ # No more manual 'for t in range(n_steps)' loop!
187+ x = self .model (x )
189188
190- return torch . stack ( outputs ). mean ( 0 )
191-
189+ # Mean over time for the final classification output
190+ return x . mean ( 0 )
192191
193192class SNNTorchCIFAR10Model (nn .Module ):
194193 """snnTorch model for CIFAR-10."""
0 commit comments