99from brainpy .dyn .base import DynamicalSystem
1010from brainpy .errors import MathError
1111from brainpy .initialize import XavierNormal , ZeroInit , Initializer , parameter
12- from brainpy .modes import Mode , TrainingMode , BatchingMode , training , batching
12+ from brainpy .modes import Mode , TrainingMode , BatchingMode , training , batching
1313from brainpy .tools .checking import check_initializer
1414from brainpy .types import Array
1515
@@ -201,17 +201,19 @@ class Flatten(DynamicalSystem):
201201 mode: Mode
202202 Enable training this node or not. (default True)
203203 """
204- def __init__ (self ,
205- name : Optional [str ] = None ,
206- mode : Optional [Mode ] = batching ,
207- ):
204+
205+ def __init__ (
206+ self ,
207+ name : Optional [str ] = None ,
208+ mode : Optional [Mode ] = batching ,
209+ ):
208210 super ().__init__ (name , mode )
209-
211+
210212 def update (self , shr , x ):
211213 if isinstance (self .mode , BatchingMode ):
212214 return x .reshape ((x .shape [0 ], - 1 ))
213215 else :
214216 return x .flatten ()
215-
217+
216218 def reset_state (self , batch_size = None ):
217- pass
219+ pass
0 commit comments