99from brainpy .dyn .base import DynamicalSystem
1010from brainpy .errors import MathError
1111from brainpy .initialize import XavierNormal , ZeroInit , Initializer , parameter
12- from brainpy .modes import Mode , TrainingMode , training
12+ from brainpy .modes import Mode , TrainingMode , BatchingMode , training , batching
1313from brainpy .tools .checking import check_initializer
1414from brainpy .types import Array
1515
1616__all__ = [
1717 'Dense' ,
18+ 'Flatten'
1819]
1920
2021
@@ -188,3 +189,29 @@ def offline_fit(self,
188189 bias , Wff = bm .split (weights , [1 ])
189190 self .W .value = Wff
190191 self .b .value = bias [0 ]
192+
193+
194+ class Flatten (DynamicalSystem ):
195+ r"""Flattens a contiguous range of dims into 2D or 1D.
196+
197+ Parameters:
198+ ----------
199+ name: str, Optional
200+ The name of the object
201+ mode: Mode
202+ Enable training this node or not. (default True)
203+ """
204+ def __init__ (self ,
205+ name : Optional [str ] = None ,
206+ mode : Optional [Mode ] = batching ,
207+ ):
208+ super ().__init__ (name , mode )
209+
210+ def update (self , shr , x ):
211+ if isinstance (self .mode , BatchingMode ):
212+ return x .reshape ((x .shape [0 ], - 1 ))
213+ else :
214+ return x .flatten ()
215+
216+ def reset_state (self , batch_size = None ):
217+ pass
0 commit comments