diff --git a/Project.toml b/Project.toml index cccf987..348eedb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "NullBroadcasts" uuid = "0d71be07-595a-4f89-9529-4065a4ab43a6" authors = ["CliMA Contributors "] -version = "0.1.0" +version = "0.1.1" [compat] Aqua = "0.8" diff --git a/src/NullBroadcasts.jl b/src/NullBroadcasts.jl index 7946945..496146c 100644 --- a/src/NullBroadcasts.jl +++ b/src/NullBroadcasts.jl @@ -46,18 +46,21 @@ broadcasted_sum(args) = Base.broadcasted(::NullBroadcasted, ::typeof(+), args...) = broadcasted_sum(filter(arg -> !(arg isa NullBroadcasted), args)) +#! format: off + Base.broadcasted(op::typeof(-), ::NullBroadcasted, arg) = Base.broadcasted(op, arg) -Base.broadcasted(op::typeof(-), arg, ::NullBroadcasted) = - Base.broadcasted(Base.identity, arg) Base.broadcasted(op::typeof(-), a::NullBroadcasted) = NullBroadcasted() -Base.broadcasted(op::typeof(-), a::NullBroadcasted, ::NullBroadcasted) = - Base.broadcasted(op, a) +Base.broadcasted(op::typeof(-), a::NullBroadcasted, ::NullBroadcasted) = Base.broadcasted(op, a) +# Specialize on identity cases: +Base.broadcasted(::typeof(-), a, ::NullBroadcasted) = a + +Base.broadcasted(op::typeof(+), ::NullBroadcasted, a, args...) = Base.broadcasted(op, a, args...) +Base.broadcasted(op::typeof(+), arg, ::NullBroadcasted, a, args...) = Base.broadcasted(op, arg, a, args...) +Base.broadcasted(op::typeof(+), a::NullBroadcasted, ::NullBroadcasted, args...) = Base.broadcasted(op, a, args...) -Base.broadcasted(op::typeof(+), ::NullBroadcasted, args...) = Base.broadcasted(op, args...) -Base.broadcasted(op::typeof(+), arg, ::NullBroadcasted, args...) = - Base.broadcasted(op, arg, args...) -Base.broadcasted(op::typeof(+), a::NullBroadcasted, ::NullBroadcasted, args...) = - Base.broadcasted(op, a, args...) +# Specialize on identity cases: +Base.broadcasted(::typeof(+), ::NullBroadcasted, a) = a +Base.broadcasted(::typeof(+), a, ::NullBroadcasted) = a Base.broadcasted(op::typeof(*), ::NullBroadcasted, args...) = NullBroadcasted() Base.broadcasted(op::typeof(*), arg, ::NullBroadcasted) = NullBroadcasted() @@ -68,6 +71,8 @@ Base.broadcasted(op::typeof(/), ::NullBroadcasted, ::NullBroadcasted) = NullBroa Base.broadcasted(op::typeof(identity), a::NullBroadcasted) = a +#! format: on + function skip_materialize(dest, bc::Base.Broadcast.Broadcasted) if typeof(bc.f) <: typeof(+) || typeof(bc.f) <: typeof(-) if length(bc.args) == 2 && diff --git a/test/runtests.jl b/test/runtests.jl index 391677a..165576c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,17 +12,16 @@ import Base.Broadcast: instantiate, materialize, Broadcasted, DefaultArrayStyle @testset "NullBroadcasted" begin x = [1] a = NullBroadcasted() - @test typeof(lazy.(x .+ a)) <: Broadcasted{ + @test typeof(lazy.(x .* 1 .+ a)) <: Broadcasted{ DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, - typeof(+), - Tuple{Vector{Int64}}, + typeof(*), + Tuple{Vector{Int64}, Int64} } - @test typeof(lazy.(a .+ x)) <: Broadcasted{ + @test typeof(lazy.(a .+ x .* 1)) <: Broadcasted{ DefaultArrayStyle{1}, Tuple{Base.OneTo{Int64}}, - typeof(+), - Tuple{Vector{Int64}}, + typeof(*), Tuple{Vector{Int64}, Int64} } @test lazy.(a .* x) isa NullBroadcasted @test lazy.(a ./ x) isa NullBroadcasted