Skip to content

Commit 916e863

Browse files
authored
Merge pull request #11504 from reyoung/feature/polish_generate_fn
Polish inline math and duplicable/optional in auto generated doc
2 parents 0329ee7 + 3571df8 commit 916e863

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

python/paddle/fluid/layers/layer_function_generator.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def _type_to_str_(tp):
4444
return framework_pb2.AttrType.Name(tp)
4545

4646

47+
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
48+
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
49+
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
50+
51+
4752
def _generate_doc_string_(op_proto):
4853
"""
4954
Generate docstring by OpProto
@@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto):
5560
str: the document string
5661
"""
5762

63+
def escape_math(text):
64+
return _two_bang_pattern_.sub(
65+
r'$$\1$$',
66+
_single_dollar_pattern_.sub(
67+
r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text)))
68+
5869
if not isinstance(op_proto, framework_pb2.OpProto):
5970
raise TypeError("OpProto should be `framework_pb2.OpProto`")
6071

6172
buf = cStringIO.StringIO()
62-
buf.write(op_proto.comment)
73+
buf.write(escape_math(op_proto.comment))
6374
buf.write('\nArgs:\n')
6475
for each_input in op_proto.inputs:
6576
line_begin = ' {0}: '.format(_convert_(each_input.name))
6677
buf.write(line_begin)
67-
buf.write(each_input.comment)
68-
buf.write('\n')
69-
buf.write(' ' * len(line_begin))
70-
buf.write('Duplicable: ')
71-
buf.write(str(each_input.duplicable))
72-
buf.write(' Optional: ')
73-
buf.write(str(each_input.dispensable))
78+
buf.write(escape_math(each_input.comment))
79+
if each_input.duplicable:
80+
buf.write(" Duplicatable.")
81+
if each_input.dispensable:
82+
buf.write(" Optional.")
7483
buf.write('\n')
7584

7685
skip_attrs = OpProtoHolder.generated_op_attr_names()
@@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto):
8392
buf.write(' (')
8493
buf.write(_type_to_str_(each_attr.type))
8594
buf.write('): ')
86-
buf.write(each_attr.comment)
95+
buf.write(escape_math(each_attr.comment))
8796
buf.write('\n')
8897

8998
if len(op_proto.outputs) != 0:
@@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto):
92101
for each_opt in op_proto.outputs:
93102
if not each_opt.intermediate:
94103
break
95-
buf.write(each_opt.comment)
104+
buf.write(escape_math(each_opt.comment))
96105

97106
return buf.getvalue()
98107

0 commit comments

Comments
 (0)