From 3ef5a3fd8f60755d835696dfda86c23c3c617f14 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Tue, 4 Feb 2025 12:32:12 -0500 Subject: [PATCH] Fix corner case of broadcast expressions with scalar RHS --- Project.toml | 2 +- src/MapBroadcast.jl | 18 +++++++++++------- test/Project.toml | 3 ++- test/test_basics.jl | 16 +++++++++++++++- 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index babb24c..ae37222 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MapBroadcast" uuid = "ebd9b9da-f48d-417c-9660-449667d60261" authors = ["ITensor developers and contributors"] -version = "0.1.6" +version = "0.1.7" [deps] BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" diff --git a/src/MapBroadcast.jl b/src/MapBroadcast.jl index 207c36b..602968d 100644 --- a/src/MapBroadcast.jl +++ b/src/MapBroadcast.jl @@ -5,7 +5,9 @@ module MapBroadcast using Base.Broadcast: Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate +using BlockArrays: mortar using Compat: allequal +using FillArrays: Fill const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} @@ -91,18 +93,20 @@ end # Promote the shape of the arguments to support broadcasting # over dimensions by expanding singleton dimensions. -function promote_shape(ax, args::AbstractArray...) +function promote_shape(ax, f, args::AbstractArray...) if allequal((ax, axes.(args)...)) - return args + return f, args end - return promote_shape_tile(ax, args...) + return f, promote_shape_tile(ax, args...) end function promote_shape_tile(common_axes, args::AbstractArray...) return map(arg -> tile(arg, common_axes), args) end -using BlockArrays: mortar -using FillArrays: Fill +# Catch the case of zero arguments, like `a .= 2`. +function promote_shape(ax, f) + return identity, (Fill(f(), ax),) +end # Extend by repeating value up to length. function extend(t::Tuple, value, length) @@ -146,13 +150,13 @@ end function Mapped(::NotMapExpr, bc::Broadcasted) f = map_function(bc) ax = axes(bc) - args = promote_shape(ax, map_args(bc)...) + f, args = promote_shape(ax, f, map_args(bc)...) return Mapped(bc.style, f, args, ax) end function Mapped(::MapExpr, bc::Broadcasted) f = bc.f ax = axes(bc) - args = promote_shape(ax, bc.args...) + f, args = promote_shape(ax, f, bc.args...) return Mapped(bc.style, f, args, ax) end diff --git a/test/Project.toml b/test/Project.toml index 865cb37..969a747 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] -MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/test_basics.jl b/test/test_basics.jl index 1c2b04d..bbd87ef 100644 --- a/test/test_basics.jl +++ b/test/test_basics.jl @@ -1,4 +1,5 @@ -using Base.Broadcast: broadcasted +using Base.Broadcast: Broadcasted, broadcasted +using FillArrays: Fill using MapBroadcast: Mapped, is_map_expr, mapped using Test: @inferred, @test, @test_throws, @testset @@ -48,3 +49,16 @@ using Test: @inferred, @test, @test_throws, @testset @test copy(m) == copy(bc) end end + +@testset "Scalar RHS" begin + # Emulates the `Broadcasted` expression that gets instantiated + # in expresions like `a .= 3` or `a .= 2 .+ 1`. + bc = Broadcasted(+, (2, 1), (Base.OneTo(2), Base.OneTo(2))) + m = @inferred Mapped(bc) + @test axes(m) === (Base.OneTo(2), Base.OneTo(2)) + @test m.f === identity + @test only(m.args) === Fill(3, 2, 2) + dest = randn(2, 2) + copyto!(dest, m) + @test dest == Fill(3, 2, 2) +end