1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import warnings
1516import numpy as np
17+ import numbers
1618
1719import paddle
1820import paddle .nn as nn
@@ -86,8 +88,10 @@ def forward(self, inputs):
8688 elif isinstance (input_size , list ):
8789 _input_size = []
8890 for item in input_size :
91+ if isinstance (item , int ):
92+ item = (item , )
8993 assert isinstance (item ,
90- (list , InputSpec )), 'When input_size is list, \
94+ (tuple , InputSpec )), 'When input_size is list, \
9195 expect item in input_size is a tuple or InputSpec, but got {}' .format (
9296 type (item ))
9397
@@ -97,12 +101,19 @@ def forward(self, inputs):
97101 batch_size = item .shape [0 ]
98102 else :
99103 _input_size .append (item )
104+ elif isinstance (input_size , int ):
105+ _input_size = (input_size , )
100106 else :
101107 _input_size = input_size
102108
103109 if batch_size is None :
104110 batch_size = - 1
105111
112+ if not paddle .in_dynamic_mode ():
113+ warnings .warn (
114+ "Your model was created in static mode, this may not get correct summary information!"
115+ )
116+
106117 result , params_info = summary_string (net , _input_size , batch_size , dtypes )
107118 print (result )
108119
@@ -117,16 +128,16 @@ def summary_string(model, input_size, batch_size=-1, dtypes=None):
117128
118129 depth = len (list (model .sublayers ()))
119130
120- def register_hook (module ):
121- def hook (module , input , output ):
122- class_name = str (module .__class__ ).split ("." )[- 1 ].split ("'" )[0 ]
131+ def register_hook (layer ):
132+ def hook (layer , input , output ):
133+ class_name = str (layer .__class__ ).split ("." )[- 1 ].split ("'" )[0 ]
123134
124135 try :
125- module_idx = int (module ._full_name .split ('_' )[- 1 ])
136+ layer_idx = int (layer ._full_name .split ('_' )[- 1 ])
126137 except :
127- module_idx = len (summary )
138+ layer_idx = len (summary )
128139
129- m_key = "%s-%i" % (class_name , module_idx + 1 )
140+ m_key = "%s-%i" % (class_name , layer_idx + 1 )
130141 summary [m_key ] = OrderedDict ()
131142 summary [m_key ]["input_shape" ] = list (input [0 ].shape )
132143 summary [m_key ]["input_shape" ][0 ] = batch_size
@@ -138,23 +149,50 @@ def hook(module, input, output):
138149 summary [m_key ]["output_shape" ][0 ] = batch_size
139150
140151 params = 0
141- if hasattr (module , "weight" ):
142- params += np .prod (module .weight .shape )
143- summary [m_key ]["trainable" ] = module .weight .trainable or (
144- not module .weight .stop_gradient )
145- if hasattr (module , "bias" ):
146- params += np .prod (module .bias .shape )
152+
153+ if paddle .in_dynamic_mode ():
154+ layer_state_dict = layer ._parameters
155+ else :
156+ layer_state_dict = layer .state_dict ()
157+
158+ for k , v in layer_state_dict .items ():
159+ params += np .prod (v .shape )
160+
161+ try :
162+ if (getattr (getattr (layer , k ), 'trainable' )) and (
163+ not getattr (getattr (layer , k ), 'stop_gradient' )):
164+ summary [m_key ]["trainable" ] = True
165+ else :
166+ summary [m_key ]["trainable" ] = False
167+ except :
168+ summary [m_key ]["trainable" ] = True
169+
147170 summary [m_key ]["nb_params" ] = params
148171
149- if (not isinstance (module , nn .Sequential ) and
150- not isinstance (module , nn .LayerList ) and
151- (not (module == model ) or depth < 1 )):
172+ if (not isinstance (layer , nn .Sequential ) and
173+ not isinstance (layer , nn .LayerList ) and
174+ (not (layer == model ) or depth < 1 )):
175+
176+ hooks .append (layer .register_forward_post_hook (hook ))
177+
178+ def _check_input_size (input_sizes ):
179+ for input_size in input_sizes :
180+ for item in input_size :
181+ if not isinstance (item , numbers .Number ):
182+ raise TypeError (
183+ "Expected item in input size be a number, but got {}" .
184+ format (type (item )))
152185
153- hooks .append (module .register_forward_post_hook (hook ))
186+ if item <= 0 :
187+ raise ValueError (
188+ "Expected item in input size greater than zero, but got {}" .
189+ format (item ))
154190
155191 if isinstance (input_size , tuple ):
156192 input_size = [input_size ]
157193
194+ _check_input_size (input_size )
195+
158196 x = [
159197 paddle .rand (
160198 [2 ] + list (in_size ), dtype = dtype )
@@ -193,7 +231,12 @@ def hook(module, input, output):
193231 "{0:,}" .format (summary [layer ]["nb_params" ]), )
194232 total_params += summary [layer ]["nb_params" ]
195233
196- total_output += np .prod (summary [layer ]["output_shape" ])
234+ try :
235+ total_output += np .prod (summary [layer ]["output_shape" ])
236+ except :
237+ for output_shape in summary [layer ]["output_shape" ]:
238+ total_output += np .prod (output_shape )
239+
197240 if "trainable" in summary [layer ]:
198241 if summary [layer ]["trainable" ] == True :
199242 trainable_params += summary [layer ]["nb_params" ]
0 commit comments