Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MapBroadcast"
uuid = "ebd9b9da-f48d-417c-9660-449667d60261"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.6"
version = "0.1.7"

[deps]
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
Expand Down
18 changes: 11 additions & 7 deletions src/MapBroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ module MapBroadcast

using Base.Broadcast:
Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate
using BlockArrays: mortar
using Compat: allequal
using FillArrays: Fill

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

Expand Down Expand Up @@ -91,18 +93,20 @@ end

# Promote the shape of the arguments to support broadcasting
# over dimensions by expanding singleton dimensions.
function promote_shape(ax, args::AbstractArray...)
function promote_shape(ax, f, args::AbstractArray...)
if allequal((ax, axes.(args)...))
return args
return f, args
end
return promote_shape_tile(ax, args...)
return f, promote_shape_tile(ax, args...)
end
function promote_shape_tile(common_axes, args::AbstractArray...)
return map(arg -> tile(arg, common_axes), args)
end

using BlockArrays: mortar
using FillArrays: Fill
# Catch the case of zero arguments, like `a .= 2`.
function promote_shape(ax, f)
return identity, (Fill(f(), ax),)
end

# Extend by repeating value up to length.
function extend(t::Tuple, value, length)
Expand Down Expand Up @@ -146,13 +150,13 @@ end
function Mapped(::NotMapExpr, bc::Broadcasted)
f = map_function(bc)
ax = axes(bc)
args = promote_shape(ax, map_args(bc)...)
f, args = promote_shape(ax, f, map_args(bc)...)
return Mapped(bc.style, f, args, ax)
end
function Mapped(::MapExpr, bc::Broadcasted)
f = bc.f
ax = axes(bc)
args = promote_shape(ax, bc.args...)
f, args = promote_shape(ax, f, bc.args...)
return Mapped(bc.style, f, args, ax)
end

Expand Down
3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
16 changes: 15 additions & 1 deletion test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Base.Broadcast: broadcasted
using Base.Broadcast: Broadcasted, broadcasted
using FillArrays: Fill
using MapBroadcast: Mapped, is_map_expr, mapped
using Test: @inferred, @test, @test_throws, @testset

Expand Down Expand Up @@ -48,3 +49,16 @@ using Test: @inferred, @test, @test_throws, @testset
@test copy(m) == copy(bc)
end
end

@testset "Scalar RHS" begin
# Emulates the `Broadcasted` expression that gets instantiated
# in expresions like `a .= 3` or `a .= 2 .+ 1`.
bc = Broadcasted(+, (2, 1), (Base.OneTo(2), Base.OneTo(2)))
m = @inferred Mapped(bc)
@test axes(m) === (Base.OneTo(2), Base.OneTo(2))
@test m.f === identity
@test only(m.args) === Fill(3, 2, 2)
dest = randn(2, 2)
copyto!(dest, m)
@test dest == Fill(3, 2, 2)
end