Skip to content

Commit 028f440

Browse files
authored
Support MPI parallelism in R2SManager (#3632)
1 parent 5c63e0d commit 028f440

File tree

3 files changed

+82
-42
lines changed

3 files changed

+82
-42
lines changed

openmc/deplete/microxs.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -170,13 +170,14 @@ def get_microxs_and_flux(
170170
# Reinitialize with tallies
171171
openmc.lib.init(intracomm=comm)
172172

173-
# create temporary run
174173
with TemporaryDirectory() as temp_dir:
175-
if run_kwargs is None:
176-
run_kwargs = {}
177-
else:
178-
run_kwargs = dict(run_kwargs)
179-
run_kwargs.setdefault('cwd', temp_dir)
174+
# Indicate to run in temporary directory unless being executed through
175+
# openmc.lib, in which case we don't need to specify the cwd
176+
run_kwargs = dict(run_kwargs) if run_kwargs else {}
177+
if not openmc.lib.is_initialized:
178+
run_kwargs.setdefault('cwd', temp_dir)
179+
180+
# Run transport simulation
180181
statepoint_path = model.run(**run_kwargs)
181182

182183
if comm.rank == 0:
@@ -189,15 +190,18 @@ def get_microxs_and_flux(
189190
if path_input is not None:
190191
model.export_to_model_xml(path_input)
191192

192-
with StatePoint(statepoint_path) as sp:
193-
if reaction_rate_mode == 'direct':
194-
rr_tally = sp.tallies[rr_tally.id]
195-
rr_tally._read_results()
196-
flux_tally = sp.tallies[flux_tally.id]
197-
flux_tally._read_results()
193+
# Broadcast updated statepoint path to all ranks
194+
statepoint_path = comm.bcast(statepoint_path)
195+
196+
# Read in tally results (on all ranks)
197+
with StatePoint(statepoint_path) as sp:
198+
if reaction_rate_mode == 'direct':
199+
rr_tally = sp.tallies[rr_tally.id]
200+
rr_tally._read_results()
201+
flux_tally = sp.tallies[flux_tally.id]
202+
flux_tally._read_results()
198203

199204
# Get flux values and make energy groups last dimension
200-
flux_tally = comm.bcast(flux_tally)
201205
flux = flux_tally.get_reshaped_data() # (domains, groups, 1, 1)
202206
flux = np.moveaxis(flux, 1, -1) # (domains, 1, 1, groups)
203207

@@ -206,7 +210,6 @@ def get_microxs_and_flux(
206210

207211
if reaction_rate_mode == 'direct':
208212
# Get reaction rates
209-
rr_tally = comm.bcast(rr_tally)
210213
reaction_rates = rr_tally.get_reshaped_data() # (domains, groups, nuclides, reactions)
211214

212215
# Make energy groups last dimension

openmc/deplete/r2s.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
from .microxs import get_microxs_and_flux, write_microxs_hdf5, read_microxs_hdf5
1212
from .results import Results
1313
from ..checkvalue import PathLike
14+
from ..mpi import comm
15+
from openmc.lib import TemporarySession
16+
from openmc.utility_funcs import change_directory
1417

1518

1619
def get_activation_materials(
@@ -199,8 +202,10 @@ def run(
199202
"""
200203

201204
if output_dir is None:
205+
# Create timestamped output directory and broadcast to all ranks for
206+
# consistency (different ranks may have slightly different times)
202207
stamp = datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
203-
output_dir = Path(f'r2s_{stamp}')
208+
output_dir = Path(comm.bcast(f'r2s_{stamp}'))
204209

205210
# Set run_kwargs for the neutron transport step
206211
if micro_kwargs is None:
@@ -257,18 +262,19 @@ def step1_neutron_transport(
257262
258263
"""
259264

260-
output_dir = Path(output_dir)
265+
output_dir = Path(output_dir).resolve()
261266
output_dir.mkdir(parents=True, exist_ok=True)
262267

263268
if self.method == 'mesh-based':
264269
# Compute material volume fractions on the mesh
265270
if mat_vol_kwargs is None:
266271
mat_vol_kwargs = {}
267-
self.results['mesh_material_volumes'] = mmv = \
268-
self.domains.material_volumes(self.neutron_model, **mat_vol_kwargs)
272+
self.results['mesh_material_volumes'] = mmv = comm.bcast(
273+
self.domains.material_volumes(self.neutron_model, **mat_vol_kwargs))
269274

270275
# Save results to file
271-
mmv.save(output_dir / 'mesh_material_volumes.npz')
276+
if comm.rank == 0:
277+
mmv.save(output_dir / 'mesh_material_volumes.npz')
272278

273279
# Create mesh-material filter based on what combos were found
274280
domains = openmc.MeshMaterialFilter.from_volumes(self.domains, mmv)
@@ -299,13 +305,16 @@ def step1_neutron_transport(
299305
micro_kwargs.setdefault('path_statepoint', output_dir / 'statepoint.h5')
300306
micro_kwargs.setdefault('path_input', output_dir / 'model.xml')
301307

302-
# Run neutron transport and get fluxes and micros
303-
self.results['fluxes'], self.results['micros'] = get_microxs_and_flux(
304-
self.neutron_model, domains, **micro_kwargs)
308+
# Run neutron transport and get fluxes and micros. Run via openmc.lib to
309+
# maintain a consistent parallelism strategy with the activation step.
310+
with TemporarySession():
311+
self.results['fluxes'], self.results['micros'] = get_microxs_and_flux(
312+
self.neutron_model, domains, **micro_kwargs)
305313

306314
# Save flux and micros to file
307-
np.save(output_dir / 'fluxes.npy', self.results['fluxes'])
308-
write_microxs_hdf5(self.results['micros'], output_dir / 'micros.h5')
315+
if comm.rank == 0:
316+
np.save(output_dir / 'fluxes.npy', self.results['fluxes'])
317+
write_microxs_hdf5(self.results['micros'], output_dir / 'micros.h5')
309318

310319
def step2_activation(
311320
self,
@@ -457,15 +466,17 @@ def step3_photon_transport(
457466
# photon model if it is different from the neutron model to account for
458467
# potential material changes
459468
if self.method == 'mesh-based' and different_photon_model:
460-
self.results['mesh_material_volumes_photon'] = photon_mmv = \
461-
self.domains.material_volumes(self.photon_model, **mat_vol_kwargs)
469+
self.results['mesh_material_volumes_photon'] = photon_mmv = comm.bcast(
470+
self.domains.material_volumes(self.photon_model, **mat_vol_kwargs))
462471

463472
# Save photon MMV results to file
464-
photon_mmv.save(output_dir / 'mesh_material_volumes.npz')
473+
if comm.rank == 0:
474+
photon_mmv.save(output_dir / 'mesh_material_volumes.npz')
465475

466-
tally_ids = [tally.id for tally in self.photon_model.tallies]
467-
with open(output_dir / 'tally_ids.json', 'w') as f:
468-
json.dump(tally_ids, f)
476+
if comm.rank == 0:
477+
tally_ids = [tally.id for tally in self.photon_model.tallies]
478+
with open(output_dir / 'tally_ids.json', 'w') as f:
479+
json.dump(tally_ids, f)
469480

470481
self.results['photon_tallies'] = {}
471482

@@ -514,8 +525,9 @@ def step3_photon_transport(
514525
time_index = len(self.results['depletion_results']) + time_index
515526

516527
# Run photon transport calculation
517-
run_kwargs['cwd'] = Path(output_dir) / f'time_{time_index}'
518-
statepoint_path = self.photon_model.run(**run_kwargs)
528+
photon_dir = Path(output_dir) / f'time_{time_index}'
529+
with TemporarySession(self.photon_model, cwd=photon_dir):
530+
statepoint_path = self.photon_model.run(**run_kwargs)
519531

520532
# Store tally results
521533
with openmc.StatePoint(statepoint_path) as sp:

openmc/lib/core.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from . import _dll
1515
from .error import _error_handler
16+
from ..mpi import comm
1617
from openmc.checkvalue import PathLike
1718
import openmc.lib
1819
import openmc
@@ -632,17 +633,23 @@ class TemporarySession:
632633
model : openmc.Model, optional
633634
OpenMC model to use for the session. If None, a minimal working model is
634635
created.
636+
cwd : PathLike, optional
637+
Working directory in which to run OpenMC. If None, a temporary directory
638+
is created and deleted automatically.
635639
**init_kwargs
636640
Keyword arguments to pass to :func:`openmc.lib.init`.
637641
638642
Attributes
639643
----------
640644
model : openmc.Model
641645
The OpenMC model used for the session.
646+
comm : mpi4py.MPI.Intracomm
647+
The MPI intracommunicator used for the session.
642648
643649
"""
644-
def __init__(self, model=None, **init_kwargs):
645-
self.init_kwargs = init_kwargs
650+
def __init__(self, model=None, cwd=None, **init_kwargs):
651+
self.init_kwargs = dict(init_kwargs)
652+
self.cwd = cwd
646653
if model is None:
647654
surf = openmc.Sphere(boundary_type="vacuum")
648655
cell = openmc.Cell(region=-surf)
@@ -652,6 +659,10 @@ def __init__(self, model=None, **init_kwargs):
652659
particles=1, batches=1, output={'summary': False})
653660
self.model = model
654661

662+
# Determine MPI intercommunicator
663+
self.init_kwargs.setdefault('intracomm', comm)
664+
self.comm = self.init_kwargs['intracomm']
665+
655666
def __enter__(self):
656667
"""Initialize the OpenMC library in a temporary directory."""
657668
# If already initialized, the context manager is a no-op
@@ -662,14 +673,24 @@ def __enter__(self):
662673
# Store original working directory
663674
self.orig_dir = Path.cwd()
664675

665-
# Set up temporary directory
666-
self.tmp_dir = TemporaryDirectory()
667-
working_dir = Path(self.tmp_dir.name)
668-
working_dir.mkdir(parents=True, exist_ok=True)
669-
os.chdir(working_dir)
676+
if self.cwd is None:
677+
# Set up temporary directory on rank 0
678+
if self.comm.rank == 0:
679+
self._tmp_dir = TemporaryDirectory()
680+
self.cwd = self._tmp_dir.name
681+
682+
# Broadcast the path so that all ranks use the same directory
683+
self.cwd = self.comm.bcast(self.cwd)
670684

671-
# Export model and initialize OpenMC
672-
self.model.export_to_model_xml()
685+
# Create and change to specified directory
686+
self.cwd = Path(self.cwd)
687+
self.cwd.mkdir(parents=True, exist_ok=True)
688+
os.chdir(self.cwd)
689+
690+
# Export model on first rank and initialize OpenMC
691+
if self.comm.rank == 0:
692+
self.model.export_to_model_xml()
693+
self.comm.barrier()
673694
openmc.lib.init(**self.init_kwargs)
674695

675696
return self
@@ -683,7 +704,11 @@ def __exit__(self, exc_type, exc_value, traceback):
683704
finalize()
684705
finally:
685706
os.chdir(self.orig_dir)
686-
self.tmp_dir.cleanup()
707+
708+
# Make sure all ranks have finalized before deleting temporary dir
709+
self.comm.barrier()
710+
if hasattr(self, '_tmp_dir'):
711+
self._tmp_dir.cleanup()
687712

688713

689714
class _DLLGlobal:

0 commit comments

Comments
 (0)