@@ -44,6 +44,11 @@ def _type_to_str_(tp):
44
44
return framework_pb2 .AttrType .Name (tp )
45
45
46
46
47
+ _two_dollar_pattern_ = re .compile (r"\$\$([^\$]+)\$\$" )
48
+ _single_dollar_pattern_ = re .compile (r"\$([^\$]+)\$" )
49
+ _two_bang_pattern_ = re .compile (r"!!([^!]+)!!" )
50
+
51
+
47
52
def _generate_doc_string_ (op_proto ):
48
53
"""
49
54
Generate docstring by OpProto
@@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto):
55
60
str: the document string
56
61
"""
57
62
63
+ def escape_math (text ):
64
+ return _two_bang_pattern_ .sub (
65
+ r'$$\1$$' ,
66
+ _single_dollar_pattern_ .sub (
67
+ r':math:`\1`' , _two_dollar_pattern_ .sub (r"!!\1!!" , text )))
68
+
58
69
if not isinstance (op_proto , framework_pb2 .OpProto ):
59
70
raise TypeError ("OpProto should be `framework_pb2.OpProto`" )
60
71
61
72
buf = cStringIO .StringIO ()
62
- buf .write (op_proto .comment )
73
+ buf .write (escape_math ( op_proto .comment ) )
63
74
buf .write ('\n Args:\n ' )
64
75
for each_input in op_proto .inputs :
65
76
line_begin = ' {0}: ' .format (_convert_ (each_input .name ))
66
77
buf .write (line_begin )
67
- buf .write (each_input .comment )
68
- buf .write ('\n ' )
69
- buf .write (' ' * len (line_begin ))
70
- buf .write ('Duplicable: ' )
71
- buf .write (str (each_input .duplicable ))
72
- buf .write (' Optional: ' )
73
- buf .write (str (each_input .dispensable ))
78
+ buf .write (escape_math (each_input .comment ))
79
+ if each_input .duplicable :
80
+ buf .write (" Duplicatable." )
81
+ if each_input .dispensable :
82
+ buf .write (" Optional." )
74
83
buf .write ('\n ' )
75
84
76
85
skip_attrs = OpProtoHolder .generated_op_attr_names ()
@@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto):
83
92
buf .write (' (' )
84
93
buf .write (_type_to_str_ (each_attr .type ))
85
94
buf .write ('): ' )
86
- buf .write (each_attr .comment )
95
+ buf .write (escape_math ( each_attr .comment ) )
87
96
buf .write ('\n ' )
88
97
89
98
if len (op_proto .outputs ) != 0 :
@@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto):
92
101
for each_opt in op_proto .outputs :
93
102
if not each_opt .intermediate :
94
103
break
95
- buf .write (each_opt .comment )
104
+ buf .write (escape_math ( each_opt .comment ) )
96
105
97
106
return buf .getvalue ()
98
107
0 commit comments