Skip to content

Commit fd226c9

Browse files
[mlir][Python] Roll up of python API fixes.
* As discussed, fixes the ordering or (operands, results) -> (results, operands) in various `create` like methods. * Fixes a syntax error in an ODS accessor method. * Removes the linalg example in favor of a test case that exercises the same. * Fixes FuncOp visibility to properly use None instead of the empty string and defaults it to None. * Implements what was documented for requiring that trailing __init__ args `loc` and `ip` are keyword only. * Adds a check to `InsertionPoint.insert` so that if attempting to insert past the terminator, an exception is raised telling you what to do instead. Previously, this would crash downstream (i.e. when trying to print the resultant module). * Renames `_ods_build_default` -> `build_generic` and documents it. * Removes `result` from the list of prohibited words and for single-result ops, defaults to naming the result `result`, thereby matching expectations and what is already implemented on the base class. * This was intended to be a relatively small set of changes to be inlined with the broader support for ODS generating the most specific builder, but it spidered out once actually testing various combinations, so rolling up separately. Differential Revision: https://reviews.llvm.org/D95320
1 parent 78d41a1 commit fd226c9

File tree

12 files changed

+208
-199
lines changed

12 files changed

+208
-199
lines changed

mlir/docs/Bindings/Python.md

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,9 @@ defaults on `OpView`):
439439
#### Builders
440440

441441
Presently, only a single, default builder is mapped to the `__init__` method.
442-
Generalizing this facility is under active development. It currently accepts
443-
arguments:
442+
The intent is that this `__init__` method represents the *most specific* of
443+
the builders typically generated for C++; however currently it is just the
444+
generic form below.
444445

445446
* One argument for each declared result:
446447
* For single-valued results: Each will accept an `mlir.ir.Type`.
@@ -453,7 +454,11 @@ arguments:
453454
* `loc`: An explicit `mlir.ir.Location` to use. Defaults to the location
454455
bound to the thread (i.e. `with Location.unknown():`) or an error if none
455456
is bound nor specified.
456-
* `context`: An explicit `mlir.ir.Context` to use. Default to the context
457-
bound to the thread (i.e. `with Context():` or implicitly via `Location` or
458-
`InsertionPoint` context managers) or an error if none is bound nor
459-
specified.
457+
* `ip`: An explicit `mlir.ir.InsertionPoint` to use. Default to the insertion
458+
point bound to the thread (i.e. `with InsertionPoint(...):`).
459+
460+
In addition, each `OpView` inherits a `build_generic` method which allows
461+
construction via a (nested in the case of variadic) sequence of `results` and
462+
`operands`. This can be used to get some default construction semantics for
463+
operations that are otherwise unsupported in Python, at the expense of having
464+
a very generic signature.

mlir/examples/python/.style.yapf

Lines changed: 0 additions & 4 deletions
This file was deleted.

mlir/examples/python/linalg_matmul.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

mlir/lib/Bindings/Python/IRModules.cpp

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -891,8 +891,8 @@ PyBlock PyOperation::getBlock() {
891891
}
892892

893893
py::object PyOperation::create(
894-
std::string name, llvm::Optional<std::vector<PyValue *>> operands,
895-
llvm::Optional<std::vector<PyType *>> results,
894+
std::string name, llvm::Optional<std::vector<PyType *>> results,
895+
llvm::Optional<std::vector<PyValue *>> operands,
896896
llvm::Optional<py::dict> attributes,
897897
llvm::Optional<std::vector<PyBlock *>> successors, int regions,
898898
DefaultingPyLocation location, py::object maybeIp) {
@@ -1039,12 +1039,12 @@ py::object PyOperation::createOpView() {
10391039
//------------------------------------------------------------------------------
10401040

10411041
py::object
1042-
PyOpView::odsBuildDefault(py::object cls, py::list operandList,
1043-
py::list resultTypeList,
1044-
llvm::Optional<py::dict> attributes,
1045-
llvm::Optional<std::vector<PyBlock *>> successors,
1046-
llvm::Optional<int> regions,
1047-
DefaultingPyLocation location, py::object maybeIp) {
1042+
PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
1043+
py::list operandList,
1044+
llvm::Optional<py::dict> attributes,
1045+
llvm::Optional<std::vector<PyBlock *>> successors,
1046+
llvm::Optional<int> regions,
1047+
DefaultingPyLocation location, py::object maybeIp) {
10481048
PyMlirContextRef context = location->getContext();
10491049
// Class level operation construction metadata.
10501050
std::string name = py::cast<std::string>(cls.attr("OPERATION_NAME"));
@@ -1288,8 +1288,9 @@ PyOpView::odsBuildDefault(py::object cls, py::list operandList,
12881288
}
12891289

12901290
// Delegate to create.
1291-
return PyOperation::create(std::move(name), /*operands=*/std::move(operands),
1291+
return PyOperation::create(std::move(name),
12921292
/*results=*/std::move(resultTypes),
1293+
/*operands=*/std::move(operands),
12931294
/*attributes=*/std::move(attributes),
12941295
/*successors=*/std::move(successors),
12951296
/*regions=*/*regions, location, maybeIp);
@@ -1357,6 +1358,16 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) {
13571358
// Insert before operation.
13581359
(*refOperation)->checkValid();
13591360
beforeOp = (*refOperation)->get();
1361+
} else {
1362+
// Insert at end (before null) is only valid if the block does not
1363+
// already end in a known terminator (violating this will cause assertion
1364+
// failures later).
1365+
if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) {
1366+
throw py::index_error("Cannot insert operation at the end of a block "
1367+
"that already has a terminator. Did you mean to "
1368+
"use 'InsertionPoint.at_block_terminator(block)' "
1369+
"versus 'InsertionPoint(block)'?");
1370+
}
13601371
}
13611372
mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
13621373
operation.setAttached();
@@ -3646,8 +3657,8 @@ void mlir::python::populateIRSubmodule(py::module &m) {
36463657

36473658
py::class_<PyOperation, PyOperationBase>(m, "Operation")
36483659
.def_static("create", &PyOperation::create, py::arg("name"),
3649-
py::arg("operands") = py::none(),
36503660
py::arg("results") = py::none(),
3661+
py::arg("operands") = py::none(),
36513662
py::arg("attributes") = py::none(),
36523663
py::arg("successors") = py::none(), py::arg("regions") = 0,
36533664
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
@@ -3681,12 +3692,11 @@ void mlir::python::populateIRSubmodule(py::module &m) {
36813692
opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true);
36823693
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none();
36833694
opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none();
3684-
opViewClass.attr("_ods_build_default") = classmethod(
3685-
&PyOpView::odsBuildDefault, py::arg("cls"),
3686-
py::arg("operands") = py::none(), py::arg("results") = py::none(),
3687-
py::arg("attributes") = py::none(), py::arg("successors") = py::none(),
3688-
py::arg("regions") = py::none(), py::arg("loc") = py::none(),
3689-
py::arg("ip") = py::none(),
3695+
opViewClass.attr("build_generic") = classmethod(
3696+
&PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(),
3697+
py::arg("operands") = py::none(), py::arg("attributes") = py::none(),
3698+
py::arg("successors") = py::none(), py::arg("regions") = py::none(),
3699+
py::arg("loc") = py::none(), py::arg("ip") = py::none(),
36903700
"Builds a specific, generated OpView based on class level attributes.");
36913701

36923702
//----------------------------------------------------------------------------

mlir/lib/Bindings/Python/IRModules.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject {
455455

456456
/// Creates an operation. See corresponding python docstring.
457457
static pybind11::object
458-
create(std::string name, llvm::Optional<std::vector<PyValue *>> operands,
459-
llvm::Optional<std::vector<PyType *>> results,
458+
create(std::string name, llvm::Optional<std::vector<PyType *>> results,
459+
llvm::Optional<std::vector<PyValue *>> operands,
460460
llvm::Optional<pybind11::dict> attributes,
461461
llvm::Optional<std::vector<PyBlock *>> successors, int regions,
462462
DefaultingPyLocation location, pybind11::object ip);
@@ -498,12 +498,12 @@ class PyOpView : public PyOperationBase {
498498
pybind11::object getOperationObject() { return operationObject; }
499499

500500
static pybind11::object
501-
odsBuildDefault(pybind11::object cls, pybind11::list operandList,
502-
pybind11::list resultTypeList,
503-
llvm::Optional<pybind11::dict> attributes,
504-
llvm::Optional<std::vector<PyBlock *>> successors,
505-
llvm::Optional<int> regions, DefaultingPyLocation location,
506-
pybind11::object maybeIp);
501+
buildGeneric(pybind11::object cls, pybind11::list resultTypeList,
502+
pybind11::list operandList,
503+
llvm::Optional<pybind11::dict> attributes,
504+
llvm::Optional<std::vector<PyBlock *>> successors,
505+
llvm::Optional<int> regions, DefaultingPyLocation location,
506+
pybind11::object maybeIp);
507507

508508
private:
509509
PyOperation &operation; // For efficient, cast-free access from C++

mlir/lib/Bindings/Python/mlir/dialects/_builtin.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
class ModuleOp:
88
"""Specialization for the module op class."""
99

10-
def __init__(self, loc=None, ip=None):
10+
def __init__(self, *, loc=None, ip=None):
1111
super().__init__(
12-
self._ods_build_default(operands=[], results=[], loc=loc, ip=ip))
12+
self.build_generic(results=[], operands=[], loc=loc, ip=ip))
1313
body = self.regions[0].blocks.append()
1414
with InsertionPoint(body):
1515
Operation.create("module_terminator")
@@ -25,7 +25,8 @@ class FuncOp:
2525
def __init__(self,
2626
name,
2727
type,
28-
visibility,
28+
*,
29+
visibility=None,
2930
body_builder=None,
3031
loc=None,
3132
ip=None):
@@ -34,8 +35,8 @@ def __init__(self,
3435
- `name` is a string representing the function name.
3536
- `type` is either a FunctionType or a pair of list describing inputs and
3637
results.
37-
- `visibility` is a string matching `public`, `private`, or `nested`. The
38-
empty string implies a private visibility.
38+
- `visibility` is a string matching `public`, `private`, or `nested`. None
39+
implies private visibility.
3940
- `body_builder` is an optional callback, when provided a new entry block
4041
is created and the callback is invoked with the new op as argument within
4142
an InsertionPoint context already set for the block. The callback is
@@ -50,7 +51,7 @@ def __init__(self,
5051
type = TypeAttr.get(type)
5152
sym_visibility = StringAttr.get(
5253
str(visibility)) if visibility is not None else None
53-
super().__init__(sym_name, type, sym_visibility, loc, ip)
54+
super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip)
5455
if body_builder:
5556
entry_block = self.add_entry_block()
5657
with InsertionPoint(entry_block):

mlir/lib/Bindings/Python/mlir/dialects/_linalg.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None):
1111
raise ValueError(
1212
"Structured ops must have outputs or results, but not both.")
1313
super().__init__(
14-
self._ods_build_default(operands=[list(inputs),
15-
list(outputs)],
16-
results=list(results),
17-
loc=loc,
18-
ip=ip))
14+
self.build_generic(results=list(results),
15+
operands=[list(inputs), list(outputs)],
16+
loc=loc,
17+
ip=ip))
1918

2019

2120
def select_opview_mixin(parent_opview_cls):
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import builtin
5+
from mlir.dialects import linalg
6+
from mlir.dialects import std
7+
8+
9+
def run(f):
10+
print("\nTEST:", f.__name__)
11+
f()
12+
13+
14+
# CHECK-LABEL: TEST: testStructuredOpOnTensors
15+
def testStructuredOpOnTensors():
16+
with Context() as ctx, Location.unknown():
17+
module = Module.create()
18+
f32 = F32Type.get()
19+
tensor_type = RankedTensorType.get((2, 3, 4), f32)
20+
with InsertionPoint.at_block_terminator(module.body):
21+
func = builtin.FuncOp(name="matmul_test",
22+
type=FunctionType.get(
23+
inputs=[tensor_type, tensor_type],
24+
results=[tensor_type]))
25+
with InsertionPoint(func.add_entry_block()):
26+
lhs, rhs = func.entry_block.arguments
27+
result = linalg.MatmulOp([lhs, rhs], results=[tensor_type]).result
28+
std.ReturnOp([result])
29+
30+
# CHECK: %[[R:.*]] = linalg.matmul ins(%arg0, %arg1 : tensor<2x3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
31+
print(module)
32+
33+
34+
run(testStructuredOpOnTensors)
35+
36+
37+
# CHECK-LABEL: TEST: testStructuredOpOnBuffers
38+
def testStructuredOpOnBuffers():
39+
with Context() as ctx, Location.unknown():
40+
module = Module.create()
41+
f32 = F32Type.get()
42+
memref_type = MemRefType.get((2, 3, 4), f32)
43+
with InsertionPoint.at_block_terminator(module.body):
44+
func = builtin.FuncOp(name="matmul_test",
45+
type=FunctionType.get(
46+
inputs=[memref_type, memref_type, memref_type],
47+
results=[]))
48+
with InsertionPoint(func.add_entry_block()):
49+
lhs, rhs, result = func.entry_block.arguments
50+
linalg.MatmulOp([lhs, rhs], outputs=[result])
51+
std.ReturnOp([])
52+
53+
# CHECK: linalg.matmul ins(%arg0, %arg1 : memref<2x3x4xf32>, memref<2x3x4xf32>) outs(%arg2 : memref<2x3x4xf32>)
54+
print(module)
55+
56+
57+
run(testStructuredOpOnBuffers)

mlir/test/Bindings/Python/insertion_point.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,21 @@ def test_insert_at_block_terminator_missing():
125125
run(test_insert_at_block_terminator_missing)
126126

127127

128+
# CHECK-LABEL: TEST: test_insert_at_end_with_terminator_errors
129+
def test_insert_at_end_with_terminator_errors():
130+
with Context() as ctx, Location.unknown():
131+
ctx.allow_unregistered_dialects = True
132+
m = Module.create() # Module is created with a terminator.
133+
with InsertionPoint(m.body):
134+
try:
135+
Operation.create("custom.op1", results=[], operands=[])
136+
except IndexError as e:
137+
# CHECK: ERROR: Cannot insert operation at the end of a block that already has a terminator.
138+
print(f"ERROR: {e}")
139+
140+
run(test_insert_at_end_with_terminator_errors)
141+
142+
128143
# CHECK-LABEL: TEST: test_insertion_point_context
129144
def test_insertion_point_context():
130145
ctx = Context()

0 commit comments

Comments
 (0)