Skip to content

Commit 9c91bdd

Browse files
authored
Migrate to SciPy sparse arrays (#3613)
1 parent 932f36f commit 9c91bdd

File tree

7 files changed

+82
-45
lines changed

7 files changed

+82
-45
lines changed

openmc/_sparse_compat.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Compatibility module for scipy.sparse arrays
2+
3+
This module provides a compatibility layer for working with scipy.sparse arrays
4+
across different scipy versions. Sparse arrays were introduced gradually in
5+
scipy, with full support arriving in scipy 1.15. This module provides a unified
6+
API that uses sparse arrays when available and falls back to sparse matrices for
7+
older scipy versions.
8+
9+
For more information on the migration from sparse matrices to sparse arrays,
10+
see: https://docs.scipy.org/doc/scipy/reference/sparse.migration_to_sparray.html
11+
"""
12+
13+
import scipy
14+
from scipy import sparse as sp
15+
16+
# Check scipy version for feature availability
17+
_SCIPY_VERSION = tuple(map(int, scipy.__version__.split('.')[:2]))
18+
19+
if _SCIPY_VERSION >= (1, 15):
20+
# Use sparse arrays
21+
csr_array = sp.csr_array
22+
csc_array = sp.csc_array
23+
dok_array = sp.dok_array
24+
lil_array = sp.lil_array
25+
eye_array = sp.eye_array
26+
block_array = sp.block_array
27+
else:
28+
# Fall back to sparse matrices
29+
csr_array = sp.csr_matrix
30+
csc_array = sp.csc_matrix
31+
dok_array = sp.dok_matrix
32+
lil_array = sp.lil_matrix
33+
eye_array = sp.eye
34+
block_array = sp.bmat
35+
36+
__all__ = [
37+
'csr_array',
38+
'csc_array',
39+
'dok_array',
40+
'lil_array',
41+
'eye_array',
42+
'block_array',
43+
]

openmc/cmfd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .checkvalue import (check_type, check_length, check_value,
2626
check_greater_than, check_less_than)
2727
from .exceptions import OpenMCError
28+
from ._sparse_compat import csr_array
2829

2930
# See if mpi4py module can be imported, define have_mpi global variable
3031
try:
@@ -980,8 +981,7 @@ def _initialize_linsolver(self):
980981
loss_row = self._loss_row
981982
loss_col = self._loss_col
982983
temp_data = np.ones(len(loss_row))
983-
temp_loss = sparse.csr_matrix((temp_data, (loss_row, loss_col)),
984-
shape=(n, n))
984+
temp_loss = csr_array((temp_data, (loss_row, loss_col)), shape=(n, n))
985985
temp_loss.sort_indices()
986986

987987
# Pass coremap as 1-d array of 32-bit integers
@@ -1585,7 +1585,7 @@ def _build_loss_matrix(self, adjoint):
15851585
# Create csr matrix
15861586
loss_row = self._loss_row
15871587
loss_col = self._loss_col
1588-
loss = sparse.csr_matrix((data, (loss_row, loss_col)), shape=(n, n))
1588+
loss = csr_array((data, (loss_row, loss_col)), shape=(n, n))
15891589
loss.sort_indices()
15901590
return loss
15911591

@@ -1612,7 +1612,7 @@ def _build_prod_matrix(self, adjoint):
16121612
# Create csr matrix
16131613
prod_row = self._prod_row
16141614
prod_col = self._prod_col
1615-
prod = sparse.csr_matrix((data, (prod_row, prod_col)), shape=(n, n))
1615+
prod = csr_array((data, (prod_row, prod_col)), shape=(n, n))
16161616
prod.sort_indices()
16171617
return prod
16181618

openmc/deplete/abc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ class Integrator(ABC):
600600
User-supplied functions are expected to have the following signature:
601601
``solver(A, n0, t) -> n1`` where
602602
603-
* ``A`` is a :class:`scipy.sparse.csc_matrix` making up the
603+
* ``A`` is a :class:`scipy.sparse.csc_array` making up the
604604
depletion matrix
605605
* ``n0`` is a 1-D :class:`numpy.ndarray` of initial compositions
606606
for a given material in atoms/cm3
@@ -1134,7 +1134,7 @@ class SIIntegrator(Integrator):
11341134
User-supplied functions are expected to have the following signature:
11351135
``solver(A, n0, t) -> n1`` where
11361136
1137-
* ``A`` is a :class:`scipy.sparse.csc_matrix` making up the
1137+
* ``A`` is a :class:`scipy.sparse.csc_array` making up the
11381138
depletion matrix
11391139
* ``n0`` is a 1-D :class:`numpy.ndarray` of initial compositions
11401140
for a given material in atoms/cm3
@@ -1297,7 +1297,7 @@ def __call__(self, A, n0, dt):
12971297
12981298
Parameters
12991299
----------
1300-
A : scipy.sparse.csc_matrix
1300+
A : scipy.sparse.csc_array
13011301
Sparse transmutation matrix ``A[j, i]`` describing rates at
13021302
which isotope ``i`` transmutes to isotope ``j``
13031303
n0 : numpy.ndarray

openmc/deplete/chain.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
from typing import List
1818

1919
import lxml.etree as ET
20-
import scipy.sparse as sp
2120

2221
from openmc.checkvalue import check_type, check_greater_than, PathLike
2322
from openmc.data import gnds_name, zam
2423
from openmc.exceptions import DataError
2524
from .nuclide import FissionYieldDistribution, Nuclide
2625
from .._xml import get_text
26+
from .._sparse_compat import csc_array, dok_array
2727
import openmc.data
2828

2929

@@ -619,7 +619,7 @@ def form_matrix(self, rates, fission_yields=None):
619619
620620
Returns
621621
-------
622-
scipy.sparse.csc_matrix
622+
scipy.sparse.csc_array
623623
Sparse matrix representing depletion.
624624
625625
See Also
@@ -713,7 +713,7 @@ def setval(i, j, val):
713713
reactions.clear()
714714

715715
# Return CSC representation instead of DOK
716-
return sp.csc_matrix((vals, (rows, cols)), shape=(n, n))
716+
return csc_array((vals, (rows, cols)), shape=(n, n))
717717

718718
def add_redox_term(self, matrix, buffer, oxidation_states):
719719
r"""Adds a redox term to the depletion matrix from data contained in
@@ -731,7 +731,7 @@ def add_redox_term(self, matrix, buffer, oxidation_states):
731731
732732
Parameters
733733
----------
734-
matrix : scipy.sparse.csc_matrix
734+
matrix : scipy.sparse.csc_array
735735
Sparse matrix representing depletion
736736
buffer : dict
737737
Dictionary of buffer nuclides used to maintain anoins net balance.
@@ -743,7 +743,7 @@ def add_redox_term(self, matrix, buffer, oxidation_states):
743743
states as integers (e.g., +1, 0).
744744
Returns
745745
-------
746-
matrix : scipy.sparse.csc_matrix
746+
matrix : scipy.sparse.csc_array
747747
Sparse matrix with redox term added
748748
"""
749749
# Elements list with the same size as self.nuclides
@@ -769,7 +769,7 @@ def add_redox_term(self, matrix, buffer, oxidation_states):
769769
for nuc, idx in buffer_idx.items():
770770
array[idx] -= redox_change * buffer[nuc] / os[idx]
771771

772-
return sp.csc_matrix(array)
772+
return csc_array(array)
773773

774774
def form_rr_term(self, tr_rates, current_timestep, mats):
775775
"""Function to form the transfer rate term matrices.
@@ -800,13 +800,13 @@ def form_rr_term(self, tr_rates, current_timestep, mats):
800800
801801
Returns
802802
-------
803-
scipy.sparse.csc_matrix
803+
scipy.sparse.csc_array
804804
Sparse matrix representing transfer term.
805805
806806
"""
807807
# Use DOK as intermediate representation
808808
n = len(self)
809-
matrix = sp.dok_matrix((n, n))
809+
matrix = dok_array((n, n))
810810

811811
for i, nuc in enumerate(self.nuclides):
812812
elm = re.split(r'\d+', nuc.name)[0]
@@ -857,15 +857,15 @@ def form_ext_source_term(self, ext_source_rates, current_timestep, mat):
857857
858858
Returns
859859
-------
860-
scipy.sparse.csc_matrix
860+
scipy.sparse.csc_array
861861
Sparse vector representing external source term.
862862
863863
"""
864864
if not ext_source_rates.get_components(mat, current_timestep):
865865
return
866866
# Use DOK as intermediate representation
867867
n = len(self)
868-
vector = sp.dok_matrix((n, 1))
868+
vector = dok_array((n, 1))
869869

870870
for i, nuc in enumerate(self.nuclides):
871871
# Build source term vector

openmc/deplete/cram.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import numbers
77

88
import numpy as np
9-
import scipy.sparse as sp
109
import scipy.sparse.linalg as sla
1110

1211
from openmc.checkvalue import check_type, check_length
1312
from .abc import DepSystemSolver
13+
from .._sparse_compat import csc_array, eye_array
1414

1515
__all__ = ["CRAM16", "CRAM48", "Cram16Solver", "Cram48Solver", "IPFCramSolver"]
1616

@@ -60,7 +60,7 @@ def __call__(self, A, n0, dt):
6060
6161
Parameters
6262
----------
63-
A : scipy.sparse.csr_matrix
63+
A : scipy.sparse.csc_array
6464
Sparse transmutation matrix ``A[j, i]`` desribing rates at
6565
which isotope ``i`` transmutes to isotope ``j``
6666
n0 : numpy.ndarray
@@ -75,9 +75,9 @@ def __call__(self, A, n0, dt):
7575
Final compositions after ``dt``
7676
7777
"""
78-
A = dt * sp.csc_matrix(A, dtype=np.float64)
78+
A = dt * csc_array(A, dtype=np.float64)
7979
y = n0.copy()
80-
ident = sp.eye(A.shape[0], format='csc')
80+
ident = eye_array(A.shape[0], format='csc')
8181
for alpha, theta in zip(self.alpha, self.theta):
8282
y += 2*np.real(alpha*sla.spsolve(A - theta*ident, y))
8383
return y * self.alpha0

openmc/deplete/pool.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from itertools import repeat, starmap
66
from multiprocessing import Pool
77

8-
from scipy.sparse import bmat, hstack, vstack, csc_matrix
98
import numpy as np
9+
from scipy.sparse import hstack
1010

1111
from openmc.mpi import comm
12+
from .._sparse_compat import block_array
1213

1314
# Configurable switch that enables / disables the use of
1415
# multiprocessing routines during depletion
@@ -159,7 +160,7 @@ def deplete(func, chain, n, rates, dt, current_timestep=None, matrix_func=None,
159160
cols.append(None)
160161

161162
rows.append(cols)
162-
matrix = bmat(rows)
163+
matrix = block_array(rows)
163164

164165
# Concatenate vectors of nuclides in one
165166
n_multi = np.concatenate(n)
@@ -194,7 +195,7 @@ def deplete(func, chain, n, rates, dt, current_timestep=None, matrix_func=None,
194195
# of the nuclide vectors
195196
for i, matrix in enumerate(matrices):
196197
if not np.equal(*matrix.shape):
197-
matrices[i] = vstack([matrix, csc_matrix([0]*matrix.shape[1])])
198+
matrix.resize(matrix.shape[1], matrix.shape[1])
198199
n[i] = np.append(n[i], 1.0)
199200

200201
inputs = zip(matrices, n, repeat(dt))

openmc/tallies.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import h5py
1313
import numpy as np
1414
import pandas as pd
15-
import scipy.sparse as sps
1615
from scipy.stats import chi2, norm
1716

1817
import openmc
1918
import openmc.checkvalue as cv
19+
from ._sparse_compat import lil_array
2020
from ._xml import clean_indentation, get_elem_list, get_text
2121
from .mixin import IDManagerMixin
2222
from .mesh import MeshBase
@@ -435,10 +435,10 @@ def _read_results(self):
435435

436436
# Convert NumPy arrays to SciPy sparse LIL matrices
437437
if self.sparse:
438-
self._sum = sps.lil_matrix(self._sum.flatten(), self._sum.shape)
439-
self._sum_sq = sps.lil_matrix(self._sum_sq.flatten(), self._sum_sq.shape)
440-
self._sum_third = sps.lil_matrix(self._sum_third.flatten(), self._sum_third.shape)
441-
self._sum_fourth = sps.lil_matrix(self.sum_fourth.flatten(), self._sum_fourth.shape)
438+
self._sum = lil_array(self._sum.flatten(), self._sum.shape)
439+
self._sum_sq = lil_array(self._sum_sq.flatten(), self._sum_sq.shape)
440+
self._sum_third = lil_array(self._sum_third.flatten(), self._sum_third.shape)
441+
self._sum_fourth = lil_array(self._sum_fourth.flatten(), self._sum_fourth.shape)
442442

443443
# Read simulation time (needed for figure of merit)
444444
self._simulation_time = f["runtime"]["simulation"][()]
@@ -534,8 +534,7 @@ def mean(self):
534534

535535
# Convert NumPy array to SciPy sparse LIL matrix
536536
if self.sparse:
537-
self._mean = sps.lil_matrix(self._mean.flatten(),
538-
self._mean.shape)
537+
self._mean = lil_array(self._mean.flatten(), self._mean.shape)
539538

540539
if self.sparse:
541540
return np.reshape(self._mean.toarray(), self.shape)
@@ -556,8 +555,7 @@ def std_dev(self):
556555

557556
# Convert NumPy array to SciPy sparse LIL matrix
558557
if self.sparse:
559-
self._std_dev = sps.lil_matrix(self._std_dev.flatten(),
560-
self._std_dev.shape)
558+
self._std_dev = lil_array(self._std_dev.flatten(), self._std_dev.shape)
561559

562560
self.with_batch_statistics = True
563561

@@ -588,7 +586,7 @@ def vov(self):
588586
self._vov[mask] = numerator[mask]/denominator[mask] - 1.0/n
589587

590588
if self.sparse:
591-
self._vov = sps.lil_matrix(self._vov.flatten(), self._vov.shape)
589+
self._vov = lil_array(self._vov.flatten(), self._vov.shape)
592590

593591
if self.sparse:
594592
return np.reshape(self._vov.toarray(), self.shape)
@@ -963,22 +961,17 @@ def sparse(self, sparse):
963961
# Convert NumPy arrays to SciPy sparse LIL matrices
964962
if sparse and not self.sparse:
965963
if self._sum is not None:
966-
self._sum = sps.lil_matrix(self._sum.flatten(), self._sum.shape)
964+
self._sum = lil_array(self._sum.flatten(), self._sum.shape)
967965
if self._sum_sq is not None:
968-
self._sum_sq = sps.lil_matrix(self._sum_sq.flatten(),
969-
self._sum_sq.shape)
966+
self._sum_sq = lil_array(self._sum_sq.flatten(), self._sum_sq.shape)
970967
if self._sum_third is not None:
971-
self._sum_third = sps.lil_matrix(self._sum_third.flatten(),
972-
self._sum_third.shape)
968+
self._sum_third = lil_array(self._sum_third.flatten(), self._sum_third.shape)
973969
if self._sum_fourth is not None:
974-
self._sum_fourth = sps.lil_matrix(self._sum_fourth.flatten(),
975-
self._sum_fourth.shape)
970+
self._sum_fourth = lil_array(self._sum_fourth.flatten(), self._sum_fourth.shape)
976971
if self._mean is not None:
977-
self._mean = sps.lil_matrix(self._mean.flatten(),
978-
self._mean.shape)
972+
self._mean = lil_array(self._mean.flatten(), self._mean.shape)
979973
if self._std_dev is not None:
980-
self._std_dev = sps.lil_matrix(self._std_dev.flatten(),
981-
self._std_dev.shape)
974+
self._std_dev = lil_array(self._std_dev.flatten(), self._std_dev.shape)
982975

983976
self._sparse = True
984977

0 commit comments

Comments
 (0)