Skip to content

Commit db06b8f

Browse files
authored
Fix corner case of broadcast expressions with scalar RHS (#3)
1 parent 6429367 commit db06b8f

File tree

4 files changed

+29
-10
lines changed

4 files changed

+29
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MapBroadcast"
22
uuid = "ebd9b9da-f48d-417c-9660-449667d60261"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.6"
4+
version = "0.1.7"
55

66
[deps]
77
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"

src/MapBroadcast.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ module MapBroadcast
55

66
using Base.Broadcast:
77
Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate
8+
using BlockArrays: mortar
89
using Compat: allequal
10+
using FillArrays: Fill
911

1012
const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}
1113

@@ -91,18 +93,20 @@ end
9193

9294
# Promote the shape of the arguments to support broadcasting
9395
# over dimensions by expanding singleton dimensions.
94-
function promote_shape(ax, args::AbstractArray...)
96+
function promote_shape(ax, f, args::AbstractArray...)
9597
if allequal((ax, axes.(args)...))
96-
return args
98+
return f, args
9799
end
98-
return promote_shape_tile(ax, args...)
100+
return f, promote_shape_tile(ax, args...)
99101
end
100102
function promote_shape_tile(common_axes, args::AbstractArray...)
101103
return map(arg -> tile(arg, common_axes), args)
102104
end
103105

104-
using BlockArrays: mortar
105-
using FillArrays: Fill
106+
# Catch the case of zero arguments, like `a .= 2`.
107+
function promote_shape(ax, f)
108+
return identity, (Fill(f(), ax),)
109+
end
106110

107111
# Extend by repeating value up to length.
108112
function extend(t::Tuple, value, length)
@@ -146,13 +150,13 @@ end
146150
function Mapped(::NotMapExpr, bc::Broadcasted)
147151
f = map_function(bc)
148152
ax = axes(bc)
149-
args = promote_shape(ax, map_args(bc)...)
153+
f, args = promote_shape(ax, f, map_args(bc)...)
150154
return Mapped(bc.style, f, args, ax)
151155
end
152156
function Mapped(::MapExpr, bc::Broadcasted)
153157
f = bc.f
154158
ax = axes(bc)
155-
args = promote_shape(ax, bc.args...)
159+
f, args = promote_shape(ax, f, bc.args...)
156160
return Mapped(bc.style, f, args, ax)
157161
end
158162

test/Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
2-
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
32
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3+
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
4+
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
45
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
56
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
67
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/test_basics.jl

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Base.Broadcast: broadcasted
1+
using Base.Broadcast: Broadcasted, broadcasted
2+
using FillArrays: Fill
23
using MapBroadcast: Mapped, is_map_expr, mapped
34
using Test: @inferred, @test, @test_throws, @testset
45

@@ -48,3 +49,16 @@ using Test: @inferred, @test, @test_throws, @testset
4849
@test copy(m) == copy(bc)
4950
end
5051
end
52+
53+
@testset "Scalar RHS" begin
54+
# Emulates the `Broadcasted` expression that gets instantiated
55+
# in expresions like `a .= 3` or `a .= 2 .+ 1`.
56+
bc = Broadcasted(+, (2, 1), (Base.OneTo(2), Base.OneTo(2)))
57+
m = @inferred Mapped(bc)
58+
@test axes(m) === (Base.OneTo(2), Base.OneTo(2))
59+
@test m.f === identity
60+
@test only(m.args) === Fill(3, 2, 2)
61+
dest = randn(2, 2)
62+
copyto!(dest, m)
63+
@test dest == Fill(3, 2, 2)
64+
end

0 commit comments

Comments
 (0)