Skip to content

Commit 31536e6

Browse files
[MLIR] [Python] ir.Value is now generic in the type of the value it holds (llvm#166148)
This makes it similar to `mlir::TypedValue` in the MLIR C++ API and allows users to be more specific about the values they produce or accept. Co-authored-by: Maksim Levental <[email protected]>
1 parent f73bcdb commit 31536e6

File tree

4 files changed

+57
-12
lines changed

4 files changed

+57
-12
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/Bindings/Python/Nanobind.h"
1919
#include "mlir/Bindings/Python/NanobindAdaptors.h"
2020
#include "nanobind/nanobind.h"
21+
#include "nanobind/typing.h"
2122
#include "llvm/ADT/ArrayRef.h"
2223
#include "llvm/ADT/SmallVector.h"
2324

@@ -1482,7 +1483,11 @@ class PyConcreteValue : public PyValue {
14821483

14831484
/// Binds the Python module objects to functions of this class.
14841485
static void bind(nb::module_ &m) {
1485-
auto cls = ClassTy(m, DerivedTy::pyClassName);
1486+
auto cls = ClassTy(
1487+
m, DerivedTy::pyClassName, nb::is_generic(),
1488+
nb::sig((Twine("class ") + DerivedTy::pyClassName + "(Value[_T])")
1489+
.str()
1490+
.c_str()));
14861491
cls.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"));
14871492
cls.def_static(
14881493
"isinstance",
@@ -4605,7 +4610,10 @@ void mlir::python::populateIRCore(nb::module_ &m) {
46054610
//----------------------------------------------------------------------------
46064611
// Mapping of Value.
46074612
//----------------------------------------------------------------------------
4608-
nb::class_<PyValue>(m, "Value")
4613+
m.attr("_T") = nb::type_var("_T", nb::arg("bound") = m.attr("Type"));
4614+
4615+
nb::class_<PyValue>(m, "Value", nb::is_generic(),
4616+
nb::sig("class Value(Generic[_T])"))
46094617
.def(nb::init<PyValue &>(), nb::keep_alive<0, 1>(), nb::arg("value"),
46104618
"Creates a Value reference from another `Value`.")
46114619
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule,
@@ -4737,7 +4745,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
47374745
[](PyValue &self, const PyType &type) {
47384746
mlirValueSetType(self.get(), type);
47394747
},
4740-
nb::arg("type"), "Sets the type of the value.")
4748+
nb::arg("type"), "Sets the type of the value.",
4749+
nb::sig("def set_type(self, type: _T)"))
47414750
.def(
47424751
"replace_all_uses_with",
47434752
[](PyValue &self, PyValue &with) {

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -350,16 +350,16 @@ def MissingNamesOp : TestOp<"missing_names"> {
350350
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
351351

352352
// CHECK: @builtins.property
353-
// CHECK: def f32(self) -> _ods_ir.Value:
353+
// CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
354354
// CHECK: return self.operation.operands[1]
355355
let arguments = (ins I32, F32:$f32, I64);
356356

357357
// CHECK: @builtins.property
358-
// CHECK: def i32(self) -> _ods_ir.OpResult:
358+
// CHECK: def i32(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
359359
// CHECK: return self.operation.results[0]
360360
//
361361
// CHECK: @builtins.property
362-
// CHECK: def i64(self) -> _ods_ir.OpResult:
362+
// CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
363363
// CHECK: return self.operation.results[2]
364364
let results = (outs I32:$i32, AnyFloat, I64:$i64);
365365
}
@@ -590,20 +590,20 @@ def SimpleOp : TestOp<"simple"> {
590590
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)
591591

592592
// CHECK: @builtins.property
593-
// CHECK: def i32(self) -> _ods_ir.Value:
593+
// CHECK: def i32(self) -> _ods_ir.Value[_ods_ir.IntegerType]:
594594
// CHECK: return self.operation.operands[0]
595595
//
596596
// CHECK: @builtins.property
597-
// CHECK: def f32(self) -> _ods_ir.Value:
597+
// CHECK: def f32(self) -> _ods_ir.Value[_ods_ir.FloatType]:
598598
// CHECK: return self.operation.operands[1]
599599
let arguments = (ins I32:$i32, F32:$f32);
600600

601601
// CHECK: @builtins.property
602-
// CHECK: def i64(self) -> _ods_ir.OpResult:
602+
// CHECK: def i64(self) -> _ods_ir.OpResult[_ods_ir.IntegerType]:
603603
// CHECK: return self.operation.results[0]
604604
//
605605
// CHECK: @builtins.property
606-
// CHECK: def f64(self) -> _ods_ir.OpResult:
606+
// CHECK: def f64(self) -> _ods_ir.OpResult[_ods_ir.FloatType]:
607607
// CHECK: return self.operation.results[1]
608608
let results = (outs I64:$i64, AnyFloat:$f64);
609609
}

mlir/test/python/dialects/python_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def testOptionalOperandOp():
554554
)
555555
assert (
556556
typing.get_type_hints(test.OptionalOperandOp.result.fget)["return"]
557-
is OpResult
557+
== OpResult[IntegerType]
558558
)
559559
assert type(op1.result) is OpResult
560560

@@ -662,6 +662,13 @@ def testCustomType():
662662
raise
663663

664664

665+
@run
666+
# CHECK-LABEL: TEST: testValue
667+
def testValue():
668+
# Check that Value is a generic class at runtime.
669+
assert hasattr(Value, "__class_getitem__")
670+
671+
665672
@run
666673
# CHECK-LABEL: TEST: testTensorValue
667674
def testTensorValue():

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,22 @@ static std::string attrSizedTraitForKind(const char *kind) {
341341
StringRef(kind).drop_front());
342342
}
343343

344+
static StringRef getPythonType(StringRef cppType) {
345+
return llvm::StringSwitch<StringRef>(cppType)
346+
.Case("::mlir::MemRefType", "_ods_ir.MemRefType")
347+
.Case("::mlir::UnrankedMemRefType", "_ods_ir.UnrankedMemRefType")
348+
.Case("::mlir::RankedTensorType", "_ods_ir.RankedTensorType")
349+
.Case("::mlir::UnrankedTensorType", "_ods_ir.UnrankedTensorType")
350+
.Case("::mlir::VectorType", "_ods_ir.VectorType")
351+
.Case("::mlir::IntegerType", "_ods_ir.IntegerType")
352+
.Case("::mlir::FloatType", "_ods_ir.FloatType")
353+
.Case("::mlir::IndexType", "_ods_ir.IndexType")
354+
.Case("::mlir::ComplexType", "_ods_ir.ComplexType")
355+
.Case("::mlir::TupleType", "_ods_ir.TupleType")
356+
.Case("::mlir::NoneType", "_ods_ir.NoneType")
357+
.Default(StringRef());
358+
}
359+
344360
/// Emits accessors to "elements" of an Op definition. Currently, the supported
345361
/// elements are operands and results, indicated by `kind`, which must be either
346362
/// `operand` or `result` and is used verbatim in the emitted code.
@@ -370,8 +386,11 @@ static void emitElementAccessors(
370386
seenVariableLength = true;
371387
if (element.name.empty())
372388
continue;
373-
const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
389+
std::string type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
374390
: "_ods_ir.OpResult";
391+
if (StringRef pythonType = getPythonType(element.constraint.getCppType());
392+
!pythonType.empty())
393+
type = llvm::formatv("{0}[{1}]", type, pythonType);
375394
if (element.isVariableLength()) {
376395
if (element.isOptional()) {
377396
os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
@@ -418,6 +437,11 @@ static void emitElementAccessors(
418437
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
419438
: "_ods_ir.OpResult";
420439
}
440+
if (std::strcmp(kind, "operand") == 0) {
441+
StringRef pythonType = getPythonType(element.constraint.getCppType());
442+
if (!pythonType.empty())
443+
type += "[" + pythonType.str() + "]";
444+
}
421445
os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
422446
kind, numSimpleLength, numVariadicGroups,
423447
numPrecedingSimple, numPrecedingVariadic, type);
@@ -449,6 +473,11 @@ static void emitElementAccessors(
449473
if (!element.isVariableLength() || element.isOptional()) {
450474
type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
451475
: "_ods_ir.OpResult";
476+
if (std::strcmp(kind, "operand") == 0) {
477+
StringRef pythonType = getPythonType(element.constraint.getCppType());
478+
if (!pythonType.empty())
479+
type += "[" + pythonType.str() + "]";
480+
}
452481
if (!element.isVariableLength()) {
453482
trailing = "[0]";
454483
} else if (element.isOptional()) {

0 commit comments

Comments
 (0)