Skip to content

Commit 201aafb

Browse files
authored
Fix implementation of fused_map_reduce (#171)
1 parent 4dc6932 commit 201aafb

File tree

3 files changed

+25
-1
lines changed

3 files changed

+25
-1
lines changed

src/interface.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -666,6 +666,15 @@ function promote_operation_fallback(
666666
return promote_operation(*, promote_operation(adjoint, A), B)
667667
end
668668

669+
function promote_operation_fallback(
670+
::typeof(LinearAlgebra.dot),
671+
::Type{<:AbstractArray{A}},
672+
::Type{<:AbstractArray{B}},
673+
) where {A,B}
674+
C = promote_operation(*, A, B)
675+
return promote_operation(+, C, C)
676+
end
677+
669678
function buffer_for(::typeof(add_dot), a::Type, b::Type, c::Type)
670679
return buffer_for(add_mul, a, promote_operation(adjoint, b), c)
671680
end

src/reduce.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@ function promote_map_reduce(op::Function, args::Vararg{Any,N}) where {N}
3636
)
3737
end
3838

39+
_concrete_eltype(x) = isempty(x) ? eltype(x) : typeof(first(x))
40+
3941
function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N}
4042
_check_same_length(args...)
41-
T = promote_map_reduce(op, eltype.(args)...)
43+
T = promote_map_reduce(op, _concrete_eltype.(args)...)
4244
accumulator = neutral_element(reduce_op(op), T)
4345
buffer = buffer_for(op, T, eltype.(args)...)
4446
for I in zip(eachindex.(args)...)

test/dispatch.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,17 @@ end
2929
# On `DummyBigInt` allocates more on previous releases of Julia
3030
# as it's dynamically allocated
3131
dispatch_tests(DummyBigInt)
32+
33+
@testset "dot non-concrete vector" begin
34+
x = [5.0, 6.0]
35+
y = Vector{Union{Float64,String}}(x)
36+
@test MA.operate(LinearAlgebra.dot, x, y) == LinearAlgebra.dot(x, y)
37+
@test MA.operate(*, x', y) == x' * y
38+
end
39+
40+
@testset "dot vector of vectors" begin
41+
x = [5.0, 6.0]
42+
z = [x, x]
43+
@test MA.operate(LinearAlgebra.dot, z, z) == LinearAlgebra.dot(z, z)
44+
end
3245
end

0 commit comments

Comments
 (0)