From daa4cb27582080b085315ca2f07d2d7f4020f365 Mon Sep 17 00:00:00 2001 From: HenrySnowden Date: Mon, 12 May 2025 14:42:31 +0100 Subject: [PATCH 1/2] Added the zeromatrix function to NamedArrayPartition A function similar to what is implemented commit 2094a78 but for NamedArrayPartitions rather than ArrayPartitions. Tested privately to work with Implicit solvers in OrdinaryDiffEq.jl --- src/named_array_partition.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/named_array_partition.jl b/src/named_array_partition.jl index de8fa91a..99c83512 100644 --- a/src/named_array_partition.jl +++ b/src/named_array_partition.jl @@ -145,6 +145,13 @@ end return dest end +#Overwrite ArrayInterface zeromatrix to work with NamedArrayPartitions & implicit solvers within OrdinaryDiffEq +function ArrayInterface.zeromatrix(A::NamedArrayPartition) + B = ArrayPartition(A) + x = reduce(vcat,vec.(B.x)) + x .* x' .* false +end + # `x = find_NamedArrayPartition(x)` returns the first `NamedArrayPartition` among broadcast arguments. find_NamedArrayPartition(bc::Base.Broadcast.Broadcasted) = find_NamedArrayPartition(bc.args) function find_NamedArrayPartition(args::Tuple) From 3450c85ab59c9eb51115bd641fb229c42b21e6e3 Mon Sep 17 00:00:00 2001 From: HenrySnowden Date: Mon, 12 May 2025 14:50:09 +0100 Subject: [PATCH 2/2] Added tests for new zero matrix function --- test/named_array_partition_tests.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/named_array_partition_tests.jl b/test/named_array_partition_tests.jl index d5647bad..e6747969 100644 --- a/test/named_array_partition_tests.jl +++ b/test/named_array_partition_tests.jl @@ -9,10 +9,13 @@ using RecursiveArrayTools, Test @test x.a ≈ ones(10) @test typeof(x .+ x[1:end]) <: Vector # test broadcast precedence @test all(x .== x[1:end]) + @test ArrayInterface.zeromatrix(x) isa Matrix + @test size(ArrayInterface.zeromatrix(x)) == (30,30) y = copy(x) @test zero(x, (10, 20)) == zero(x) # test that ignoring dims works @test typeof(zero(x)) <: NamedArrayPartition @test (y .*= 2).a[1] ≈ 2 # test in-place bcast + @test length(Array(x)) == 30 @test typeof(Array(x)) <: Array