Skip to content

Commit a32ab63

Browse files
Define zero on some MatrixField types
Define zero for some MatrixField structs
1 parent dceaa19 commit a32ab63

File tree

4 files changed

+23
-0
lines changed

4 files changed

+23
-0
lines changed

src/MatrixFields/field_matrix_solver.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,9 @@ end
168168
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂) =
169169
LazySchurComplement(A₁₁, A₁₂, A₂₁, A₂₂, nothing, nothing, nothing, nothing)
170170

171+
Base.zero(lsc::LazySchurComplement) =
172+
LazySchurComplement(map(fn -> zero(getfield(lsc, fn)), fieldnames(lsc))...)
173+
171174
NVTX.@annotate function lazy_mul(A₂₂′::LazySchurComplement, x₂)
172175
(; A₁₁, A₁₂, A₂₁, A₂₂, alg₁, cache₁, A₁₂_x₂, invA₁₁_A₁₂_x₂) = A₂₂′
173176
zero_rows = setdiff(keys(A₁₂_x₂), matrix_row_keys(keys(A₁₂)))
@@ -229,6 +232,8 @@ partial pivoting matrix).
229232
"""
230233
struct BlockDiagonalSolve <: FieldMatrixSolverAlgorithm end
231234

235+
Base.zero(alg::BlockDiagonalSolve) = alg
236+
232237
function field_matrix_solver_cache(::BlockDiagonalSolve, A, b)
233238
caches = map(matrix_row_keys(keys(A))) do name
234239
single_field_solver_cache(A[name, name], b[name])
@@ -315,6 +320,9 @@ BlockLowerTriangularSolve(
315320
alg₂ = BlockDiagonalSolve(),
316321
) = BlockLowerTriangularSolve(names₁, alg₁, alg₂)
317322

323+
Base.zero(alg::BlockLowerTriangularSolve) =
324+
BlockLowerTriangularSolve(alg.names₁, zero(alg.alg₁), zero(alg.alg₂))
325+
318326
function field_matrix_solver_cache(alg::BlockLowerTriangularSolve, A, b)
319327
A₁₁, _, _, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
320328
cache₁ = field_matrix_solver_cache(alg.alg₁, A₁₁, b₁)
@@ -448,6 +456,9 @@ end
448456
SchurComplementReductionSolve(names₁...; alg₁ = BlockDiagonalSolve(), alg₂) =
449457
SchurComplementReductionSolve(names₁, alg₁, alg₂)
450458

459+
Base.zero(alg::SchurComplementReductionSolve) =
460+
SchurComplementReductionSolve(alg.names₁, zero(alg.alg₁), zero(alg.alg₂))
461+
451462
function field_matrix_solver_cache(alg::SchurComplementReductionSolve, A, b)
452463
A₁₁, A₁₂, A₂₁, A₂₂, b₁, b₂ = partition_blocks(alg.names₁, A, b)
453464
b₁′ = similar(b₁)

src/MatrixFields/field_matrix_with_solver.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ Base.:(==)(A1::FieldMatrixWithSolver, A2::FieldMatrixWithSolver) =
4141
Base.similar(A::FieldMatrixWithSolver) =
4242
FieldMatrixWithSolver(similar(A.matrix), A.solver)
4343

44+
Base.zero(A::FieldMatrixWithSolver) =
45+
FieldMatrixWithSolver(zero(A.matrix), A.solver)
46+
4447
ldiv!(x::Fields.FieldVector, A::FieldMatrixWithSolver, b::Fields.FieldVector) =
4548
field_matrix_solve!(A.solver, x, A.matrix, b)
4649

src/MatrixFields/field_name_dict.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,13 @@ function Base.similar(dict::FieldNameDict)
182182
return FieldNameDict(keys(dict), entries)
183183
end
184184

185+
function Base.zero(dict::FieldNameDict)
186+
entries = unrolled_map(values(dict)) do entry
187+
entry isa UniformScaling ? entry : zero(entry)
188+
end
189+
return FieldNameDict(keys(dict), entries)
190+
end
191+
185192
# Note: This assumes that the matrix has the same row and column units, since I
186193
# cannot be multiplied by anything other than a scalar.
187194
function Base.one(matrix::FieldMatrix)

test/MatrixFields/field_matrix_solvers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ function test_field_matrix_solver(; test_name, alg, A, b, use_rel_error = false)
1818
@testset "$test_name" begin
1919
x = similar(b)
2020
A′ = FieldMatrixWithSolver(A, b, alg)
21+
@test zero(A′) isa typeof(A′)
2122
solve_time =
2223
@benchmark ClimaComms.@cuda_sync comms_device ldiv!(x, A′, b)
2324

2425
b_test = similar(b)
26+
@test zero(b) isa typeof(b)
2527
mul_time =
2628
@benchmark ClimaComms.@cuda_sync comms_device mul!(b_test, A′, x)
2729

0 commit comments

Comments
 (0)