@@ -188,6 +188,64 @@ def get_function_name(function):
188
188
return function .__name__
189
189
190
190
191
+ def get_default_value_for_repr (value ):
192
+ """Return a substitute for rendering the default value of a funciton arg.
193
+
194
+ Function and object instances are rendered as <Foo object at 0x00000000>
195
+ which can't be parsed by black. We substitute functions with the function
196
+ name and objects with a rendered version of the constructor like
197
+ `Foo(a=2, b="bar")`.
198
+
199
+ Args:
200
+ value: The value to find a better rendering of.
201
+
202
+ Returns:
203
+ Another value or `None` if no substitution is needed.
204
+ """
205
+
206
+ class ReprWrapper :
207
+ def __init__ (self , representation ):
208
+ self .representation = representation
209
+
210
+ def __repr__ (self ):
211
+ return self .representation
212
+
213
+ if value is inspect ._empty :
214
+ return None
215
+
216
+ if inspect .isfunction (value ):
217
+ # Render the function name instead
218
+ return ReprWrapper (value .__name__ )
219
+
220
+ if (
221
+ repr (value ).startswith ("<" ) # <Foo object at 0x00000000>
222
+ and hasattr (value , "__class__" ) # it is an object
223
+ and hasattr (value , "get_config" ) # it is a Keras object
224
+ ):
225
+ config = value .get_config ()
226
+ init_args = [] # The __init__ arguments to render
227
+ for p in inspect .signature (value .__class__ .__init__ ).parameters .values ():
228
+ if p .name == "self" :
229
+ continue
230
+ if p .kind == inspect .Parameter .POSITIONAL_ONLY :
231
+ # Required positional, render without a name
232
+ init_args .append (repr (config [p .name ]))
233
+ elif p .default is inspect ._empty or p .default != config [p .name ]:
234
+ # Keyword arg with non-default value, render
235
+ init_args .append (p .name + "=" + repr (config [p .name ]))
236
+ # else don't render that argument
237
+ return ReprWrapper (
238
+ value .__class__ .__module__
239
+ + "."
240
+ + value .__class__ .__name__
241
+ + "("
242
+ + ", " .join (init_args )
243
+ + ")"
244
+ )
245
+
246
+ return None
247
+
248
+
191
249
def get_signature_start (function ):
192
250
"""For the Dense layer, it should return the string 'keras.layers.Dense'"""
193
251
if ismethod (function ):
@@ -209,9 +267,12 @@ def get_signature_end(function):
209
267
210
268
formatted_params = []
211
269
for x in params :
270
+ default = get_default_value_for_repr (x .default )
271
+ if default :
272
+ x = inspect .Parameter (
273
+ x .name , x .kind , default = default , annotation = x .annotation
274
+ )
212
275
str_x = str (x )
213
- if "<function" in str_x :
214
- str_x = re .sub (r'<function (.*?) at 0x[0-9a-fA-F]+>' , r'\1' , str_x )
215
276
formatted_params .append (str_x )
216
277
signature_end = "(" + ", " .join (formatted_params ) + ")"
217
278
@@ -382,10 +443,8 @@ def get_class_from_method(meth):
382
443
return cls
383
444
meth = meth .__func__ # fallback to __qualname__ parsing
384
445
if inspect .isfunction (meth ):
385
- cls = getattr (
386
- inspect .getmodule (meth ),
387
- meth .__qualname__ .split (".<locals>" , 1 )[0 ].rsplit ("." , 1 )[0 ],
388
- )
446
+ cls_name = meth .__qualname__ .split (".<locals>" , 1 )[0 ].rsplit ("." , 1 )[0 ]
447
+ cls = getattr (inspect .getmodule (meth ), cls_name , None )
389
448
if isinstance (cls , type ):
390
449
return cls
391
450
return getattr (meth , "__objclass__" , None ) # handle special descriptor objects
0 commit comments