Skip to content

Commit eda7ad0

Browse files
authored
Refactor: smoother conversion from SUNMatrixWrapper to SUNMatrix (#2317)
Adds an implicit conversion function to SUNMatrixWrapper make things more readable.
1 parent c9b08ac commit eda7ad0

File tree

8 files changed

+51
-46
lines changed

8 files changed

+51
-46
lines changed

include/amici/rdata.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ class ReturnData : public ModelDimensions {
576576

577577
if (!this->J.empty()) {
578578
SUNMatrixWrapper J(nx_solver, nx_solver);
579-
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J.get());
579+
model.fJ(t_, 0.0, x_solver_, dx_solver_, xdot, J);
580580
// CVODES uses colmajor, so we need to transform to rowmajor
581581
for (int ix = 0; ix < model.nx_solver; ix++)
582582
for (int jx = 0; jx < model.nx_solver; jx++)

include/amici/sundials_matrix_wrapper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,11 @@ class SUNMatrixWrapper {
7272

7373
~SUNMatrixWrapper();
7474

75+
/**
76+
* @brief Conversion function.
77+
*/
78+
operator SUNMatrix() { return get(); };
79+
7580
/**
7681
* @brief Copy constructor
7782
* @param other

src/model.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2249,7 +2249,7 @@ void Model::fdJydy(int const it, AmiVector const& x, ExpData const& edata) {
22492249

22502250
auto tmp_sparse = SUNMatrixWrapper(tmp_dense, 0.0, CSC_MAT);
22512251
auto ret = SUNMatScaleAdd(
2252-
1.0, derived_state_.dJydy_.at(iyt).get(), tmp_sparse.get()
2252+
1.0, derived_state_.dJydy_.at(iyt), tmp_sparse
22532253
);
22542254
if (ret != SUNMAT_SUCCESS) {
22552255
throw AmiException(
@@ -2897,7 +2897,7 @@ void Model::fdwdp(realtype const t, realtype const* x) {
28972897
}
28982898

28992899
if (always_check_finite_) {
2900-
checkFinite(derived_state_.dwdp_.get(), ModelQuantity::dwdp, t);
2900+
checkFinite(derived_state_.dwdp_, ModelQuantity::dwdp, t);
29012901
}
29022902
}
29032903

@@ -2943,7 +2943,7 @@ void Model::fdwdx(realtype const t, realtype const* x) {
29432943
}
29442944

29452945
if (always_check_finite_) {
2946-
checkFinite(derived_state_.dwdx_.get(), ModelQuantity::dwdx, t);
2946+
checkFinite(derived_state_.dwdx_, ModelQuantity::dwdx, t);
29472947
}
29482948
}
29492949

@@ -2960,7 +2960,7 @@ void Model::fdwdw(realtype const t, realtype const* x) {
29602960
);
29612961

29622962
if (always_check_finite_) {
2963-
checkFinite(dwdw_.get(), ModelQuantity::dwdw, t);
2963+
checkFinite(dwdw_, ModelQuantity::dwdw, t);
29642964
}
29652965
}
29662966

src/model_dae.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void Model_DAE::fJ(
1414
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
1515
const_N_Vector /*xdot*/, SUNMatrix J
1616
) {
17-
fJSparse(t, cj, x, dx, derived_state_.J_.get());
17+
fJSparse(t, cj, x, dx, derived_state_.J_);
1818
derived_state_.J_.refresh();
1919
auto JDense = SUNMatrixWrapper(J);
2020
derived_state_.J_.to_dense(JDense);
@@ -88,7 +88,7 @@ void Model_DAE::fJv(
8888
N_Vector Jv, realtype cj
8989
) {
9090
N_VConst(0.0, Jv);
91-
fJSparse(t, cj, x, dx, derived_state_.J_.get());
91+
fJSparse(t, cj, x, dx, derived_state_.J_);
9292
derived_state_.J_.refresh();
9393
derived_state_.J_.multiply(Jv, v);
9494
}
@@ -135,7 +135,7 @@ void Model_DAE::fJDiag(
135135
realtype const t, AmiVector& JDiag, realtype const /*cj*/,
136136
AmiVector const& x, AmiVector const& dx
137137
) {
138-
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_.get());
138+
fJSparse(t, 0.0, x.getNVector(), dx.getNVector(), derived_state_.J_);
139139
derived_state_.J_.refresh();
140140
derived_state_.J_.to_diag(JDiag.getNVector());
141141
if (checkFinite(JDiag.getVector(), ModelQuantity::JDiag) != AMICI_SUCCESS)
@@ -355,7 +355,7 @@ void Model_DAE::fJB(
355355
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
356356
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
357357
) {
358-
fJSparse(t, cj, x, dx, derived_state_.J_.get());
358+
fJSparse(t, cj, x, dx, derived_state_.J_);
359359
derived_state_.J_.refresh();
360360
auto JBDense = SUNMatrixWrapper(JB);
361361
derived_state_.J_.transpose(JBDense, -1.0, nxtrue_solver);
@@ -376,7 +376,7 @@ void Model_DAE::fJSparseB(
376376
realtype t, realtype cj, const_N_Vector x, const_N_Vector dx,
377377
const_N_Vector /*xB*/, const_N_Vector /*dxB*/, SUNMatrix JB
378378
) {
379-
fJSparse(t, cj, x, dx, derived_state_.J_.get());
379+
fJSparse(t, cj, x, dx, derived_state_.J_);
380380
derived_state_.J_.refresh();
381381
auto JSparseB = SUNMatrixWrapper(JB);
382382
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
@@ -387,7 +387,7 @@ void Model_DAE::fJvB(
387387
const_N_Vector dxB, const_N_Vector vB, N_Vector JvB, realtype cj
388388
) {
389389
N_VConst(0.0, JvB);
390-
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_.get());
390+
fJSparseB(t, cj, x, dx, xB, dxB, derived_state_.JB_);
391391
derived_state_.JB_.refresh();
392392
derived_state_.JB_.multiply(JvB, vB);
393393
}
@@ -397,7 +397,7 @@ void Model_DAE::fxBdot(
397397
const_N_Vector dxB, N_Vector xBdot
398398
) {
399399
N_VConst(0.0, xBdot);
400-
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_.get());
400+
fJSparseB(t, 1.0, x, dx, xB, dxB, derived_state_.JB_);
401401
derived_state_.JB_.refresh();
402402
fM(t, x);
403403
derived_state_.JB_.multiply(xBdot, xB);
@@ -454,7 +454,7 @@ void Model_DAE::fqBdot_ss(
454454

455455
void Model_DAE::fJSparseB_ss(SUNMatrix JB) {
456456
/* Just pass the model Jacobian on to JB */
457-
SUNMatCopy(derived_state_.JB_.get(), JB);
457+
SUNMatCopy(derived_state_.JB_, JB);
458458
derived_state_.JB_.refresh();
459459
}
460460

@@ -465,7 +465,7 @@ void Model_DAE::writeSteadystateJB(
465465
/* Get backward Jacobian */
466466
fJSparseB(
467467
t, cj, x.getNVector(), dx.getNVector(), xB.getNVector(),
468-
dxB.getNVector(), derived_state_.JB_.get()
468+
dxB.getNVector(), derived_state_.JB_
469469
);
470470
derived_state_.JB_.refresh();
471471
/* Switch sign, as we integrate forward in time, not backward */
@@ -491,7 +491,7 @@ void Model_DAE::fsxdot(
491491
// the same for all remaining
492492
fM(t, x);
493493
fdxdotdp(t, x, dx);
494-
fJSparse(t, 0.0, x, dx, derived_state_.J_.get());
494+
fJSparse(t, 0.0, x, dx, derived_state_.J_);
495495
derived_state_.J_.refresh();
496496
}
497497

src/model_ode.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ void Model_ODE::fJ(
1414
void Model_ODE::fJ(
1515
realtype t, const_N_Vector x, const_N_Vector /*xdot*/, SUNMatrix J
1616
) {
17-
fJSparse(t, x, derived_state_.J_.get());
17+
fJSparse(t, x, derived_state_.J_);
1818
derived_state_.J_.refresh();
1919
auto JDense = SUNMatrixWrapper(J);
2020
derived_state_.J_.to_dense(JDense);
@@ -77,7 +77,7 @@ void Model_ODE::fJv(
7777
const_N_Vector v, N_Vector Jv, realtype t, const_N_Vector x
7878
) {
7979
N_VConst(0.0, Jv);
80-
fJSparse(t, x, derived_state_.J_.get());
80+
fJSparse(t, x, derived_state_.J_);
8181
derived_state_.J_.refresh();
8282
derived_state_.J_.multiply(Jv, v);
8383
}
@@ -355,7 +355,7 @@ void Model_ODE::fJB(
355355
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
356356
const_N_Vector /*xBdot*/, SUNMatrix JB
357357
) {
358-
fJSparse(t, x, derived_state_.J_.get());
358+
fJSparse(t, x, derived_state_.J_);
359359
derived_state_.J_.refresh();
360360
auto JDenseB = SUNMatrixWrapper(JB);
361361
derived_state_.J_.transpose(JDenseB, -1.0, nxtrue_solver);
@@ -373,14 +373,14 @@ void Model_ODE::fJSparseB(
373373
realtype t, const_N_Vector x, const_N_Vector /*xB*/,
374374
const_N_Vector /*xBdot*/, SUNMatrix JB
375375
) {
376-
fJSparse(t, x, derived_state_.J_.get());
376+
fJSparse(t, x, derived_state_.J_);
377377
derived_state_.J_.refresh();
378378
auto JSparseB = SUNMatrixWrapper(JB);
379379
derived_state_.J_.transpose(JSparseB, -1.0, nxtrue_solver);
380380
}
381381

382382
void Model_ODE::fJDiag(realtype t, N_Vector JDiag, const_N_Vector x) {
383-
fJSparse(t, x, derived_state_.J_.get());
383+
fJSparse(t, x, derived_state_.J_);
384384
derived_state_.J_.refresh();
385385
derived_state_.J_.to_diag(JDiag);
386386
}
@@ -390,14 +390,14 @@ void Model_ODE::fJvB(
390390
const_N_Vector xB
391391
) {
392392
N_VConst(0.0, JvB);
393-
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
393+
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);
394394
derived_state_.JB_.refresh();
395395
derived_state_.JB_.multiply(JvB, vB);
396396
}
397397

398398
void Model_ODE::fxBdot(realtype t, N_Vector x, N_Vector xB, N_Vector xBdot) {
399399
N_VConst(0.0, xBdot);
400-
fJSparseB(t, x, xB, nullptr, derived_state_.JB_.get());
400+
fJSparseB(t, x, xB, nullptr, derived_state_.JB_);
401401
derived_state_.JB_.refresh();
402402
derived_state_.JB_.multiply(xBdot, xB);
403403
}
@@ -456,7 +456,7 @@ void Model_ODE::fqBdot_ss(realtype /*t*/, N_Vector xB, N_Vector qBdot) const {
456456

457457
void Model_ODE::fJSparseB_ss(SUNMatrix JB) {
458458
/* Just copy the model Jacobian */
459-
SUNMatCopy(derived_state_.JB_.get(), JB);
459+
SUNMatCopy(derived_state_.JB_, JB);
460460
derived_state_.JB_.refresh();
461461
}
462462

@@ -468,7 +468,7 @@ void Model_ODE::writeSteadystateJB(
468468
/* Get backward Jacobian */
469469
fJSparseB(
470470
t, x.getNVector(), xB.getNVector(), xBdot.getNVector(),
471-
derived_state_.JB_.get()
471+
derived_state_.JB_
472472
);
473473
derived_state_.JB_.refresh();
474474
/* Switch sign, as we integrate forward in time, not backward */
@@ -492,7 +492,7 @@ void Model_ODE::fsxdot(
492492
// we only need to call this for the first parameter index will be
493493
// the same for all remaining
494494
fdxdotdp(t, x);
495-
fJSparse(t, x, derived_state_.J_.get());
495+
fJSparse(t, x, derived_state_.J_);
496496
derived_state_.J_.refresh();
497497
}
498498
if (pythonGenerated) {

src/newton_solver.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ void NewtonSolver::computeNewtonSensis(
108108
NewtonSolverDense::NewtonSolverDense(Model const& model)
109109
: NewtonSolver(model)
110110
, Jtmp_(model.nx_solver, model.nx_solver)
111-
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_.get())) {
111+
, linsol_(SUNLinSol_Dense(x_.getNVector(), Jtmp_)) {
112112
auto status = SUNLinSolInitialize_Dense(linsol_);
113113
if (status != SUNLS_SUCCESS)
114114
throw NewtonFailure(status, "SUNLinSolInitialize_Dense");
@@ -117,26 +117,26 @@ NewtonSolverDense::NewtonSolverDense(Model const& model)
117117
void NewtonSolverDense::prepareLinearSystem(
118118
Model& model, SimulationState const& state
119119
) {
120-
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
120+
model.fJ(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
121121
Jtmp_.refresh();
122-
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
122+
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);
123123
if (status != SUNLS_SUCCESS)
124124
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
125125
}
126126

127127
void NewtonSolverDense::prepareLinearSystemB(
128128
Model& model, SimulationState const& state
129129
) {
130-
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get());
130+
model.fJB(state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_);
131131
Jtmp_.refresh();
132-
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_.get());
132+
auto status = SUNLinSolSetup_Dense(linsol_, Jtmp_);
133133
if (status != SUNLS_SUCCESS)
134134
throw NewtonFailure(status, "SUNLinSolSetup_Dense");
135135
}
136136

137137
void NewtonSolverDense::solveLinearSystem(AmiVector& rhs) {
138138
auto status = SUNLinSolSolve_Dense(
139-
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
139+
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
140140
);
141141
Jtmp_.refresh();
142142
// last argument is tolerance and does not have any influence on result
@@ -167,7 +167,7 @@ NewtonSolverDense::~NewtonSolverDense() {
167167
NewtonSolverSparse::NewtonSolverSparse(Model const& model)
168168
: NewtonSolver(model)
169169
, Jtmp_(model.nx_solver, model.nx_solver, model.nnz, CSC_MAT)
170-
, linsol_(SUNKLU(x_.getNVector(), Jtmp_.get())) {
170+
, linsol_(SUNKLU(x_.getNVector(), Jtmp_)) {
171171
auto status = SUNLinSolInitialize_KLU(linsol_);
172172
if (status != SUNLS_SUCCESS)
173173
throw NewtonFailure(status, "SUNLinSolInitialize_KLU");
@@ -177,9 +177,9 @@ void NewtonSolverSparse::prepareLinearSystem(
177177
Model& model, SimulationState const& state
178178
) {
179179
/* Get sparse Jacobian */
180-
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_.get());
180+
model.fJSparse(state.t, 0.0, state.x, state.dx, xdot_, Jtmp_);
181181
Jtmp_.refresh();
182-
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
182+
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
183183
if (status != SUNLS_SUCCESS)
184184
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
185185
}
@@ -189,18 +189,18 @@ void NewtonSolverSparse::prepareLinearSystemB(
189189
) {
190190
/* Get sparse Jacobian */
191191
model.fJSparseB(
192-
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_.get()
192+
state.t, 0.0, state.x, state.dx, xB_, dxB_, xdot_, Jtmp_
193193
);
194194
Jtmp_.refresh();
195-
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_.get());
195+
auto status = SUNLinSolSetup_KLU(linsol_, Jtmp_);
196196
if (status != SUNLS_SUCCESS)
197197
throw NewtonFailure(status, "SUNLinSolSetup_KLU");
198198
}
199199

200200
void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) {
201201
/* Pass pointer to the linear solver */
202202
auto status = SUNLinSolSolve_KLU(
203-
linsol_, Jtmp_.get(), rhs.getNVector(), rhs.getNVector(), 0.0
203+
linsol_, Jtmp_, rhs.getNVector(), rhs.getNVector(), 0.0
204204
);
205205
// last argument is tolerance and does not have any influence on result
206206

@@ -211,7 +211,7 @@ void NewtonSolverSparse::solveLinearSystem(AmiVector& rhs) {
211211
void NewtonSolverSparse::reinitialize() {
212212
/* partial reinitialization, don't need to reallocate Jtmp_ */
213213
auto status = SUNLinSol_KLUReInit(
214-
linsol_, Jtmp_.get(), Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
214+
linsol_, Jtmp_, Jtmp_.capacity(), SUNKLU_REINIT_PARTIAL
215215
);
216216
if (status != SUNLS_SUCCESS)
217217
throw NewtonFailure(status, "SUNLinSol_KLUReInit");

src/sundials_linsol_wrapper.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ SUNLinSolBand::SUNLinSolBand(N_Vector x, SUNMatrix A)
161161

162162
SUNLinSolBand::SUNLinSolBand(AmiVector const& x, int ubw, int lbw)
163163
: A_(SUNMatrixWrapper(x.getLength(), ubw, lbw)) {
164-
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_.get());
164+
solver_ = SUNLinSol_Band(const_cast<N_Vector>(x.getNVector()), A_);
165165
if (!solver_)
166166
throw AmiException("Failed to create solver.");
167167
}
@@ -170,7 +170,7 @@ SUNMatrix SUNLinSolBand::getMatrix() const { return A_.get(); }
170170

171171
SUNLinSolDense::SUNLinSolDense(AmiVector const& x)
172172
: A_(SUNMatrixWrapper(x.getLength(), x.getLength())) {
173-
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_.get());
173+
solver_ = SUNLinSol_Dense(const_cast<N_Vector>(x.getNVector()), A_);
174174
if (!solver_)
175175
throw AmiException("Failed to create solver.");
176176
}
@@ -187,7 +187,7 @@ SUNLinSolKLU::SUNLinSolKLU(
187187
AmiVector const& x, int nnz, int sparsetype, StateOrdering ordering
188188
)
189189
: A_(SUNMatrixWrapper(x.getLength(), x.getLength(), nnz, sparsetype)) {
190-
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_.get());
190+
solver_ = SUNLinSol_KLU(const_cast<N_Vector>(x.getNVector()), A_);
191191
if (!solver_)
192192
throw AmiException("Failed to create solver.");
193193

@@ -197,7 +197,7 @@ SUNLinSolKLU::SUNLinSolKLU(
197197
SUNMatrix SUNLinSolKLU::getMatrix() const { return A_.get(); }
198198

199199
void SUNLinSolKLU::reInit(int nnz, int reinit_type) {
200-
int status = SUNLinSol_KLUReInit(solver_, A_.get(), nnz, reinit_type);
200+
int status = SUNLinSol_KLUReInit(solver_, A_, nnz, reinit_type);
201201
if (status != SUNLS_SUCCESS)
202202
throw AmiException("SUNLinSol_KLUReInit failed with %d", status);
203203
}

tests/cpp/unittests/testMisc.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ TEST(UnravelIndex, UnravelIndexSunMatDense)
681681
A.set_data(2, 1, 5);
682682

683683
for(int i = 0; i < 6; ++i) {
684-
auto idx = unravel_index(i, A.get());
684+
auto idx = unravel_index(i, A);
685685
EXPECT_EQ(A.get_data(idx.first, idx.second), i);
686686
}
687687
}
@@ -706,7 +706,7 @@ TEST(UnravelIndex, UnravelIndexSunMatSparse)
706706
D.set_data(2, 1, 0);
707707
D.set_data(3, 1, 0);
708708

709-
auto S = SUNSparseFromDenseMatrix(D.get(), 1e-15, CSC_MAT);
709+
auto S = SUNSparseFromDenseMatrix(D, 1e-15, CSC_MAT);
710710

711711
EXPECT_EQ(unravel_index(0, S), std::make_pair((sunindextype) 2, (sunindextype) 0));
712712
EXPECT_EQ(unravel_index(1, S), std::make_pair((sunindextype) 3, (sunindextype) 0));
@@ -720,8 +720,8 @@ TEST(UnravelIndex, UnravelIndexSunMatSparseMissingIndices)
720720
{
721721
// Sparse matrix without any indices set
722722
SUNMatrixWrapper mat = SUNMatrixWrapper(2, 3, 2, CSC_MAT);
723-
EXPECT_EQ(unravel_index(0, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
724-
EXPECT_EQ(unravel_index(1, mat.get()), std::make_pair((sunindextype) -1, (sunindextype) -1));
723+
EXPECT_EQ(unravel_index(0, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
724+
EXPECT_EQ(unravel_index(1, mat), std::make_pair((sunindextype) -1, (sunindextype) -1));
725725
}
726726

727727

0 commit comments

Comments
 (0)