Skip to content

Commit 8f9368a

Browse files
authored
feat: sorted search overloads (#1325)
* feat: sorted search overloads * feat: add reverse overloads * fix: searchsorted last * test: searchsorted routines
1 parent e66fffa commit 8f9368a

File tree

2 files changed

+108
-18
lines changed

2 files changed

+108
-18
lines changed

src/TracedRArray.jl

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_tra
2121
using ReactantCore: ReactantCore
2222
using GPUArraysCore: GPUArraysCore, @allowscalar
2323

24+
__lt(::Base.Order.ForwardOrdering, a, b) = isless.(a, b)
25+
__lt(o::Base.Order.ReverseOrdering, a, b) = __lt(o.fwd, b, a)
26+
__lt(o::Base.Order.By, a, b) = __lt(o.order, o.by.(a), o.by.(b))
27+
__lt(o::Base.Order.Lt, a, b) = o.lt.(a, b)
28+
2429
ReactantCore.is_traced(::TracedRArray, seen) = true
2530
ReactantCore.is_traced(::TracedRArray) = true
2631

@@ -943,20 +948,26 @@ function overloaded_stack(dims::Union{Integer,Colon}, xs)
943948
end
944949

945950
# sort
946-
function Base.sort(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
947-
return sort!(copy(x); alg, order, kwargs...)
951+
function Base.sort(x::AnyTracedRArray; alg=missing, kwargs...)
952+
return sort!(copy(x); alg, kwargs...)
948953
end
949-
function Base.sort(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
950-
return sort!(copy(x); alg, order, kwargs...)
954+
function Base.sort(x::AnyTracedRVector; alg=missing, kwargs...)
955+
return sort!(copy(x); alg, kwargs...)
951956
end
952957

953958
function Base.sort!(
954-
x::AnyTracedRVector; lt=isless, by=identity, rev::Bool=false, alg=missing, order=missing
959+
x::AnyTracedRVector;
960+
lt=isless,
961+
by=identity,
962+
rev::Bool=false,
963+
alg=missing,
964+
order=Base.Order.Forward,
955965
)
956966
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
957-
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"
958967

959-
comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b))
968+
ordering = Base.ord(lt, by, rev, order)
969+
comparator = (a, b) -> __lt(ordering, a, b)
970+
960971
res = only(Ops.sort(materialize_traced_array(x); comparator, dimension=1))
961972
set_mlir_data!(x, get_mlir_data(res))
962973
return x
@@ -969,22 +980,23 @@ function Base.sort!(
969980
by=identity,
970981
rev::Bool=false,
971982
alg=missing,
972-
order=missing,
983+
order=Base.Order.Forward,
973984
)
974985
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sort!`"
975-
@assert order === missing "Reactant doesn't support `order` kwarg for `sort!`"
976986

977-
comparator = rev ? (a, b) -> !lt(by(a), by(b)) : (a, b) -> lt(by(a), by(b))
987+
ordering = Base.ord(lt, by, rev, order)
988+
comparator = (a, b) -> __lt(ordering, a, b)
989+
978990
res = only(Ops.sort(materialize_traced_array(x); dimension=dims, comparator))
979991
set_mlir_data!(x, get_mlir_data(res))
980992
return x
981993
end
982994

983-
function Base.sortperm(x::AnyTracedRArray; alg=missing, order=missing, kwargs...)
984-
return sortperm!(similar(x, Int), x; alg, order, kwargs...)
995+
function Base.sortperm(x::AnyTracedRArray; alg=missing, kwargs...)
996+
return sortperm!(similar(x, Int), x; alg, kwargs...)
985997
end
986-
function Base.sortperm(x::AnyTracedRVector; alg=missing, order=missing, kwargs...)
987-
return sortperm!(similar(x, Int), x; alg, order, dims=1, kwargs...)
998+
function Base.sortperm(x::AnyTracedRVector; alg=missing, kwargs...)
999+
return sortperm!(similar(x, Int), x; alg, dims=1, kwargs...)
9881000
end
9891001

9901002
function Base.sortperm!(
@@ -995,18 +1007,18 @@ function Base.sortperm!(
9951007
by=identity,
9961008
rev::Bool=false,
9971009
alg=missing,
998-
order=missing,
1010+
order=Base.Order.Forward,
9991011
) where {N}
10001012
if dims === nothing
10011013
@assert ndims(x) == 1
10021014
dims = 1
10031015
end
10041016

10051017
@assert alg === missing "Reactant doesn't support `alg` kwarg for `sortperm!`"
1006-
@assert order === missing "Reactant doesn't support `order` kwarg for `sortperm!`"
10071018

1008-
comparator =
1009-
rev ? (a, b, i1, i2) -> !lt(by(a), by(b)) : (a, b, i1, i2) -> lt(by(a), by(b))
1019+
ordering = Base.ord(lt, by, rev, order)
1020+
comparator = (a, b, i1, i2) -> __lt(ordering, a, b)
1021+
10101022
idxs = Ops.constant(collect(LinearIndices(x)))
10111023
_, res = Ops.sort(materialize_traced_array(x), idxs; dimension=dims, comparator)
10121024
set_mlir_data!(ix, get_mlir_data(res))
@@ -1346,4 +1358,50 @@ function scan_impl!(
13461358
return output
13471359
end
13481360

1361+
function Base.searchsortedfirst(
1362+
v::AnyTracedRVector, x, lo::T, hi::T, o::Base.Ordering
1363+
) where {T<:Integer}
1364+
return sum(T.(__lt(o, v[lo:hi], x)); init=lo)
1365+
end
1366+
1367+
function Base.searchsortedlast(
1368+
v::AnyTracedRVector, x, lo::T, hi::T, o::Base.Ordering
1369+
) where {T<:Integer}
1370+
return sum(T.(.!(__lt(o, x, v[lo:hi]))); init=lo - 1)
1371+
end
1372+
1373+
function Base.searchsorted(
1374+
v::AnyTracedRVector, x, lo::T, hi::T, o::Base.Ordering
1375+
) where {T<:Integer}
1376+
firstidx = searchsortedfirst(v, x, lo, hi, o)
1377+
lastidx = searchsortedlast(v, x, lo, hi, o)
1378+
return Reactant.TracedRNumberOverrides.TracedUnitRange(firstidx, lastidx)
1379+
end
1380+
1381+
function Base.reverse(
1382+
v::AnyTracedRVector{T}, start::Integer, stop::Integer=lastindex(v)
1383+
) where {T}
1384+
v[start:stop] = reverse!(v[start:stop])
1385+
return v
1386+
end
1387+
1388+
function Base.reverse!(
1389+
v::AnyTracedRVector{T}, start::Integer, stop::Integer=lastindex(v)
1390+
) where {T}
1391+
reverse!(view(v, start:stop))
1392+
return v
1393+
end
1394+
1395+
function Base.reverse!(v::AnyTracedRVector{T}) where {T}
1396+
v_mat = materialize_traced_array(v)
1397+
copyto!(v, Ops.reverse(v_mat; dimensions=1))
1398+
return v
1399+
end
1400+
1401+
function Base._reverse!(a::AnyTracedRArray{T,N}, dims::NTuple{M,Int}) where {T,N,M}
1402+
a_mat = materialize_traced_array(a)
1403+
copyto!(a, Ops.reverse(a_mat; dimensions=dims))
1404+
return a
1405+
end
1406+
13491407
end

test/basic.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,3 +1295,35 @@ accum_fn(x, y) = abs2(x) + abs2(y)
12951295
end accumulate(accum_fn, b; dims=3, init=0.0f0)
12961296
end
12971297
end
1298+
1299+
sameunitrange(x, y) = first(x) == first(y) && last(x) == last(y)
1300+
1301+
@testset "searchsorted" begin
1302+
x = [1, 2, 4, 5, 5, 7]
1303+
x_ra = Reactant.to_rarray(x)
1304+
1305+
@testset "searchsortedfirst" begin
1306+
@testset for val in (4, 5, 3, 9, 0)
1307+
@test @jit(searchsortedfirst(x_ra, val)) == searchsortedfirst(x, val)
1308+
@test @jit(searchsortedfirst(x_ra, ConcreteRNumber(val))) ==
1309+
searchsortedfirst(x, val)
1310+
end
1311+
end
1312+
1313+
@testset "searchsortedlast" begin
1314+
@testset for val in (4, 5, 3, 9, 0)
1315+
@test @jit(searchsortedlast(x_ra, val)) == searchsortedlast(x, val)
1316+
@test @jit(searchsortedlast(x_ra, ConcreteRNumber(val))) ==
1317+
searchsortedlast(x, val)
1318+
end
1319+
end
1320+
1321+
@testset "searchsorted" begin
1322+
@testset for val in (4, 5, 3, 9, 0)
1323+
@test sameunitrange(@jit(searchsorted(x_ra, val)), searchsorted(x, val))
1324+
@test sameunitrange(
1325+
@jit(searchsorted(x_ra, ConcreteRNumber(val))), searchsorted(x, val)
1326+
)
1327+
end
1328+
end
1329+
end

0 commit comments

Comments
 (0)