diff --git a/src/Ops.jl b/src/Ops.jl index 013e0dbc8e..9a6c50106d 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -914,28 +914,50 @@ end # return TracedRArray{T,N}((), res, size(x)) # end -# sorting ops -# TODO need to trace over `comparator` -# function sort( -# x::TracedRArray{T,N}; -# comparator, -# dimension=-1, -# is_stable=false, -# location=mlir_stacktrace("sort", @__FILE__, @__LINE__), -# ) where {T,N} -# dimension = MLIR.IR.Attribute(dimension) -# is_stable = MLIR.IR.Attribute(is_stable) -# res = MLIR.IR.result( -# stablehlo.sort( -# x.mlir_data; -# result=mlir_type(TracedRArray{T,N}, size(x)), -# dimension, -# is_stable, -# location, -# ), -# ) -# return TracedRArray{T,N}((), res, size(x)) -# end +function sort( + x::TracedRArray{T,N}; + comparator, + dimension=1, + is_stable=false, + location=mlir_stacktrace("sort", @__FILE__, @__LINE__), +) where {T,N} + #C4: + @assert 0 < dimension <= ndims(x) "$x invalid dimension" + + #TODO: move to @trace + (a, b) = (ConcreteRNumber(T(0)), ConcreteRNumber(T(0))) + func = Reactant.make_mlir_fn(comparator, (a, b), (), "main"; no_args_in_result=true, return_dialect=:stablehlo)[2] + @assert MLIR.IR.nregions(func) == 1 + fn_name = String( + MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())) + ) + #C5: + @assert fn_name == "main" "$comparator: no function generated" + ftype_attr = MLIR.IR.attr(func, "function_type") + ftype = MLIR.IR.Type(ftype_attr) + @assert MLIR.IR.result(ftype) == MLIR.IR.TensorType((), MLIR.IR.Type(Bool)) error( + "$comparator return type is not tensor" + ) + + comparator = MLIR.IR.Region() + MLIR.API.mlirRegionTakeBody(comparator, MLIR.IR.region(func, 1)) + MLIR.IR.rmfromparent!(func) + + dimension = MLIR.IR.Attribute(dimension - 1) + is_stable = MLIR.IR.Attribute(is_stable) + + res = MLIR.IR.result( + stablehlo.sort( + [x.mlir_data]; + result_0=[mlir_type(TracedRArray{T,N}, size(x))], + dimension, + is_stable, + comparator, + location, + ), + ) + return TracedRArray{T,N}((), res, size(x)) +end function top_k( x::TracedRArray{T,N}, k; location=mlir_stacktrace("top_k", @__FILE__, @__LINE__) diff --git a/test/Project.toml b/test/Project.toml index 4b50a487fc..b8c18a9167 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -25,7 +25,7 @@ ArrayInterface = "7.10" BenchmarkTools = "1.5" Enzyme = "0.13.21" FFTW = "1.8" -Flux = "0.15" +Flux = "0.15, 0.16" Functors = "0.5" InteractiveUtils = "1.10" LinearAlgebra = "1.10" diff --git a/test/ops.jl b/test/ops.jl index 07f911e88b..81e6d8feb8 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -646,6 +646,18 @@ end end end +@testset "sort" begin + basic_sort(x, dimension) = Reactant.Ops.sort(x; comparator=(a, b) -> a < b, dimension) + for i in 1:3 + t_size = tuple(fill(10, (i,))...) + x = Reactant.to_rarray(randn(t_size)) + + for j in 1:i + @test (i == 1 ? sort(x) : sort(x; dims=j)) == @jit basic_sort(x, j) + end + end +end + @testset "slice" begin x = ConcreteRArray([1, 2, 3, 4]) @test [2, 3] == @jit Ops.slice(x, [2], [3])