Skip to content

Commit 976942c

Browse files
authored
Parallel Joblib Process Entries (#3933)
Add joblib backend to process entries in parallel
1 parent 940eb60 commit 976942c

File tree

3 files changed

+175
-42
lines changed

3 files changed

+175
-42
lines changed

src/pymatgen/entries/compatibility.py

Lines changed: 97 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from typing import TYPE_CHECKING, Union, cast
1313

1414
import numpy as np
15+
from joblib import Parallel, delayed
1516
from monty.design_patterns import cached_class
1617
from monty.json import MSONable
1718
from monty.serialization import loadfn
@@ -30,6 +31,7 @@
3031
)
3132
from pymatgen.io.vasp.sets import MITRelaxSet, MPRelaxSet, VaspInputSet
3233
from pymatgen.util.due import Doi, due
34+
from pymatgen.util.joblib import set_python_warnings, tqdm_joblib
3335

3436
if TYPE_CHECKING:
3537
from collections.abc import Sequence
@@ -538,28 +540,86 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
538540
"""
539541
raise NotImplementedError
540542

541-
def process_entry(self, entry: ComputedEntry, **kwargs) -> ComputedEntry | None:
543+
def process_entry(self, entry: ComputedEntry, inplace: bool = True, **kwargs) -> ComputedEntry | None:
542544
"""Process a single entry with the chosen Corrections. Note
543545
that this method will change the data of the original entry.
544546
545547
Args:
546548
entry: A ComputedEntry object.
549+
inplace (bool): Whether to adjust the entry in place. Defaults to True.
547550
**kwargs: Will be passed to process_entries().
548551
549552
Returns:
550553
An adjusted entry if entry is compatible, else None.
551554
"""
552-
try:
553-
return self.process_entries(entry, **kwargs)[0]
554-
except IndexError:
555+
if not inplace:
556+
entry = copy.deepcopy(entry)
557+
558+
entry = self._process_entry_inplace(entry, **kwargs)
559+
560+
return entry[0] if entry is not None else None
561+
562+
def _process_entry_inplace(
563+
self,
564+
entry: AnyComputedEntry,
565+
clean: bool = True,
566+
on_error: Literal["ignore", "warn", "raise"] = "ignore",
567+
) -> ComputedEntry | None:
568+
"""Process a single entry with the chosen Corrections. Note
569+
that this method will change the data of the original entry.
570+
571+
Args:
572+
entry: A ComputedEntry object.
573+
clean (bool): Whether to remove any previously-applied energy adjustments.
574+
If True, all EnergyAdjustment are removed prior to processing the Entry.
575+
Defaults to True.
576+
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
577+
raises CompatibilityError. Defaults to 'ignore'.
578+
579+
Returns:
580+
An adjusted entry if entry is compatible, else None.
581+
"""
582+
ignore_entry = False
583+
# if clean is True, remove all previous adjustments from the entry
584+
if clean:
585+
entry.energy_adjustments = []
586+
587+
try: # get the energy adjustments
588+
adjustments = self.get_adjustments(entry)
589+
except CompatibilityError as exc:
590+
if on_error == "raise":
591+
raise
592+
if on_error == "warn":
593+
warnings.warn(str(exc))
555594
return None
556595

596+
for ea in adjustments:
597+
# Has this correction already been applied?
598+
if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]:
599+
# we already applied this exact correction. Do nothing.
600+
pass
601+
elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]:
602+
# we already applied a correction with the same name
603+
# but a different value. Something is wrong.
604+
ignore_entry = True
605+
warnings.warn(
606+
f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its "
607+
f"value differs from the value of {ea.value:.3f} calculated here. This "
608+
"Entry will be discarded."
609+
)
610+
else:
611+
# Add the correction to the energy_adjustments list
612+
entry.energy_adjustments.append(ea)
613+
614+
return entry, ignore_entry
615+
557616
def process_entries(
558617
self,
559618
entries: AnyComputedEntry | list[AnyComputedEntry],
560619
clean: bool = True,
561620
verbose: bool = False,
562621
inplace: bool = True,
622+
n_workers: int = 1,
563623
on_error: Literal["ignore", "warn", "raise"] = "ignore",
564624
) -> list[AnyComputedEntry]:
565625
"""Process a sequence of entries with the chosen Compatibility scheme.
@@ -576,6 +636,7 @@ def process_entries(
576636
verbose (bool): Whether to display progress bar for processing multiple entries.
577637
Defaults to False.
578638
inplace (bool): Whether to adjust input entries in place. Defaults to True.
639+
n_workers (int): Number of workers to use for parallel processing. Defaults to 1.
579640
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
580641
raises CompatibilityError. Defaults to 'ignore'.
581642
@@ -593,41 +654,28 @@ def process_entries(
593654
if not inplace:
594655
entries = copy.deepcopy(entries)
595656

596-
for entry in tqdm(entries, disable=not verbose):
597-
ignore_entry = False
598-
# if clean is True, remove all previous adjustments from the entry
599-
if clean:
600-
entry.energy_adjustments = []
601-
602-
try: # get the energy adjustments
603-
adjustments = self.get_adjustments(entry)
604-
except CompatibilityError as exc:
605-
if on_error == "raise":
606-
raise
607-
if on_error == "warn":
608-
warnings.warn(str(exc))
609-
continue
610-
611-
for ea in adjustments:
612-
# Has this correction already been applied?
613-
if (ea.name, ea.cls, ea.value) in [(ea2.name, ea2.cls, ea2.value) for ea2 in entry.energy_adjustments]:
614-
# we already applied this exact correction. Do nothing.
615-
pass
616-
elif (ea.name, ea.cls) in [(ea2.name, ea2.cls) for ea2 in entry.energy_adjustments]:
617-
# we already applied a correction with the same name
618-
# but a different value. Something is wrong.
619-
ignore_entry = True
620-
warnings.warn(
621-
f"Entry {entry.entry_id} already has an energy adjustment called {ea.name}, but its "
622-
f"value differs from the value of {ea.value:.3f} calculated here. This "
623-
"Entry will be discarded."
624-
)
625-
else:
626-
# Add the correction to the energy_adjustments list
627-
entry.energy_adjustments.append(ea)
628-
629-
if not ignore_entry:
630-
processed_entry_list.append(entry)
657+
if n_workers == 1:
658+
for entry in tqdm(entries, disable=not verbose):
659+
result = self._process_entry_inplace(entry, clean, on_error)
660+
if result is None:
661+
continue
662+
entry, ignore_entry = result
663+
if not ignore_entry:
664+
processed_entry_list.append(entry)
665+
elif not inplace:
666+
# set python warnings to ignore otherwise warnings will be printed multiple times
667+
with tqdm_joblib(tqdm(total=len(entries), disable=not verbose)), set_python_warnings("ignore"):
668+
results = Parallel(n_jobs=n_workers)(
669+
delayed(self._process_entry_inplace)(entry, clean, on_error) for entry in entries
670+
)
671+
for result in results:
672+
if result is None:
673+
continue
674+
entry, ignore_entry = result
675+
if not ignore_entry:
676+
processed_entry_list.append(entry)
677+
else:
678+
raise ValueError("Parallel processing is not possible with for 'inplace=True'")
631679

632680
return processed_entry_list
633681

@@ -1133,7 +1181,9 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
11331181
expected_u = float(u_settings.get(symbol, 0))
11341182
actual_u = float(calc_u.get(symbol, 0))
11351183
if actual_u != expected_u:
1136-
raise CompatibilityError(f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3}")
1184+
raise CompatibilityError(
1185+
f"Invalid U value of {actual_u:.3} on {symbol}, expected {expected_u:.3} for {entry.as_dict()}"
1186+
)
11371187
if symbol in u_corrections:
11381188
adjustments.append(
11391189
CompositionEnergyAdjustment(
@@ -1450,6 +1500,7 @@ def process_entries(
14501500
clean: bool = False,
14511501
verbose: bool = False,
14521502
inplace: bool = True,
1503+
n_workers: int = 1,
14531504
on_error: Literal["ignore", "warn", "raise"] = "ignore",
14541505
) -> list[AnyComputedEntry]:
14551506
"""Process a sequence of entries with the chosen Compatibility scheme.
@@ -1463,6 +1514,7 @@ def process_entries(
14631514
Default is False.
14641515
inplace (bool): Whether to modify the entries in place. If False, a copy of the
14651516
entries is made and processed. Default is True.
1517+
n_workers (int): Number of workers to use for parallel processing. Default is 1.
14661518
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
14671519
raises CompatibilityError. Defaults to 'ignore'.
14681520
@@ -1480,7 +1532,8 @@ def process_entries(
14801532

14811533
# pre-process entries with the given solid compatibility class
14821534
if self.solid_compat:
1483-
entries = self.solid_compat.process_entries(entries, clean=True)
1535+
entries = self.solid_compat.process_entries(entries, clean=True, inplace=inplace, n_workers=n_workers)
1536+
return [entries]
14841537

14851538
# when processing single entries, all H2 polymorphs will get assigned the
14861539
# same energy
@@ -1514,7 +1567,9 @@ def process_entries(
15141567
h2_entries = sorted(h2_entries, key=lambda e: e.energy_per_atom)
15151568
self.h2_energy = h2_entries[0].energy_per_atom # type: ignore[assignment]
15161569

1517-
return super().process_entries(entries, clean=clean, verbose=verbose, inplace=inplace, on_error=on_error)
1570+
return super().process_entries(
1571+
entries, clean=clean, verbose=verbose, inplace=inplace, n_workers=n_workers, on_error=on_error
1572+
)
15181573

15191574

15201575
def needs_u_correction(

src/pymatgen/util/joblib.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""This module provides utility functions for getting progress bar with joblib."""
2+
3+
from __future__ import annotations
4+
5+
import contextlib
6+
import os
7+
from typing import TYPE_CHECKING, Any
8+
9+
import joblib
10+
11+
if TYPE_CHECKING:
12+
from collections.abc import Iterator
13+
14+
from tqdm import tqdm
15+
16+
17+
@contextlib.contextmanager
18+
def tqdm_joblib(tqdm_object: tqdm) -> Iterator[None]:
19+
"""Context manager to patch joblib to report into tqdm progress bar given
20+
as argument.
21+
"""
22+
23+
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
24+
def __call__(self, *args: tuple, **kwargs: dict[str, Any]) -> None:
25+
"""This will be called after each batch, to update the progress bar."""
26+
tqdm_object.update(n=self.batch_size)
27+
return super().__call__(*args, **kwargs)
28+
29+
old_batch_callback = joblib.parallel.BatchCompletionCallBack
30+
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
31+
try:
32+
yield tqdm_object
33+
finally:
34+
joblib.parallel.BatchCompletionCallBack = old_batch_callback
35+
tqdm_object.close()
36+
37+
38+
@contextlib.contextmanager
39+
def set_python_warnings(warnings):
40+
"""Context manager to set the PYTHONWARNINGS environment variable to the
41+
given value. This is useful for preventing spam when using parallel processing.
42+
"""
43+
original_warnings = os.environ.get("PYTHONWARNINGS")
44+
os.environ["PYTHONWARNINGS"] = warnings
45+
try:
46+
yield
47+
finally:
48+
if original_warnings is None:
49+
del os.environ["PYTHONWARNINGS"]
50+
else:
51+
os.environ["PYTHONWARNINGS"] = original_warnings

tests/entries/test_compatibility.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,17 @@ def test_process_entries(self):
481481
entries = self.compat.process_entries([self.entry1, self.entry2, self.entry3, self.entry4])
482482
assert len(entries) == 2
483483

484+
def test_parallel_process_entries(self):
485+
with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"):
486+
entries = self.compat.process_entries(
487+
[self.entry1, self.entry2, self.entry3, self.entry4], inplace=True, n_workers=2
488+
)
489+
490+
entries = self.compat.process_entries(
491+
[self.entry1, self.entry2, self.entry3, self.entry4], inplace=False, n_workers=2
492+
)
493+
assert len(entries) == 2
494+
484495
def test_msonable(self):
485496
compat_dict = self.compat.as_dict()
486497
decoder = MontyDecoder()
@@ -1879,6 +1890,22 @@ def test_processing_entries_inplace(self):
18791890
MaterialsProjectAqueousCompatibility().process_entries(entries, inplace=False)
18801891
assert all(e.correction == e_copy.correction for e, e_copy in zip(entries, entries_copy))
18811892

1893+
def test_parallel_process_entries(self):
1894+
hydrate_entry = ComputedEntry(Composition("FeH4O2"), -10) # nH2O = 2
1895+
hydrate_entry2 = ComputedEntry(Composition("Li2O2H2"), -10) # nH2O = 0
1896+
1897+
entry_list = [hydrate_entry, hydrate_entry2]
1898+
1899+
compat = MaterialsProjectAqueousCompatibility(
1900+
o2_energy=-10, h2o_energy=-20, h2o_adjustments=-0.5, solid_compat=None
1901+
)
1902+
1903+
with pytest.raises(ValueError, match="Parallel processing is not possible with for 'inplace=True'"):
1904+
entries = compat.process_entries(entry_list, inplace=True, n_workers=2)
1905+
1906+
entries = compat.process_entries(entry_list, inplace=False, n_workers=2, on_error="raise")
1907+
assert len(entries) == 2
1908+
18821909

18831910
class TestAqueousCorrection(TestCase):
18841911
def setUp(self):

0 commit comments

Comments
 (0)