Skip to content

Commit 6ce5159

Browse files
authored
[MLIR][Python] Use ir.Value directly instead of _SubClassValueT (llvm#82341)
_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT: ... here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.
1 parent 6d160a4 commit 6ce5159

File tree

6 files changed

+10
-24
lines changed

6 files changed

+10
-24
lines changed

mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ class _Globals:
1010
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
1111

1212
def register_dialect(dialect_class: type) -> object: ...
13-
def register_operation(dialect_class: type) -> object: ...
13+
def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...

mlir/python/mlir/dialects/_ods_common.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Sequence as _Sequence,
99
Tuple as _Tuple,
1010
Type as _Type,
11-
TypeVar as _TypeVar,
1211
Union as _Union,
1312
)
1413

@@ -143,12 +142,6 @@ def get_op_result_or_op_results(
143142
else op
144143
)
145144

146-
147-
# This is the standard way to indicate subclass/inheritance relationship
148-
# see the typing.Type doc string.
149-
_U = _TypeVar("_U", bound=_cext.ir.Value)
150-
SubClassValueT = _Type[_U]
151-
152145
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
153146
ResultValueT = _Union[ResultValueTypeTuple]
154147
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]

mlir/python/mlir/dialects/arith.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
get_default_loc_context as _get_default_loc_context,
1313
_cext as _ods_cext,
1414
get_op_result_or_op_results as _get_op_result_or_op_results,
15-
SubClassValueT as _SubClassValueT,
1615
)
1716

1817
from typing import Any, List, Union
@@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]:
8180

8281
def constant(
8382
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
84-
) -> _SubClassValueT:
83+
) -> Value:
8584
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))

mlir/test/mlir-tblgen/op-python-bindings.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
77
// CHECK: @_ods_cext.register_dialect
88
// CHECK: class _Dialect(_ods_ir.Dialect):
99
// CHECK: DIALECT_NAMESPACE = "test"
10-
// CHECK: pass
1110
def Test_Dialect : Dialect {
1211
let name = "test";
1312
let cppNamespace = "Test";

mlir/test/python/ir/value.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import gc
44
from mlir.ir import *
55
from mlir.dialects import func
6-
from mlir.dialects._ods_common import SubClassValueT
76

87

98
def run(f):
@@ -270,7 +269,7 @@ def __str__(self):
270269
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)
271270

272271
@register_value_caster(IntegerType.static_typeid)
273-
def cast_int(v) -> SubClassValueT:
272+
def cast_int(v) -> Value:
274273
print("in caster", v.__class__.__name__)
275274
if isinstance(v, OpResult):
276275
return NOPResult(v)

mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
3131
3232
from ._ods_common import _cext as _ods_cext
3333
from ._ods_common import (
34-
SubClassValueT as _SubClassValueT,
3534
equally_sized_accessor as _ods_equally_sized_accessor,
3635
get_default_loc_context as _ods_get_default_loc_context,
3736
get_op_result_or_op_results as _get_op_result_or_op_results,
@@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
5251
@_ods_cext.register_dialect
5352
class _Dialect(_ods_ir.Dialect):
5453
DIALECT_NAMESPACE = "{0}"
55-
pass
56-
5754
)Py";
5855

5956
constexpr const char *dialectExtensionTemplate = R"Py(
@@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
10071004
});
10081005
std::string nameWithoutDialect =
10091006
op.getOperationName().substr(op.getOperationName().find('.') + 1);
1010-
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
1011-
op.getCppClassName(),
1012-
llvm::join(valueBuilderParams, ", "),
1013-
llvm::join(opBuilderArgs, ", "),
1014-
(op.getNumResults() > 1
1015-
? "_Sequence[_SubClassValueT]"
1016-
: (op.getNumResults() > 0 ? "_SubClassValueT"
1017-
: "_ods_ir.Operation")));
1007+
os << llvm::formatv(
1008+
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
1009+
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
1010+
llvm::join(opBuilderArgs, ", "),
1011+
(op.getNumResults() > 1
1012+
? "_Sequence[_ods_ir.Value]"
1013+
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
10181014
}
10191015

10201016
/// Emits bindings for a specific Op to the given output stream.

0 commit comments

Comments
 (0)