Skip to content

Commit ce3d4f3

Browse files
Merge pull request #446 from Snowd1n/arraypart_zero
NamedArrayPartition zeromatrix
2 parents cd9a975 + 3450c85 commit ce3d4f3

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

src/named_array_partition.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,13 @@ end
145145
return dest
146146
end
147147

148+
#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq
149+
function ArrayInterface.zeromatrix(A::NamedArrayPartition)
150+
B = ArrayPartition(A)
151+
x = reduce(vcat,vec.(B.x))
152+
x .* x' .* false
153+
end
154+
148155
# `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments.
149156
find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args)
150157
function find_NamedArrayPartition(args::Tuple)

test/named_array_partition_tests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ using RecursiveArrayTools, Test
99
@test x.a ones(10)
1010
@test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence
1111
@test all(x .== x[1:end])
12+
@test ArrayInterface.zeromatrix(x) isa Matrix
13+
@test size(ArrayInterface.zeromatrix(x)) == (30,30)
1214
y = copy(x)
1315
@test zero(x, (10, 20)) == zero(x) # test that ignoring dims works
1416
@test typeof(zero(x)) <: NamedArrayPartition
1517
@test (y .*= 2).a[1] 2 # test in-place bcast
18+
1619

1720
@test length(Array(x)) == 30
1821
@test typeof(Array(x)) <: Array

0 commit comments

Comments
 (0)