Skip to content

Commit ca42957

Browse files
rwcarlsenGuySten
andauthored
depletion: fix performance of chain matrix construction (#3567)
Co-authored-by: GuySten <[email protected]>
1 parent 3665090 commit ca42957

File tree

1 file changed

+17
-11
lines changed

1 file changed

+17
-11
lines changed

openmc/deplete/chain.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -627,9 +627,15 @@ def form_matrix(self, rates, fission_yields=None):
627627
"""
628628
reactions = set()
629629

630-
# Use DOK matrix as intermediate representation for matrix
631630
n = len(self)
632-
matrix = sp.dok_matrix((n, n))
631+
632+
# we accumulate indices and value entries for everything and create the matrix
633+
# in one step at the end to avoid expensive index checks scipy otherwise does.
634+
rows, cols, vals = [], [], []
635+
def setval(i, j, val):
636+
rows.append(i)
637+
cols.append(j)
638+
vals.append(val)
633639

634640
if fission_yields is None:
635641
fission_yields = self.get_default_fission_yields()
@@ -639,7 +645,7 @@ def form_matrix(self, rates, fission_yields=None):
639645
if nuc.half_life is not None:
640646
decay_constant = math.log(2) / nuc.half_life
641647
if decay_constant != 0.0:
642-
matrix[i, i] -= decay_constant
648+
setval(i, i, -decay_constant)
643649

644650
# Gain from radioactive decay
645651
if nuc.n_decay_modes != 0:
@@ -650,19 +656,19 @@ def form_matrix(self, rates, fission_yields=None):
650656
if branch_val != 0.0:
651657
if target is not None:
652658
k = self.nuclide_dict[target]
653-
matrix[k, i] += branch_val
659+
setval(k, i, branch_val)
654660

655661
# Produce alphas and protons from decay
656662
if 'alpha' in decay_type:
657663
k = self.nuclide_dict.get('He4')
658664
if k is not None:
659665
count = decay_type.count('alpha')
660-
matrix[k, i] += count * branch_val
666+
setval(k, i, count * branch_val)
661667
elif 'p' in decay_type:
662668
k = self.nuclide_dict.get('H1')
663669
if k is not None:
664670
count = decay_type.count('p')
665-
matrix[k, i] += count * branch_val
671+
setval(k, i, count * branch_val)
666672

667673
if nuc.name in rates.index_nuc:
668674
# Extract all reactions for this nuclide in this cell
@@ -679,34 +685,34 @@ def form_matrix(self, rates, fission_yields=None):
679685
if r_type not in reactions:
680686
reactions.add(r_type)
681687
if path_rate != 0.0:
682-
matrix[i, i] -= path_rate
688+
setval(i, i, -path_rate)
683689

684690
# Gain term; allow for total annihilation for debug purposes
685691
if r_type != 'fission':
686692
if target is not None and path_rate != 0.0:
687693
k = self.nuclide_dict[target]
688-
matrix[k, i] += path_rate * br
694+
setval(k, i, path_rate * br)
689695

690696
# Determine light nuclide production, e.g., (n,d) should
691697
# produce H2
692698
light_nucs = REACTIONS[r_type].secondaries
693699
for light_nuc in light_nucs:
694700
k = self.nuclide_dict.get(light_nuc)
695701
if k is not None:
696-
matrix[k, i] += path_rate * br
702+
setval(k, i, path_rate * br)
697703

698704
else:
699705
for product, y in fission_yields[nuc.name].items():
700706
yield_val = y * path_rate
701707
if yield_val != 0.0:
702708
k = self.nuclide_dict[product]
703-
matrix[k, i] += yield_val
709+
setval(k, i, yield_val)
704710

705711
# Clear set of reactions
706712
reactions.clear()
707713

708714
# Return CSC representation instead of DOK
709-
return matrix.tocsc()
715+
return sp.csc_matrix((vals, (rows, cols)), shape=(n, n))
710716

711717
def form_rr_term(self, tr_rates, current_timestep, mats):
712718
"""Function to form the transfer rate term matrices.

0 commit comments

Comments
 (0)