Skip to content

Commit 772c307

Browse files
authored
Improve type-inference in broadcasting (#351)
* Improve type-inference in broadcasting * Add test using JET * Allow more JET versions * Reshape dest if ndims don't match * Use ndims from type param * Rename variable
1 parent 74e4928 commit 772c307

File tree

3 files changed

+34
-3
lines changed

3 files changed

+34
-3
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Aqua = "0.8"
1212
ArrayLayouts = "1.0.8"
1313
Documenter = "1"
1414
FillArrays = "1"
15+
JET = "0.4, 0.6, 0.7, 0.8"
1516
LinearAlgebra = "1.6"
1617
OffsetArrays = "1"
1718
Random = "1.6"
@@ -23,11 +24,12 @@ julia = "1.6"
2324
[extras]
2425
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
2526
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
27+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
2628
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
2729
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2830
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2931
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3032
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3133

3234
[targets]
33-
test = ["Aqua", "Documenter", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"]
35+
test = ["Aqua", "Documenter", "JET", "OffsetArrays", "SparseArrays", "StaticArrays", "Test", "Random"]

src/blockbroadcast.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,13 @@ end
140140
@inline _bview(arg, ::Vararg) = arg
141141
@inline _bview(A::AbstractArray, I...) = view(A, I...)
142142

143-
@inline function Base.Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {BS<:AbstractBlockStyle}
144-
return copyto!(dest, Base.Broadcast.instantiate(Base.Broadcast.Broadcasted{BS}(bc.f, bc.args, combine_blockaxes.(axes(dest),axes(bc)))))
143+
@inline function Broadcast.materialize!(dest, bc::Broadcasted{BS}) where {NDims, BS<:AbstractBlockStyle{NDims}}
144+
dest_reshaped = ndims(dest) == NDims ? dest : reshape(dest, size(bc))
145+
bc2 = Broadcast.instantiate(
146+
Broadcast.Broadcasted{BS}(bc.f, bc.args,
147+
map(combine_blockaxes, axes(dest_reshaped), axes(bc))))
148+
copyto!(dest_reshaped, bc2)
149+
return dest
145150
end
146151

147152
function _generic_blockbroadcast_copyto!(dest::AbstractArray,

test/test_blockbroadcast.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using BlockArrays, FillArrays, Test
22
import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
3+
using JET
34

45
@testset "broadcast" begin
56
@testset "BlockArray" begin
@@ -24,6 +25,24 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
2425

2526
@test axes(A + A) == axes(A .+ A) == axes(A)
2627
@test axes(A .+ 1) == axes(A)
28+
29+
@testset "mismatched ndims" begin
30+
u = BlockArray(randn(5), [2,3])
31+
dest = zeros(size(u)..., 1)
32+
@test (dest .= u) isa typeof(dest)
33+
@static if isdefined(JET, :test_opt)
34+
@test_opt ((dest,u) -> dest .= u)(dest,u)
35+
end
36+
@test reshape(dest, size(u)) == u
37+
38+
u = BlockArray(randn(3,3), [1,2], [1,2])
39+
dest = zeros(length(u))
40+
@test (dest .= u) isa typeof(dest)
41+
@static if isdefined(JET, :test_opt)
42+
@test_opt ((dest,u) -> dest .= u)(dest,u)
43+
end
44+
@test reshape(dest, size(u)) == u
45+
end
2746
end
2847

2948
@testset "PseudoBlockArray" begin
@@ -180,8 +199,13 @@ import BlockArrays: SubBlockIterator, BlockIndexRange, Diagonal
180199

181200
@testset "type inference" begin
182201
u = BlockArray(randn(5), [2,3]);
202+
A = zeros(size(u))
183203
@inferred(copyto!(similar(u), Base.broadcasted(exp, u)))
184204
@test exp.(u) == exp.(Vector(u))
205+
# test_opt isn't available on JET v0.4, which is installed on Julia v1.6
206+
@static if isdefined(JET, :test_opt)
207+
@test_opt ((A,B) -> A .= B)(A,u)
208+
end
185209
end
186210

187211
@testset "adjtrans" begin

0 commit comments

Comments
 (0)