Skip to content

Commit 0d7ad84

Browse files
committed
fix: correct semantics for Colon mapreduce
1 parent d9cf498 commit 0d7ad84

File tree

3 files changed

+14
-2
lines changed

3 files changed

+14
-2
lines changed

src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ end
575575
function compile_xla(f, args; client=nothing)
576576
# register MLIR dialects
577577
ctx = MLIR.IR.Context()
578-
Base.append!(Reactant.registry[]; context=ctx)
578+
append!(Reactant.registry[]; context=ctx)
579579
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
580580

581581
return MLIR.IR.context!(ctx) do

src/TracedRArray.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,11 @@ function Base.mapreduce(
431431
)
432432
red = TracedRArray{T,length(toonedims)}((), red, (toonedims...,))
433433
else
434-
red = TracedRArray{T,length(outdims)}((), red, (outdims...,))
434+
if length(outdims) == 0
435+
red = TracedRNumber{T}((), red)
436+
else
437+
red = TracedRArray{T,length(outdims)}((), red, (outdims...,))
438+
end
435439
end
436440
return red
437441
end

test/basic.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ end
6565

6666
sumexp(x) = sum(exp, x)
6767

68+
sum_compare(x) = sum(x) > 0
69+
6870
@testset "Basic mapreduce" begin
6971
x = rand(Float32, 10)
7072
a = Reactant.ConcreteRArray(x)
@@ -74,6 +76,12 @@ sumexp(x) = sum(exp, x)
7476
f_res = f(a)
7577

7678
@test f_res r_res
79+
80+
# Ensure we are tracing as scalars. Else this will fail due to > not being defined on
81+
# arrays
82+
f = @compile sum_compare(a)
83+
# We need to use [] to unwrap the scalar. We will fix this in the future.
84+
@test f(a)[] == sum_compare(x)
7785
end
7886

7987
function mysoftmax!(x)

0 commit comments

Comments
 (0)