Skip to content

Commit 9c3e172

Browse files
authored
fix: iteration ordering for findfirst (#1537)
* fix: iteration ordering for findfirst * Update src/TracedRArray.jl * fix: return -1 * test: make it deterministic * chore: remove unused deps * Update src/Ops.jl
1 parent d0b6471 commit 9c3e172

File tree

3 files changed

+40
-7
lines changed

3 files changed

+40
-7
lines changed

src/Ops.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,30 @@ end
12751275
return (; values, indices)
12761276
end
12771277

1278+
# Taken from https://github.com/JuliaGPU/GPUArrays.jl/blob/49a339c63a50f1a00ac84844675bcb3a11070cb0/src/host/indexing.jl#L193
1279+
@noinline function findfirst(
1280+
x::TracedRArray{Bool,N};
1281+
dimension::Integer=N,
1282+
location=mlir_stacktrace("findfirst", @__FILE__, @__LINE__),
1283+
) where {N}
1284+
return reduce(
1285+
TracedRArray[
1286+
x, iota(Int64, collect(Int64, size(x)); iota_dimension=dimension, location)
1287+
],
1288+
TracedRNumber[
1289+
Reactant.TracedUtils.promote_to(TracedRNumber{Bool}, false),
1290+
Reactant.TracedUtils.promote_to(TracedRNumber{Int64}, typemax(Int64)),
1291+
],
1292+
[dimension],
1293+
function (x, i, y, j)
1294+
cond_val = x | y
1295+
idx = ifelse(x, ifelse(i < j, i, j), ifelse(y, j, typemax(Int64)))
1296+
return cond_val, idx
1297+
end;
1298+
location,
1299+
)[2] .+ 1
1300+
end
1301+
12781302
@noinline function argmax(
12791303
x::TracedRArray{T,N};
12801304
dimension::Integer=N,

src/TracedRArray.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,16 +1114,17 @@ Base.argmax(x::AnyTracedRArray; kwargs...) = findmax(identity, x; kwargs...)[2]
11141114
Base.findfirst(x::AnyTracedRArray) = findfirst(identity, x)
11151115
Base.findlast(x::AnyTracedRArray) = findlast(identity, x)
11161116

1117+
# FIXME: we need to conditionally return `nothing` here if idx < 0
11171118
function Base.findfirst(f::Function, x::AnyTracedRArray)
1118-
fA = materialize_traced_array(vec(f.(x)))
1119-
(; indices) = Ops.top_k(fA, 1)
1120-
return @allowscalar indices[1]
1119+
idx = Ops.findfirst(materialize_traced_array(vec(f.(x))))
1120+
return TracedRNumber{Int}((), idx.mlir_data)
11211121
end
11221122

1123+
# FIXME: we need to conditionally return `nothing` here if idx < 0
11231124
function Base.findlast(f::Function, x::AnyTracedRArray)
11241125
fA = Ops.reverse(materialize_traced_array(vec(f.(x))); dimensions=[1])
1125-
(; indices) = Ops.top_k(fA, 1)
1126-
return length(x) - @allowscalar(indices[1]) + 1
1126+
idx = Ops.findfirst(fA)
1127+
return length(x) + 1 - TracedRNumber{Int}((), idx.mlir_data)
11271128
end
11281129

11291130
Base.findmin(x::AnyTracedRVector) = findmin(identity, x; dims=1)

test/sorting.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ end
171171
end
172172

173173
@testset "findfirst / findlast" begin
174-
x = rand(Bool, 3, 4)
174+
x = Bool[
175+
0 0 0 0
176+
1 0 1 0
177+
0 1 0 1
178+
]
175179
x_ra = Reactant.to_rarray(x)
176180

177181
ffirstlinindices(x) = LinearIndices(x)[findfirst(x)]
@@ -182,7 +186,11 @@ end
182186
@test ffirstlinindices(x) == @jit(findfirst(x_ra))
183187
@test flastlinindices(x) == @jit(findlast(x_ra))
184188

185-
x = rand(1:256, 3, 4)
189+
x = Int64[
190+
3 5 7 9
191+
4 6 7 8
192+
5 7 8 9
193+
]
186194
x_ra = Reactant.to_rarray(x)
187195

188196
@test ffirstlinindices(iseven, x) == @jit(findfirst(iseven, x_ra))

0 commit comments

Comments
 (0)