Skip to content

Commit 0fac4f4

Browse files
tkfdlfivefifty
authored andcommitted
Fix copyto! for multi-argument ApplyArrayBroadcastStyle (#56)
* Fix copyto! for multi-argument ApplyArrayBroadcastStyle * Bring back some examples that works in test/macrotests.jl
1 parent 4775a68 commit 0fac4f4

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

src/lazyapplying.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,30 @@ similar(bc::Broadcasted{ApplyArrayBroadcastStyle{N}}, ::Type{ElType}) where {N,E
8484
copyto!(dest, first(bc.args))
8585
end
8686
@inline function copyto!(dest::AbstractArray, bc::Broadcasted{ApplyArrayBroadcastStyle{N}}) where N
87-
@assert length(bc.args) == 1
88-
copyto!(dest, first(bc.args))
87+
if length(bc.args) == 1
88+
copyto!(dest, first(bc.args))
89+
if bc.f !== identity
90+
dest .= bc.f.(dest)
91+
end
92+
else
93+
bc′ = mapbc(bc) do x
94+
if x isa Applied
95+
materialize(x)
96+
else
97+
x
98+
end
99+
end
100+
materialize!(dest, bc′)
101+
end
102+
return dest
89103
end
90104

105+
# Map over all nested Broadcasted and their arguments. Using `broadcasted`
106+
# instead of `Broadcasted` to re-process arguments via `broadcastable`.
107+
@inline mapbc(f, bc::Broadcasted) =
108+
f(broadcasted(bc.f, map(a -> mapbc(f, a), bc.args)...))
109+
@inline mapbc(f, x) = f(x)
110+
91111
struct MatrixFunctionStyle{F} <: AbstractArrayApplyStyle end
92112

93113
for f in (:exp, :sin, :cos, :sqrt)

test/macrotests.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,18 @@ C = randn(6, 6)
1010
expressions_block = quote
1111
exp.(A)
1212
@. exp(A)
13-
# exp(A)
13+
exp(A)
1414
A .+ 2
1515
@. A + 2
1616
A + B
1717
@. A + B
1818
A * B + C
19-
# A * B .+ C
19+
A * B .+ C
2020
A * (B + C)
2121
# A * (B .+ C)
22-
# 2 .* (A * B) .+ 3 .* C
22+
2 .* (A * B) .+ 3 .* C
23+
exp.(A * C) # https://github.com/JuliaArrays/LazyArrays.jl/issues/54
24+
(A * A) .+ (A * C)
2325
end
2426
testparams = [
2527
("$(rmlines(ex))", ex) for ex in expressions_block.args if ex isa Expr
@@ -43,17 +45,19 @@ testparams = [
4345

4446
@testset "LazyArray(@~ $label)" begin
4547
actual = LazyArray(lazy) :: LazyArray
46-
@test actual == desired
48+
@test actual desired
4749
end
4850

4951
@testset "materialize(LazyArray(@~ $label))" begin
50-
@test materialize(LazyArray(lazy)) == desired
52+
@test_skip materialize(LazyArray(lazy)) == desired # should work
53+
@test materialize(LazyArray(lazy)) desired
5154
end
5255

5356
@testset ".= LazyArray(@~ $label)" begin
5457
actual = zero(desired)
5558
actual .= LazyArray(lazy)
56-
@test actual == desired
59+
@test_skip actual == desired # should work
60+
@test actual desired
5761
end
5862
end
5963
end

0 commit comments

Comments
 (0)