Skip to content

Commit 6c36d80

Browse files
authored
fix: mapreduce with unitrange dims (#1572)
1 parent 70ec4e3 commit 6c36d80

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.154"
4+
version = "0.2.155"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/TracedRArray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ function overloaded_mapreduce(
568568
original_dims = dims
569569
dims isa Int && (dims = Int64[dims])
570570
dims isa Colon && (dims = collect(Int64, 1:N))
571-
dims isa AbstractVector{<:Integer} || (dims = collect(Int64, dims))
571+
dims isa Vector{Int64} || (dims = collect(Int64, dims))
572572

573573
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
574574
reduce_init = __default_init(op_in_T, op)

test/basic.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,3 +1524,10 @@ end
15241524
)
15251525
end
15261526
end
1527+
1528+
@testset "mapreduce with unitrange dims" begin
1529+
x = reshape(collect(Float32, 1:64), 2, 4, 8)
1530+
x_ra = Reactant.to_rarray(x)
1531+
1532+
@test @jit(sum(x_ra; dims=1:2)) sum(x; dims=1:2)
1533+
end

0 commit comments

Comments
 (0)