@@ -914,28 +914,67 @@ end
914914# return TracedRArray{T,N}((), res, size(x))
915915# end
916916
917- # sorting ops
918- # TODO need to trace over `comparator`
919- # function sort(
920- # x::TracedRArray{T,N};
921- # comparator,
922- # dimension=-1,
923- # is_stable=false,
924- # location=mlir_stacktrace("sort", @__FILE__, @__LINE__),
925- # ) where {T,N}
926- # dimension = MLIR.IR.Attribute(dimension)
927- # is_stable = MLIR.IR.Attribute(is_stable)
928- # res = MLIR.IR.result(
929- # stablehlo.sort(
930- # x.mlir_data;
931- # result=mlir_type(TracedRArray{T,N}, size(x)),
932- # dimension,
933- # is_stable,
934- # location,
935- # ),
936- # )
937- # return TracedRArray{T,N}((), res, size(x))
938- # end
917+ function sort (
918+ x:: TracedRArray{T,N} ;
919+ comparator:: Function ,
920+ dimension= 1 ,
921+ is_stable= false ,
922+ location= mlir_stacktrace (" sort" , @__FILE__ , @__LINE__ ),
923+ ) where {T,N}
924+ # C4:
925+ @assert 0 < dimension <= ndims (x) " $x invalid dimension"
926+
927+ # C5:
928+ method = Base. methods (
929+ comparator, (Reactant. TracedRArray{T,N}, Reactant. TracedRArray{T,N})
930+ )
931+ @assert size (method, 1 ) != 0 error (" $comparator is not a valid comparator" )
932+ @assert size (method, 1 ) == 1 error (" $comparator ambiguous candidates" )
933+ # TODO : move to @trace
934+ (a, b) = (ConcreteRNumber (T (0 )), ConcreteRNumber (T (0 )))
935+ func = Reactant. make_mlir_fn (comparator, (a, b), (), " main" ; no_args_in_result= true )[2 ]
936+
937+ fn_name = String (
938+ MLIR. IR. attr (func, String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ()))
939+ )
940+ @assert fn_name == " main" " $comparator : no function generated"
941+ @assert MLIR. IR. nregions (func) == 1
942+ ftype_attr = MLIR. IR. attr (func, " function_type" )
943+ ftype = MLIR. IR. Type (ftype_attr)
944+ @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (Bool)) error (
945+ " $comparator return type is not tensor<i1>"
946+ )
947+
948+ # TODO : move takebody to utils?
949+ comparator = MLIR. IR. Region ()
950+ MLIR. API. mlirRegionTakeBody (comparator, MLIR. IR. region (func, 1 ))
951+ MLIR. IR. rmfromparent! (func)
952+ global leaked = comparator
953+ for block in MLIR. IR. BlockIterator (comparator)
954+ return_op = MLIR. IR. terminator (block)
955+ MLIR. IR. name (return_op) == " func.return" || continue
956+ operands = [MLIR. IR. operand (return_op, i) for i in 1 : MLIR. IR. noperands (return_op)]
957+ MLIR. IR. block! (block) do
958+ MLIR. Dialects. stablehlo. return_ (operands; location= MLIR. IR. location (return_op))
959+ MLIR. IR. rmfromparent! (return_op)
960+ end
961+ end
962+
963+ dimension = MLIR. IR. Attribute (dimension - 1 )
964+ is_stable = MLIR. IR. Attribute (is_stable)
965+
966+ res = MLIR. IR. result (
967+ stablehlo. sort (
968+ [x. mlir_data];
969+ result_0= [mlir_type (TracedRArray{T,N}, size (x))],
970+ dimension,
971+ is_stable,
972+ comparator,
973+ location,
974+ ),
975+ )
976+ return TracedRArray {T,N} ((), res, size (x))
977+ end
939978
940979function top_k (
941980 x:: TracedRArray{T,N} , k; location= mlir_stacktrace (" top_k" , @__FILE__ , @__LINE__ )
0 commit comments