@@ -926,33 +926,22 @@ function sort(
926926
927927 # TODO : move to @trace
928928 (a, b) = (ConcreteRNumber (T (0 )), ConcreteRNumber (T (0 )))
929- func = Reactant. make_mlir_fn (comparator, (a, b), (), " main" ; no_args_in_result= true )[2 ]
930-
929+ func = Reactant. make_mlir_fn (comparator, (a, b), (), " main" ; no_args_in_result= true , return_dialect = :stablehlo )[2 ]
930+ @assert MLIR . IR . nregions (func) == 1
931931 fn_name = String (
932932 MLIR. IR. attr (func, String (MLIR. API. mlirSymbolTableGetSymbolAttributeName ()))
933933 )
934934 # C5:
935935 @assert fn_name == " main" " $comparator : no function generated"
936- @assert MLIR. IR. nregions (func) == 1
937936 ftype_attr = MLIR. IR. attr (func, " function_type" )
938937 ftype = MLIR. IR. Type (ftype_attr)
939938 @assert MLIR. IR. result (ftype) == MLIR. IR. TensorType ((), MLIR. IR. Type (Bool)) error (
940939 " $comparator return type is not tensor<i1>"
941940 )
942941
943- # TODO : move takebody to utils?
944942 comparator = MLIR. IR. Region ()
945943 MLIR. API. mlirRegionTakeBody (comparator, MLIR. IR. region (func, 1 ))
946944 MLIR. IR. rmfromparent! (func)
947- for block in MLIR. IR. BlockIterator (comparator)
948- return_op = MLIR. IR. terminator (block)
949- MLIR. IR. name (return_op) == " func.return" || continue
950- operands = [MLIR. IR. operand (return_op, i) for i in 1 : MLIR. IR. noperands (return_op)]
951- MLIR. IR. block! (block) do
952- MLIR. Dialects. stablehlo. return_ (operands; location= MLIR. IR. location (return_op))
953- MLIR. IR. rmfromparent! (return_op)
954- end
955- end
956945
957946 dimension = MLIR. IR. Attribute (dimension - 1 )
958947 is_stable = MLIR. IR. Attribute (is_stable)
0 commit comments