Skip to content

Commit 10d0d95

Browse files
[MLIR][Python] Add docstring for generated python op classes (llvm#158198)
This PR adds support in mlir-tblgen for generating docstrings for each Python class corresponding to an MLIR op. The docstrings are currently derived from the op’s description in ODS, with indentation adjusted to display nicely in Python. This makes it easier for Python users to see the op descriptions directly in their IDE or LSP while coding. In the future, we can extend the docstrings to include explanations for each method, attribute, and so on. This idea was previously discussed in the `#mlir-python` channel on Discord with @makslevental and @superbobry. --------- Co-authored-by: Maksim Levental <[email protected]>
1 parent c46cf1e commit 10d0d95

File tree

3 files changed

+59
-6
lines changed

3 files changed

+59
-6
lines changed

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,34 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
254254
// CHECK: op = DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip); results = op.results
255255
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
256256

257+
258+
// CHECK: class DescriptionOp(_ods_ir.OpView):
259+
// CHECK: r"""
260+
// CHECK: This is a long description.
261+
// CHECK: It has multiple lines.
262+
// CHECK: A code block (to test the indent).
263+
// CHECK: ```mlir
264+
// CHECK: test.loop {
265+
// CHECK: test.yield
266+
// CHECK: }
267+
// CHECK: ```
268+
// CHECK: Add \"\"\" will not terminate the description.
269+
// CHECK: """
270+
def DescriptionOp : TestOp<"description"> {
271+
let description = [{
272+
This is a long description.
273+
It has multiple lines.
274+
275+
A code block (to test the indent).
276+
```mlir
277+
test.loop {
278+
test.yield
279+
}
280+
```
281+
Add """ will not terminate the description.
282+
}];
283+
}
284+
257285
// CHECK: @_ods_cext.register_operation(_Dialect)
258286
// CHECK: class EmptyOp(_ods_ir.OpView):
259287
// CHECK-LABEL: OPERATION_NAME = "test.empty"

mlir/test/python/ir/auto_location.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def testInferLocations():
5151
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
5252
three = arith.constant(IndexType.get(), 3)
5353
# fmt: off
54-
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
54+
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
5555
# fmt: on
5656
print(three.location)
5757

@@ -60,14 +60,14 @@ def foo():
6060
print(four.location)
6161

6262
# fmt: off
63-
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
63+
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
6464
# fmt: on
6565
foo()
6666

6767
_cext.globals.register_traceback_file_exclusion(__file__)
6868

6969
# fmt: off
70-
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235))
70+
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":{{[0-9]+}}:4 to :235))
7171
# fmt: on
7272
foo()
7373

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313

1414
#include "OpGenHelpers.h"
1515

16+
#include "mlir/Support/IndentedOstream.h"
1617
#include "mlir/TableGen/GenInfo.h"
1718
#include "mlir/TableGen/Operator.h"
1819
#include "llvm/ADT/StringSet.h"
1920
#include "llvm/Support/CommandLine.h"
2021
#include "llvm/Support/FormatVariadic.h"
2122
#include "llvm/TableGen/Error.h"
2223
#include "llvm/TableGen/Record.h"
24+
#include <regex>
2325

2426
using namespace mlir;
2527
using namespace mlir::tblgen;
@@ -61,10 +63,11 @@ from ._{0}_ops_gen import _Dialect
6163

6264
/// Template for operation class:
6365
/// {0} is the Python class name;
64-
/// {1} is the operation name.
66+
/// {1} is the operation name;
67+
/// {2} is the docstring for this operation.
6568
constexpr const char *opClassTemplate = R"Py(
6669
@_ods_cext.register_operation(_Dialect)
67-
class {0}(_ods_ir.OpView):
70+
class {0}(_ods_ir.OpView):{2}
6871
OPERATION_NAME = "{1}"
6972
)Py";
7073

@@ -1031,9 +1034,31 @@ static void emitValueBuilder(const Operator &op,
10311034
}
10321035
}
10331036

1037+
/// Retrieve the description of the given op and generate a docstring for it.
1038+
static std::string makeDocStringForOp(const Operator &op) {
1039+
if (!op.hasDescription())
1040+
return "";
1041+
1042+
auto desc = op.getDescription().rtrim(" \t").str();
1043+
// Replace all """ with \"\"\" to avoid early termination of the literal.
1044+
desc = std::regex_replace(desc, std::regex(R"(""")"), R"(\"\"\")");
1045+
1046+
std::string docString = "\n";
1047+
llvm::raw_string_ostream os(docString);
1048+
raw_indented_ostream identedOs(os);
1049+
os << R"( r""")" << "\n";
1050+
identedOs.printReindented(desc, " ");
1051+
if (!StringRef(desc).ends_with("\n"))
1052+
os << "\n";
1053+
os << R"( """)" << "\n";
1054+
1055+
return docString;
1056+
}
1057+
10341058
/// Emits bindings for a specific Op to the given output stream.
10351059
static void emitOpBindings(const Operator &op, raw_ostream &os) {
1036-
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName());
1060+
os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
1061+
makeDocStringForOp(op));
10371062

10381063
// Sized segments.
10391064
if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {

0 commit comments

Comments
 (0)