Skip to content

Commit 91fdd05

Browse files
committed
stablehlo.sort Ops
1 parent 311498b commit 91fdd05

File tree

2 files changed

+73
-22
lines changed

2 files changed

+73
-22
lines changed

src/Ops.jl

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -914,28 +914,67 @@ end
914914
# return TracedRArray{T,N}((), res, size(x))
915915
# end
916916

917-
# sorting ops
918-
# TODO need to trace over `comparator`
919-
# function sort(
920-
# x::TracedRArray{T,N};
921-
# comparator,
922-
# dimension=-1,
923-
# is_stable=false,
924-
# location=mlir_stacktrace("sort", @__FILE__, @__LINE__),
925-
# ) where {T,N}
926-
# dimension = MLIR.IR.Attribute(dimension)
927-
# is_stable = MLIR.IR.Attribute(is_stable)
928-
# res = MLIR.IR.result(
929-
# stablehlo.sort(
930-
# x.mlir_data;
931-
# result=mlir_type(TracedRArray{T,N}, size(x)),
932-
# dimension,
933-
# is_stable,
934-
# location,
935-
# ),
936-
# )
937-
# return TracedRArray{T,N}((), res, size(x))
938-
# end
917+
function sort(
918+
x::TracedRArray{T,N};
919+
comparator::Function,
920+
dimension=1,
921+
is_stable=false,
922+
location=mlir_stacktrace("sort", @__FILE__, @__LINE__),
923+
) where {T,N}
924+
#C4:
925+
@assert 0 < dimension <= ndims(x) "$x invalid dimension"
926+
927+
#C5:
928+
method = Base.methods(
929+
comparator, (Reactant.TracedRArray{T,N}, Reactant.TracedRArray{T,N})
930+
)
931+
@assert size(method, 1) != 0 error("$comparator is not a valid comparator")
932+
@assert size(method, 1) == 1 error("$comparator ambiguous candidates")
933+
#TODO: move to @trace
934+
(a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0)))
935+
func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true)[2]
936+
937+
fn_name = String(
938+
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
939+
)
940+
@assert fn_name == "main" "$comparator: no function generated"
941+
@assert MLIR.IR.nregions(func) == 1
942+
ftype_attr = MLIR.IR.attr(func, "function_type")
943+
ftype = MLIR.IR.Type(ftype_attr)
944+
@assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error(
945+
"$comparator return type is not tensor<i1>"
946+
)
947+
948+
#TODO: move takebody to utils?
949+
comparator = MLIR.IR.Region()
950+
MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1))
951+
MLIR.IR.rmfromparent!(func)
952+
global leaked = comparator
953+
for block in MLIR.IR.BlockIterator(comparator)
954+
return_op = MLIR.IR.terminator(block)
955+
MLIR.IR.name(return_op) == "func.return" || continue
956+
operands = [MLIR.IR.operand(return_op, i) for i in 1:MLIR.IR.noperands(return_op)]
957+
MLIR.IR.block!(block) do
958+
MLIR.Dialects.stablehlo.return_(operands; location=MLIR.IR.location(return_op))
959+
MLIR.IR.rmfromparent!(return_op)
960+
end
961+
end
962+
963+
dimension = MLIR.IR.Attribute(dimension - 1)
964+
is_stable = MLIR.IR.Attribute(is_stable)
965+
966+
res = MLIR.IR.result(
967+
stablehlo.sort(
968+
[x.mlir_data];
969+
result_0=[mlir_type(TracedRArray{T,N}, size(x))],
970+
dimension,
971+
is_stable,
972+
comparator,
973+
location,
974+
),
975+
)
976+
return TracedRArray{T,N}((), res, size(x))
977+
end
939978

940979
function top_k(
941980
x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__)

test/ops.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,18 @@ end
646646
end
647647
end
648648

649+
@testset "sort" begin
650+
basic_sort(x, dimension) = Reactant.Ops.sort(x; comparator=(a, b) -> a < b, dimension)
651+
for i in 1:3
652+
t_size = tuple(fill(10, (i,))...)
653+
x = Reactant.to_rarray(randn(t_size))
654+
655+
for j in 1:i
656+
@test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(x, j)
657+
end
658+
end
659+
end
660+
649661
@testset "slice" begin
650662
x = ConcreteRArray([1, 2, 3, 4])
651663
@test [2, 3] == @jit Ops.slice(x, [2], [3])

0 commit comments

Comments
 (0)