Skip to content

Commit 16ec8b2

Browse files
authored
Refactor DEExporter/DEModel/csc_matrix (#2311)
* Refactor DEExporter/DEModel/csc_matrix Reduce unnecessary coupling: * `csc_matrix` as free function - removes the need for the codeprinter in DEModel * Move the codeprinter to `DEExporter` where it's actually needed * ..
1 parent b355ab7 commit 16ec8b2

File tree

4 files changed

+119
-121
lines changed

4 files changed

+119
-121
lines changed

python/sdist/amici/cxxcodeprinter.py

Lines changed: 81 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -207,90 +207,6 @@ def format_line(symbol: sp.Symbol):
207207
if math not in [0, 0.0]
208208
]
209209

210-
def csc_matrix(
211-
self,
212-
matrix: sp.Matrix,
213-
rownames: list[sp.Symbol],
214-
colnames: list[sp.Symbol],
215-
identifier: Optional[int] = 0,
216-
pattern_only: Optional[bool] = False,
217-
) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]:
218-
"""
219-
Generates the sparse symbolic identifiers, symbolic identifiers,
220-
sparse matrix, column pointers and row values for a symbolic
221-
variable
222-
223-
:param matrix:
224-
dense matrix to be sparsified
225-
226-
:param rownames:
227-
ids of the variable of which the derivative is computed (assuming
228-
matrix is the jacobian)
229-
230-
:param colnames:
231-
ids of the variable with respect to which the derivative is computed
232-
(assuming matrix is the jacobian)
233-
234-
:param identifier:
235-
additional identifier that gets appended to symbol names to
236-
ensure their uniqueness in outer loops
237-
238-
:param pattern_only:
239-
flag for computing sparsity pattern without whole matrix
240-
241-
:return:
242-
symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list,
243-
sparse_matrix
244-
"""
245-
idx = 0
246-
247-
nrows, ncols = matrix.shape
248-
249-
if not pattern_only:
250-
sparse_matrix = sp.zeros(nrows, ncols)
251-
symbol_list = []
252-
sparse_list = []
253-
symbol_col_ptrs = []
254-
symbol_row_vals = []
255-
256-
for col in range(ncols):
257-
symbol_col_ptrs.append(idx)
258-
for row in range(nrows):
259-
if matrix[row, col] == 0:
260-
continue
261-
262-
symbol_row_vals.append(row)
263-
idx += 1
264-
symbol_name = (
265-
f"d{rownames[row].name}" f"_d{colnames[col].name}"
266-
)
267-
if identifier:
268-
symbol_name += f"_{identifier}"
269-
symbol_list.append(symbol_name)
270-
if pattern_only:
271-
continue
272-
273-
sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True)
274-
sparse_list.append(matrix[row, col])
275-
276-
if idx == 0:
277-
symbol_col_ptrs = [] # avoid bad memory access for empty matrices
278-
else:
279-
symbol_col_ptrs.append(idx)
280-
281-
if pattern_only:
282-
sparse_matrix = None
283-
else:
284-
sparse_list = sp.Matrix(sparse_list)
285-
286-
return (
287-
symbol_col_ptrs,
288-
symbol_row_vals,
289-
sparse_list,
290-
symbol_list,
291-
sparse_matrix,
292-
)
293-
294210
@staticmethod
295211
def print_bool(expr) -> str:
296212
"""Print the boolean value of the given expression"""
@@ -360,3 +276,84 @@ def get_switch_statement(
360276
),
361277
indent0 + "}",
362278
]
279+
280+
281+
def csc_matrix(
282+
matrix: sp.Matrix,
283+
rownames: list[sp.Symbol],
284+
colnames: list[sp.Symbol],
285+
identifier: Optional[int] = 0,
286+
pattern_only: Optional[bool] = False,
287+
) -> tuple[list[int], list[int], sp.Matrix, list[str], sp.Matrix]:
288+
"""
289+
Generates the sparse symbolic identifiers, symbolic identifiers,
290+
sparse matrix, column pointers and row values for a symbolic
291+
variable
292+
293+
:param matrix:
294+
dense matrix to be sparsified
295+
296+
:param rownames:
297+
ids of the variable of which the derivative is computed (assuming
298+
matrix is the jacobian)
299+
300+
:param colnames:
301+
ids of the variable with respect to which the derivative is computed
302+
(assuming matrix is the jacobian)
303+
304+
:param identifier:
305+
additional identifier that gets appended to symbol names to
306+
ensure their uniqueness in outer loops
307+
308+
:param pattern_only:
309+
flag for computing sparsity pattern without whole matrix
310+
311+
:return:
312+
symbol_col_ptrs, symbol_row_vals, sparse_list, symbol_list,
313+
sparse_matrix
314+
"""
315+
idx = 0
316+
nrows, ncols = matrix.shape
317+
318+
if not pattern_only:
319+
sparse_matrix = sp.zeros(nrows, ncols)
320+
symbol_list = []
321+
sparse_list = []
322+
symbol_col_ptrs = []
323+
symbol_row_vals = []
324+
325+
for col in range(ncols):
326+
symbol_col_ptrs.append(idx)
327+
for row in range(nrows):
328+
if matrix[row, col] == 0:
329+
continue
330+
331+
symbol_row_vals.append(row)
332+
idx += 1
333+
symbol_name = f"d{rownames[row].name}" f"_d{colnames[col].name}"
334+
if identifier:
335+
symbol_name += f"_{identifier}"
336+
symbol_list.append(symbol_name)
337+
if pattern_only:
338+
continue
339+
340+
sparse_matrix[row, col] = sp.Symbol(symbol_name, real=True)
341+
sparse_list.append(matrix[row, col])
342+
343+
if idx == 0:
344+
symbol_col_ptrs = [] # avoid bad memory access for empty matrices
345+
else:
346+
symbol_col_ptrs.append(idx)
347+
348+
if pattern_only:
349+
sparse_matrix = None
350+
else:
351+
sparse_list = sp.Matrix(sparse_list)
352+
353+
return (
354+
symbol_col_ptrs,
355+
symbol_row_vals,
356+
sparse_list,
357+
symbol_list,
358+
sparse_matrix,
359+
)

python/sdist/amici/de_export.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@
4444
splines,
4545
)
4646
from .constants import SymbolId
47-
from .cxxcodeprinter import AmiciCxxCodePrinter, get_switch_statement
47+
from .cxxcodeprinter import (
48+
AmiciCxxCodePrinter,
49+
get_switch_statement,
50+
csc_matrix,
51+
)
4852
from .de_model import *
4953
from .import_utils import (
5054
ObservableTransformation,
@@ -725,9 +729,6 @@ class DEModel:
725729
whether all observables have a gaussian noise model, i.e. whether
726730
res and FIM make sense.
727731
728-
:ivar _code_printer:
729-
Code printer to generate C++ code
730-
731732
:ivar _z2event:
732733
list of event indices for each event observable
733734
"""
@@ -869,10 +870,6 @@ def cached_simplify(
869870
self._has_quadratic_nllh: bool = True
870871
set_log_level(logger, verbose)
871872

872-
self._code_printer = AmiciCxxCodePrinter()
873-
for fun in CUSTOM_FUNCTIONS:
874-
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]
875-
876873
def differential_states(self) -> list[DifferentialState]:
877874
"""Get all differential states."""
878875
return self._differential_states
@@ -1882,7 +1879,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
18821879
sparse_list,
18831880
symbol_list,
18841881
sparse_matrix,
1885-
) = self._code_printer.csc_matrix(
1882+
) = csc_matrix(
18861883
matrix[iy, :],
18871884
rownames=rownames,
18881885
colnames=colnames,
@@ -1900,7 +1897,7 @@ def _generate_sparse_symbol(self, name: str) -> None:
19001897
sparse_list,
19011898
symbol_list,
19021899
sparse_matrix,
1903-
) = self._code_printer.csc_matrix(
1900+
) = csc_matrix(
19041901
matrix,
19051902
rownames=rownames,
19061903
colnames=colnames,
@@ -2884,6 +2881,9 @@ class DEExporter:
28842881
If the given model uses special functions, this set contains hints for
28852882
model building.
28862883
2884+
:ivar _code_printer:
2885+
Code printer to generate C++ code
2886+
28872887
:ivar generate_sensitivity_code:
28882888
Specifies whether code for sensitivity computation is to be generated
28892889
@@ -2950,10 +2950,14 @@ def __init__(
29502950
self.set_name(model_name)
29512951
self.set_paths(outdir)
29522952

2953+
self._code_printer = AmiciCxxCodePrinter()
2954+
for fun in CUSTOM_FUNCTIONS:
2955+
self._code_printer.known_functions[fun["sympy"]] = fun["c++"]
2956+
29532957
# Signatures and properties of generated model functions (see
29542958
# include/amici/model.h for details)
29552959
self.model: DEModel = de_model
2956-
self.model._code_printer.known_functions.update(
2960+
self._code_printer.known_functions.update(
29572961
splines.spline_user_functions(
29582962
self.model._splines, self._get_index("p")
29592963
)
@@ -3519,7 +3523,7 @@ def _get_function_body(
35193523
f"reinitialization_state_idxs.cend(), {index}) != "
35203524
"reinitialization_state_idxs.cend())",
35213525
f" {function}[{index}] = "
3522-
f"{self.model._code_printer.doprint(formula)};",
3526+
f"{self._code_printer.doprint(formula)};",
35233527
]
35243528
)
35253529
cases[ipar] = expressions
@@ -3534,12 +3538,12 @@ def _get_function_body(
35343538
f"reinitialization_state_idxs.cend(), {index}) != "
35353539
"reinitialization_state_idxs.cend())\n "
35363540
f"{function}[{index}] = "
3537-
f"{self.model._code_printer.doprint(formula)};"
3541+
f"{self._code_printer.doprint(formula)};"
35383542
)
35393543

35403544
elif function in event_functions:
35413545
cases = {
3542-
ie: self.model._code_printer._get_sym_lines_array(
3546+
ie: self._code_printer._get_sym_lines_array(
35433547
equations[ie], function, 0
35443548
)
35453549
for ie in range(self.model.num_events())
@@ -3552,7 +3556,7 @@ def _get_function_body(
35523556
for ie, inner_equations in enumerate(equations):
35533557
inner_lines = []
35543558
inner_cases = {
3555-
ipar: self.model._code_printer._get_sym_lines_array(
3559+
ipar: self._code_printer._get_sym_lines_array(
35563560
inner_equations[:, ipar], function, 0
35573561
)
35583562
for ipar in range(self.model.num_par())
@@ -3567,7 +3571,7 @@ def _get_function_body(
35673571
and equations.shape[1] == self.model.num_par()
35683572
):
35693573
cases = {
3570-
ipar: self.model._code_printer._get_sym_lines_array(
3574+
ipar: self._code_printer._get_sym_lines_array(
35713575
equations[:, ipar], function, 0
35723576
)
35733577
for ipar in range(self.model.num_par())
@@ -3577,15 +3581,15 @@ def _get_function_body(
35773581
elif function in multiobs_functions:
35783582
if function == "dJydy":
35793583
cases = {
3580-
iobs: self.model._code_printer._get_sym_lines_array(
3584+
iobs: self._code_printer._get_sym_lines_array(
35813585
equations[iobs], function, 0
35823586
)
35833587
for iobs in range(self.model.num_obs())
35843588
if not smart_is_zero_matrix(equations[iobs])
35853589
}
35863590
else:
35873591
cases = {
3588-
iobs: self.model._code_printer._get_sym_lines_array(
3592+
iobs: self._code_printer._get_sym_lines_array(
35893593
equations[:, iobs], function, 0
35903594
)
35913595
for iobs in range(equations.shape[1])
@@ -3605,12 +3609,12 @@ def _get_function_body(
36053609
symbols = list(map(sp.Symbol, self.model.sparsesym(function)))
36063610
else:
36073611
symbols = self.model.sym(function)
3608-
lines += self.model._code_printer._get_sym_lines_symbols(
3612+
lines += self._code_printer._get_sym_lines_symbols(
36093613
symbols, equations, function, 4
36103614
)
36113615

36123616
else:
3613-
lines += self.model._code_printer._get_sym_lines_array(
3617+
lines += self._code_printer._get_sym_lines_array(
36143618
equations, function, 4
36153619
)
36163620

@@ -3766,10 +3770,10 @@ def _write_model_header_cpp(self) -> None:
37663770
"NK": self.model.num_const(),
37673771
"O2MODE": "amici::SecondOrderMode::none",
37683772
# using code printer ensures proper handling of nan/inf
3769-
"PARAMETERS": self.model._code_printer.doprint(
3770-
self.model.val("p")
3771-
)[1:-1],
3772-
"FIXED_PARAMETERS": self.model._code_printer.doprint(
3773+
"PARAMETERS": self._code_printer.doprint(self.model.val("p"))[
3774+
1:-1
3775+
],
3776+
"FIXED_PARAMETERS": self._code_printer.doprint(
37733777
self.model.val("k")
37743778
)[1:-1],
37753779
"PARAMETER_NAMES_INITIALIZER_LIST": self._get_symbol_name_initializer_list(
@@ -3961,7 +3965,7 @@ def _get_symbol_id_initializer_list(self, name: str) -> str:
39613965
Template initializer list of ids
39623966
"""
39633967
return "\n".join(
3964-
f'"{self.model._code_printer.doprint(symbol)}", // {name}[{idx}]'
3968+
f'"{self._code_printer.doprint(symbol)}", // {name}[{idx}]'
39653969
for idx, symbol in enumerate(self.model.sym(name))
39663970
)
39673971

python/sdist/amici/pysb_import.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def pysb2amici(
178178
compiler=compiler,
179179
generate_sensitivity_code=generate_sensitivity_code,
180180
)
181+
# Sympy code optimizations are incompatible with PySB objects, as
182+
# `pysb.Observable` comes with its own `.match` which overrides
183+
# `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`.
184+
exporter._code_printer._fpoptimizer = None
181185
exporter.generate_model_code()
182186

183187
if compile:
@@ -241,10 +245,6 @@ def ode_model_from_pysb_importer(
241245
simplify=simplify,
242246
cache_simplify=cache_simplify,
243247
)
244-
# Sympy code optimizations are incompatible with PySB objects, as
245-
# `pysb.Observable` comes with its own `.match` which overrides
246-
# `sympy.Basic.match()`, breaking `sympy.codegen.rewriting.optimize`.
247-
ode._code_printer._fpoptimizer = None
248248

249249
if constant_parameters is None:
250250
constant_parameters = []

0 commit comments

Comments
 (0)