Skip to content

Commit ab3a653

Browse files
committed
use return_dialect
1 parent e180896 commit ab3a653

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

src/Ops.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -926,33 +926,22 @@ function sort(
926926

927927
#TODO: move to @trace
928928
(a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0)))
929-
func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true)[2]
930-
929+
func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2]
930+
@assert MLIR.IR.nregions(func) == 1
931931
fn_name = String(
932932
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
933933
)
934934
#C5:
935935
@assert fn_name == "main" "$comparator: no function generated"
936-
@assert MLIR.IR.nregions(func) == 1
937936
ftype_attr = MLIR.IR.attr(func, "function_type")
938937
ftype = MLIR.IR.Type(ftype_attr)
939938
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error(
940939
"$comparator return type is not tensor<i1>"
941940
)
942941

943-
#TODO: move takebody to utils?
944942
comparator = MLIR.IR.Region()
945943
MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1))
946944
MLIR.IR.rmfromparent!(func)
947-
for block in MLIR.IR.BlockIterator(comparator)
948-
return_op = MLIR.IR.terminator(block)
949-
MLIR.IR.name(return_op) == "func.return" || continue
950-
operands = [MLIR.IR.operand(return_op, i) for i in 1:MLIR.IR.noperands(return_op)]
951-
MLIR.IR.block!(block) do
952-
MLIR.Dialects.stablehlo.return_(operands; location=MLIR.IR.location(return_op))
953-
MLIR.IR.rmfromparent!(return_op)
954-
end
955-
end
956945

957946
dimension = MLIR.IR.Attribute(dimension - 1)
958947
is_stable = MLIR.IR.Attribute(is_stable)

0 commit comments

Comments
 (0)