Skip to content

Commit d39d88e

Browse files
authored
reuse_chol -> reuse_fact and avoid try catch in direct solver (#171)
* reuse_chol -> reuse_fact and avoid try catch * bump version * fix typo * formatting * remove ldiv! * lower bound NonconvexPercival * split the tests some more
1 parent d60ba07 commit d39d88e

File tree

8 files changed

+30
-43
lines changed

8 files changed

+30
-43
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ jobs:
2525
- Extended_Tests
2626
- Examples_1
2727
- Examples_2
28+
- Examples_3
2829
- WCSMO14_1
2930
- WCSMO14_2
3031
steps:

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TopOpt"
22
uuid = "53a1e1a5-51bb-58a9-8a02-02056cc81109"
33
authors = ["mohamed82008 <[email protected]>", "yijiangh <[email protected]>"]
4-
version = "0.8.3"
4+
version = "0.9.0"
55

66
[deps]
77
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
@@ -65,7 +65,7 @@ MappedArrays = "0.4"
6565
NearestNeighbors = "0.4"
6666
Nonconvex = "2"
6767
NonconvexMMA = "1"
68-
NonconvexPercival = "0.1"
68+
NonconvexPercival = "0.1.4"
6969
NonconvexSemidefinite = "0.1.7"
7070
Parameters = "0.12"
7171
Preconditioners = "0.3, 0.4, 0.5, 0.6"

src/FEA/direct_displacement_solver.jl

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,12 @@ function (s::DirectDisplacementSolver{T})(
4949
::Type{Val{safe}}=Val{false},
5050
::Type{newT}=T;
5151
assemble_f=true,
52-
reuse_chol=false,
52+
reuse_fact=false,
5353
rhs=assemble_f ? s.globalinfo.f : s.rhs,
5454
lhs=assemble_f ? s.u : s.lhs,
5555
kwargs...,
5656
) where {T,safe,newT}
5757
globalinfo = s.globalinfo
58-
N = size(globalinfo.K, 1)
5958
assemble!(
6059
globalinfo,
6160
s.problem,
@@ -75,40 +74,24 @@ function (s::DirectDisplacementSolver{T})(
7574
end
7675
end
7776
nans = false
78-
if !reuse_chol
79-
try
80-
if T === newT
81-
if s.qr
82-
globalinfo.qrK = qr(K.data)
83-
else
84-
globalinfo.cholK = cholesky(Symmetric(K))
85-
end
86-
else
87-
if s.qr
88-
globalinfo.qrK = qr((newT.(K)).data)
89-
else
90-
globalinfo.cholK = cholesky(Symmetric(newT.(K)))
91-
end
92-
end
93-
catch err
94-
lhs .= T(NaN)
95-
nans = true
96-
end
97-
end
98-
if !nans
99-
if T === newT
100-
if s.qr
101-
lhs .= globalinfo.qrK \ rhs
102-
else
103-
lhs .= globalinfo.cholK \ rhs
104-
end
77+
if !reuse_fact
78+
newK = T === newT ? K : newT.(K)
79+
if s.qr
80+
globalinfo.qrK = qr(newK.data)
10581
else
106-
if s.qr
107-
lhs .= globalinfo.qrK \ newT.(rhs)
82+
cholK = cholesky(Symmetric(K); check=false)
83+
if issuccess(cholK)
84+
globalinfo.cholK = cholK
10885
else
109-
lhs .= globalinfo.cholK \ newT.(rhs)
86+
@warn "The global stiffness matrix is not positive definite. Please check your boundary conditions."
87+
lhs .= T(NaN)
88+
nans = true
11089
end
11190
end
11291
end
92+
nans && return nothing
93+
new_rhs = T === newT ? rhs : newT.(rhs)
94+
fact = s.qr ? globalinfo.qrK : globalinfo.cholK
95+
lhs .= fact \ new_rhs
11396
return nothing
11497
end

src/Functions/block_compliance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ function compute_approx_bc(bc, F, V, Y)
143143
bc.method.sample_once || bc.method.sample_method(V)
144144
for i in 1:nv
145145
@views mul!(solver.rhs, F, V[:, i])
146-
solver(; assemble_f=false, reuse_chol=(i > 1))
146+
solver(; assemble_f=false, reuse_fact=(i > 1))
147147
invKFv = solver.lhs
148148
Y[:, i] .= invKFv
149149
temp = F' * invKFv
@@ -161,7 +161,7 @@ function compute_jtvp!_bc(out, bc, method::DiagonalEstimation, w)
161161
temp .= 0
162162
#q_i = K^-1 F (w .* v_i)
163163
@views mul!(solver.rhs, F, w .* V[:, i])
164-
solver(; assemble_f=false, reuse_chol=(i > 1))
164+
solver(; assemble_f=false, reuse_fact=(i > 1))
165165
Q[:, i] = solver.lhs
166166
#<q_i, dK/dx_e, y_i>
167167
@views compute_inner(temp, Q[:, i], Y[:, i], solver)

src/Functions/displacement.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ function ChainRulesCore.rrule(dp::Displacement, x::PseudoDensities)
7979
else
8080
solver.rhs .= Δ
8181
end
82-
solver(; reuse_chol=true, assemble_f=false)
82+
solver(; reuse_fact=true, assemble_f=false)
8383
dudx_tmp .= 0
8484
for e in 1:length(x.x)
8585
_, dρe = get_ρ_dρ(x.x[e], penalty, xmin)

src/Functions/mean_compliance.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ function compute_exact_ec(ec, x, grad, F, n)
7575
grad .= 0
7676
for i in 1:size(F, 2)
7777
@views solver.rhs .= F[:, i]
78-
solver(; assemble_f=false, reuse_chol=(i > 1))
78+
solver(; assemble_f=false, reuse_fact=(i > 1))
7979
u = solver.lhs
8080
obj += compute_compliance(
8181
cell_comp, grad_temp, cell_dofs, Kes, u, black, white, varind, x, penalty, xmin
@@ -104,7 +104,7 @@ function compute_approx_ec(ec, x, grad, F, V, n)
104104
ec.method.sample_once || ec.method.sample_method(V)
105105
for i in 1:nv
106106
@views mul!(solver.rhs, F, V[:, i])
107-
solver(; assemble_f=false, reuse_chol=(i > 1))
107+
solver(; assemble_f=false, reuse_fact=(i > 1))
108108
invKFv = solver.lhs
109109
obj += compute_compliance(
110110
cell_comp,

src/Functions/truss_stress.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ end
8181
# u = dp(x)
8282
# return u, Δ -> begin # v
8383
# solver.rhs .= Δ
84-
# solver(reuse_chol = true, assemble_f = false)
84+
# solver(reuse_fact = true, assemble_f = false)
8585
# dudx_tmp .= 0
8686
# for e in 1:length(x)
8787
# _, dρe = get_ρ_dρ(x[e], penalty, xmin)

test/runtests.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,21 @@ if GROUP == "All" || GROUP == "Examples_1"
3333
@safetestset "CSIMP example" begin
3434
include("examples/csimp.jl")
3535
end
36+
end
37+
38+
if GROUP == "All" || GROUP == "Examples_2"
3639
@safetestset "Global stress example" begin
3740
include("examples/global_stress.jl")
3841
end
3942
@safetestset "Local stress example" begin
4043
include("examples/local_stress.jl")
4144
end
45+
end
46+
47+
if GROUP == "All" || GROUP == "Examples_3"
4248
@safetestset "More examples" begin
4349
include("examples/test_examples.jl")
4450
end
45-
end
46-
47-
if GROUP == "All" || GROUP == "Examples_2"
4851
@safetestset "Neural network example" begin
4952
include("examples/neural.jl")
5053
end

0 commit comments

Comments
 (0)