@@ -21,6 +21,11 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_tra
21
21
using ReactantCore: ReactantCore
22
22
using GPUArraysCore: GPUArraysCore, @allowscalar
23
23
24
+ __lt (:: Base.Order.ForwardOrdering , a, b) = isless .(a, b)
25
+ __lt (o:: Base.Order.ReverseOrdering , a, b) = __lt (o. fwd, b, a)
26
+ __lt (o:: Base.Order.By , a, b) = __lt (o. order, o. by .(a), o. by .(b))
27
+ __lt (o:: Base.Order.Lt , a, b) = o. lt .(a, b)
28
+
24
29
ReactantCore. is_traced (:: TracedRArray , seen) = true
25
30
ReactantCore. is_traced (:: TracedRArray ) = true
26
31
@@ -943,20 +948,26 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
943
948
end
944
949
945
950
# sort
946
- function Base. sort (x:: AnyTracedRArray ; alg= missing , order = missing , kwargs... )
947
- return sort! (copy (x); alg, order, kwargs... )
951
+ function Base. sort (x:: AnyTracedRArray ; alg= missing , kwargs... )
952
+ return sort! (copy (x); alg, kwargs... )
948
953
end
949
- function Base. sort (x:: AnyTracedRVector ; alg= missing , order = missing , kwargs... )
950
- return sort! (copy (x); alg, order, kwargs... )
954
+ function Base. sort (x:: AnyTracedRVector ; alg= missing , kwargs... )
955
+ return sort! (copy (x); alg, kwargs... )
951
956
end
952
957
953
958
function Base. sort! (
954
- x:: AnyTracedRVector ; lt= isless, by= identity, rev:: Bool = false , alg= missing , order= missing
959
+ x:: AnyTracedRVector ;
960
+ lt= isless,
961
+ by= identity,
962
+ rev:: Bool = false ,
963
+ alg= missing ,
964
+ order= Base. Order. Forward,
955
965
)
956
966
@assert alg === missing " Reactant doesn't support `alg` kwarg for `sort!`"
957
- @assert order === missing " Reactant doesn't support `order` kwarg for `sort!`"
958
967
959
- comparator = rev ? (a, b) -> ! lt (by (a), by (b)) : (a, b) -> lt (by (a), by (b))
968
+ ordering = Base. ord (lt, by, rev, order)
969
+ comparator = (a, b) -> __lt (ordering, a, b)
970
+
960
971
res = only (Ops. sort (materialize_traced_array (x); comparator, dimension= 1 ))
961
972
set_mlir_data! (x, get_mlir_data (res))
962
973
return x
@@ -969,22 +980,23 @@ function Base.sort!(
969
980
by= identity,
970
981
rev:: Bool = false ,
971
982
alg= missing ,
972
- order= missing ,
983
+ order= Base . Order . Forward ,
973
984
)
974
985
@assert alg === missing " Reactant doesn't support `alg` kwarg for `sort!`"
975
- @assert order === missing " Reactant doesn't support `order` kwarg for `sort!`"
976
986
977
- comparator = rev ? (a, b) -> ! lt (by (a), by (b)) : (a, b) -> lt (by (a), by (b))
987
+ ordering = Base. ord (lt, by, rev, order)
988
+ comparator = (a, b) -> __lt (ordering, a, b)
989
+
978
990
res = only (Ops. sort (materialize_traced_array (x); dimension= dims, comparator))
979
991
set_mlir_data! (x, get_mlir_data (res))
980
992
return x
981
993
end
982
994
983
- function Base. sortperm (x:: AnyTracedRArray ; alg= missing , order = missing , kwargs... )
984
- return sortperm! (similar (x, Int), x; alg, order, kwargs... )
995
+ function Base. sortperm (x:: AnyTracedRArray ; alg= missing , kwargs... )
996
+ return sortperm! (similar (x, Int), x; alg, kwargs... )
985
997
end
986
- function Base. sortperm (x:: AnyTracedRVector ; alg= missing , order = missing , kwargs... )
987
- return sortperm! (similar (x, Int), x; alg, order, dims= 1 , kwargs... )
998
+ function Base. sortperm (x:: AnyTracedRVector ; alg= missing , kwargs... )
999
+ return sortperm! (similar (x, Int), x; alg, dims= 1 , kwargs... )
988
1000
end
989
1001
990
1002
function Base. sortperm! (
@@ -995,18 +1007,18 @@ function Base.sortperm!(
995
1007
by= identity,
996
1008
rev:: Bool = false ,
997
1009
alg= missing ,
998
- order= missing ,
1010
+ order= Base . Order . Forward ,
999
1011
) where {N}
1000
1012
if dims === nothing
1001
1013
@assert ndims (x) == 1
1002
1014
dims = 1
1003
1015
end
1004
1016
1005
1017
@assert alg === missing " Reactant doesn't support `alg` kwarg for `sortperm!`"
1006
- @assert order === missing " Reactant doesn't support `order` kwarg for `sortperm!`"
1007
1018
1008
- comparator =
1009
- rev ? (a, b, i1, i2) -> ! lt (by (a), by (b)) : (a, b, i1, i2) -> lt (by (a), by (b))
1019
+ ordering = Base. ord (lt, by, rev, order)
1020
+ comparator = (a, b, i1, i2) -> __lt (ordering, a, b)
1021
+
1010
1022
idxs = Ops. constant (collect (LinearIndices (x)))
1011
1023
_, res = Ops. sort (materialize_traced_array (x), idxs; dimension= dims, comparator)
1012
1024
set_mlir_data! (ix, get_mlir_data (res))
@@ -1346,4 +1358,50 @@ function scan_impl!(
1346
1358
return output
1347
1359
end
1348
1360
1361
+ function Base. searchsortedfirst (
1362
+ v:: AnyTracedRVector , x, lo:: T , hi:: T , o:: Base.Ordering
1363
+ ) where {T<: Integer }
1364
+ return sum (T .(__lt (o, v[lo: hi], x)); init= lo)
1365
+ end
1366
+
1367
+ function Base. searchsortedlast (
1368
+ v:: AnyTracedRVector , x, lo:: T , hi:: T , o:: Base.Ordering
1369
+ ) where {T<: Integer }
1370
+ return sum (T .(.! (__lt (o, x, v[lo: hi]))); init= lo - 1 )
1371
+ end
1372
+
1373
+ function Base. searchsorted (
1374
+ v:: AnyTracedRVector , x, lo:: T , hi:: T , o:: Base.Ordering
1375
+ ) where {T<: Integer }
1376
+ firstidx = searchsortedfirst (v, x, lo, hi, o)
1377
+ lastidx = searchsortedlast (v, x, lo, hi, o)
1378
+ return Reactant. TracedRNumberOverrides. TracedUnitRange (firstidx, lastidx)
1379
+ end
1380
+
1381
+ function Base. reverse (
1382
+ v:: AnyTracedRVector{T} , start:: Integer , stop:: Integer = lastindex (v)
1383
+ ) where {T}
1384
+ v[start: stop] = reverse! (v[start: stop])
1385
+ return v
1386
+ end
1387
+
1388
+ function Base. reverse! (
1389
+ v:: AnyTracedRVector{T} , start:: Integer , stop:: Integer = lastindex (v)
1390
+ ) where {T}
1391
+ reverse! (view (v, start: stop))
1392
+ return v
1393
+ end
1394
+
1395
+ function Base. reverse! (v:: AnyTracedRVector{T} ) where {T}
1396
+ v_mat = materialize_traced_array (v)
1397
+ copyto! (v, Ops. reverse (v_mat; dimensions= 1 ))
1398
+ return v
1399
+ end
1400
+
1401
+ function Base. _reverse! (a:: AnyTracedRArray{T,N} , dims:: NTuple{M,Int} ) where {T,N,M}
1402
+ a_mat = materialize_traced_array (a)
1403
+ copyto! (a, Ops. reverse (a_mat; dimensions= dims))
1404
+ return a
1405
+ end
1406
+
1349
1407
end
0 commit comments