Skip to content

Commit 94f6b31

Browse files
authored
Define interface for broadcast expressions with unspecified style (#19)
1 parent dddda50 commit 94f6b31

File tree

4 files changed

+8
-5
lines changed

4 files changed

+8
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DerivableInterfaces"
22
uuid = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.3.12"
4+
version = "0.3.14"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/abstractarrayinterface.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ function interface(::Type{<:Broadcast.AbstractArrayStyle})
55
return DefaultArrayInterface()
66
end
77

8+
function interface(::Type{<:Broadcast.Broadcasted{Nothing}})
9+
return DefaultArrayInterface()
10+
end
11+
812
function interface(::Type{<:Broadcast.Broadcasted{<:Style}}) where {Style}
913
return interface(Style)
1014
end
@@ -123,9 +127,7 @@ end
123127
@interface interface::AbstractArrayInterface function Base.copyto!(
124128
a_dest::AbstractArray, bc::Broadcast.Broadcasted{Broadcast.DefaultArrayStyle{0}}
125129
)
126-
m = Mapped(bc)
127-
isempty(m.args) || error("Bad broadcast expression.")
128-
return @interface interface map!(m.f, a_dest, a_dest)
130+
@interface interface fill!(a_dest, bc.f(bc.args...)[])
129131
end
130132

131133
# This is defined in this way so we can rely on the Broadcast logic

test/Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
33
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
44
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
55
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
6-
MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261"
76
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
87
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
98
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/test_defaultarrayinterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,6 @@ end
3333

3434
@testset "Broadcast.DefaultArrayStyle" begin
3535
@test interface(Broadcast.DefaultArrayStyle) == DefaultArrayInterface()
36+
@test interface(Broadcast.Broadcasted(nothing, +, (randn(2), randn(2)))) ==
37+
DefaultArrayInterface()
3638
end

0 commit comments

Comments
 (0)