Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
Optional<AnyType>:$variadic2);
}

// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
// CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -> AttrSizedOperandsOp:
// CHECK: return AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -108,8 +108,9 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
Variadic<AnyType>:$variadic2);
}

// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, AttrSizedResultsOp]:
// CHECK: op = AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)


// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -159,7 +160,7 @@ def AttributedOp : TestOp<"attributed_op"> {
UnitAttr:$unitAttr, I32Attr:$in);
}

// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
// CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -> AttributedOp:
// CHECK: return AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -196,7 +197,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
let arguments = (ins I32, UnitAttr:$in, F32, OptionalAttr<F32Attr>:$is);
}

// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
// CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -> AttributedOpWithOperands
// CHECK: return AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -221,7 +222,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
let results = (outs);
}

// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
// CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -> DefaultValuedAttrsOp:
// CHECK: return DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
Expand All @@ -239,7 +240,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
let results = (outs AnyType:$res, AnyType);
}

// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None)
// CHECK: def derive_result_types_op(type_, *, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return DeriveResultTypesOp(type_=type_, results=results, loc=loc, ip=ip).results

// CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
Expand All @@ -249,8 +250,9 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
let results = (outs AnyType:$res, Variadic<AnyType>);
}

// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
// CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, DeriveResultTypesVariadicOp]:
// CHECK: op = DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class EmptyOp(_ods_ir.OpView):
Expand All @@ -267,7 +269,7 @@ def EmptyOp : TestOp<"empty">;
// CHECK: attributes=attributes, results=results, operands=operands,
// CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)

// CHECK: def empty(*, loc=None, ip=None)
// CHECK: def empty(*, loc=None, ip=None) -> EmptyOp:
// CHECK: return EmptyOp(loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
Expand All @@ -281,7 +283,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
let results = (outs I32:$i32, F32:$f32);
}

// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None)
// CHECK: def infer_result_types_implied_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return InferResultTypesImpliedOp(results=results, loc=loc, ip=ip).results

// CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
Expand All @@ -295,7 +297,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
let results = (outs AnyType, AnyType, AnyType);
}

// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None)
// CHECK: def infer_result_types_op(*, results=None, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return InferResultTypesOp(results=results, loc=loc, ip=ip).results

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -334,7 +336,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
let results = (outs I32:$i32, AnyFloat, I64:$i64);
}

// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
// CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip).results

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -366,7 +368,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
// CHECK: return None if len(self.operation.operands) < 2 else self.operation.operands[1]
}

// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
// CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -> OneOptionalOperandOp:
// CHECK: return OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -399,7 +401,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
let arguments = (ins AnyType:$non_variadic, Variadic<AnyType>:$variadic);
}

// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
// CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -> OneVariadicOperandOp:
// CHECK: return OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand Down Expand Up @@ -433,8 +435,9 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
let results = (outs Variadic<AnyType>:$variadic, AnyType:$non_variadic);
}

// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
// CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, OneVariadicResultOp]:
// CHECK: op = OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class PythonKeywordOp(_ods_ir.OpView):
Expand All @@ -458,7 +461,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
let arguments = (ins AnyType:$in);
}

// CHECK: def python_keyword(in_, *, loc=None, ip=None)
// CHECK: def python_keyword(in_, *, loc=None, ip=None) -> PythonKeywordOp:
// CHECK: return PythonKeywordOp(in_=in_, loc=loc, ip=ip)

// CHECK-LABEL: OPERATION_NAME = "test.same_results"
Expand All @@ -471,8 +474,8 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
let results = (outs AnyType:$res);
}

// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None)
// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip)
// CHECK: def same_results(in1, in2, *, results=None, loc=None, ip=None) -> _ods_ir.OpResult:
// CHECK: return SameResultsOp(in1=in1, in2=in2, results=results, loc=loc, ip=ip).result

// CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
Expand All @@ -481,8 +484,9 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
let results = (outs Variadic<AnyType>:$res);
}

// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
// CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameResultsVariadicOp]:
// CHECK: op = SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)


// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -508,7 +512,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
Variadic<AnyType>:$variadic2);
}

// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> SameVariadicOperandSizeOp:
// CHECK: return SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -534,8 +538,9 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
Variadic<AnyType>:$variadic2);
}

// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
// CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, SameVariadicResultSizeOp]:
// CHECK: op = SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip); results = op.results
// CHECK: return results if len(results) > 1 else (results[0] if len(results) == 1 else op)

// CHECK: @_ods_cext.register_operation(_Dialect)
// CHECK: class SimpleOp(_ods_ir.OpView):
Expand Down Expand Up @@ -575,7 +580,7 @@ def SimpleOp : TestOp<"simple"> {
let results = (outs I64:$i64, AnyFloat:$f64);
}

// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
// CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -> _ods_ir.OpResultList:
// CHECK: return SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip).results

// CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
Expand Down Expand Up @@ -603,7 +608,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
// CHECK: return self.regions[2:]
}

// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
// CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -> VariadicAndNormalRegionOp:
// CHECK: return VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: class VariadicRegionOp(_ods_ir.OpView):
Expand All @@ -627,7 +632,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
// CHECK: return self.regions[0:]
}

// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
// CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -> VariadicRegionOp:
// CHECK: return VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -636,7 +641,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
}

// CHECK: def _123with__special_characters(*, loc=None, ip=None)
// CHECK: def _123with__special_characters(*, loc=None, ip=None) -> WithSpecialCharactersOp:
// CHECK: return WithSpecialCharactersOp(loc=loc, ip=ip)

// CHECK: @_ods_cext.register_operation(_Dialect)
Expand All @@ -651,11 +656,11 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
VariadicSuccessor<AnySuccessor>:$successors);
}

// CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
// CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -> WithSuccessorsOp:
// CHECK: return WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)

// CHECK: class snake_case(_ods_ir.OpView):
// CHECK-LABEL: OPERATION_NAME = "test.snake_case"
def already_snake_case : TestOp<"snake_case"> {}
// CHECK: def snake_case_(*, loc=None, ip=None)
// CHECK: def snake_case_(*, loc=None, ip=None) -> snake_case:
// CHECK: return snake_case(loc=loc, ip=ip)
43 changes: 42 additions & 1 deletion mlir/test/python/dialects/python_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# RUN: %PYTHON %s pybind11 | FileCheck %s
# RUN: %PYTHON %s nanobind | FileCheck %s

import inspect
import sys
from typing import Union

from mlir.ir import *
import mlir.dialects.func as func
import mlir.dialects.python_test as test
Expand Down Expand Up @@ -323,6 +325,7 @@ def resultTypesDefinedByTraits():
# CHECK: f32 index
print(no_infer.single.type, no_infer.doubled.type)


# CHECK-LABEL: TEST: testOptionalOperandOp
@run
def testOptionalOperandOp():
Expand Down Expand Up @@ -594,6 +597,17 @@ def testInferTypeOpInterface():
# CHECK: f32
print(two_operands.result.type)

assert (
inspect.signature(
test.infer_results_variadic_inputs_op
).return_annotation
is OpResult
)
assert isinstance(
test.infer_results_variadic_inputs_op(single=zero, doubled=zero),
OpResult,
)


# CHECK-LABEL: TEST: testVariadicOperandAccess
@run
Expand Down Expand Up @@ -621,6 +635,15 @@ def values(lst):
# CHECK: ['Value(%{{.*}} = arith.constant 3 : i32)', 'Value(%{{.*}} = arith.constant 4 : i32)']
print(values(variadic_operands.variadic2))

assert (
inspect.signature(test.same_variadic_operand).return_annotation
is test.SameVariadicOperandSizeOp
)
assert isinstance(
test.same_variadic_operand([zero, one], two, [three, four]),
test.SameVariadicOperandSizeOp,
)


# CHECK-LABEL: TEST: testVariadicResultAccess
@run
Expand All @@ -642,6 +665,15 @@ def types(lst):
# CHECK: [IntegerType(i3), IntegerType(i4)]
print(types(op.variadic2))

assert (
inspect.signature(test.same_variadic_result_vfv).return_annotation
is Union[OpResult, OpResultList, test.SameVariadicResultSizeOpVFV]
)
assert isinstance(
test.same_variadic_result_vfv([i[0], i[1]], i[2], [i[3], i[4]]),
OpResultList,
)

# Test Variadic-Variadic-Variadic
op = test.SameVariadicResultSizeOpVVV(
[i[0], i[1]], [i[2], i[3]], [i[4], i[5]]
Expand Down Expand Up @@ -713,3 +745,12 @@ def types(lst):
print(types(op.variadic2))
# CHECK: i4
print(op.non_variadic3.type)

assert (
inspect.signature(test.results_variadic).return_annotation
is Union[OpResult, OpResultList, test.ResultsVariadicOp]
)
assert isinstance(
test.results_variadic([i[0]]),
OpResult,
)
6 changes: 3 additions & 3 deletions mlir/test/python/ir/auto_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def testInferLocations():
_cext.globals.register_traceback_file_inclusion(_arith_ops_gen.__file__)
three = arith.constant(IndexType.get(), 3)
# fmt: off
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":52:16 to :50) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4)))))
# fmt: on
print(three.location)

Expand All @@ -60,14 +60,14 @@ def foo():
print(four.location)

# fmt: off
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
# CHECK: loc(callsite("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235) at callsite("testInferLocations.<locals>.foo"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":59:19 to :53) at callsite("testInferLocations"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":65:8 to :13) at callsite("run"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":13:4 to :7) at "<module>"("{{.*}}[[SEP]]test[[SEP]]python[[SEP]]ir[[SEP]]auto_location.py":26:1 to :4))))))
# fmt: on
foo()

_cext.globals.register_traceback_file_exclusion(__file__)

# fmt: off
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":397:4 to :235))
# CHECK: loc("ConstantOp.__init__"("{{.*}}[[SEP]]mlir[[SEP]]dialects[[SEP]]_arith_ops_gen.py":396:4 to :235))
# fmt: on
foo()

Expand Down
4 changes: 4 additions & 0 deletions mlir/test/python/python_test_ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -265,4 +265,8 @@ def SameVariadicResultSizeOpFVFVF : TestOp<"same_variadic_result_fvfvf",
AnyType:$non_variadic3);
}

def ResultsVariadicOp : TestOp<"results_variadic"> {
let results = (outs Variadic<AnyType>:$res);
}

#endif // PYTHON_TEST_OPS
Loading
Loading