12
12
from typing import TYPE_CHECKING , Union , cast
13
13
14
14
import numpy as np
15
+ from joblib import Parallel , delayed
15
16
from monty .design_patterns import cached_class
16
17
from monty .json import MSONable
17
18
from monty .serialization import loadfn
30
31
)
31
32
from pymatgen .io .vasp .sets import MITRelaxSet , MPRelaxSet , VaspInputSet
32
33
from pymatgen .util .due import Doi , due
34
+ from pymatgen .util .joblib import set_python_warnings , tqdm_joblib
33
35
34
36
if TYPE_CHECKING :
35
37
from collections .abc import Sequence
@@ -538,28 +540,86 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
538
540
"""
539
541
raise NotImplementedError
540
542
541
- def process_entry (self , entry : ComputedEntry , ** kwargs ) -> ComputedEntry | None :
543
+ def process_entry (self , entry : ComputedEntry , inplace : bool = True , ** kwargs ) -> ComputedEntry | None :
542
544
"""Process a single entry with the chosen Corrections. Note
543
545
that this method will change the data of the original entry.
544
546
545
547
Args:
546
548
entry: A ComputedEntry object.
549
+ inplace (bool): Whether to adjust the entry in place. Defaults to True.
547
550
**kwargs: Will be passed to process_entries().
548
551
549
552
Returns:
550
553
An adjusted entry if entry is compatible, else None.
551
554
"""
552
- try :
553
- return self .process_entries (entry , ** kwargs )[0 ]
554
- except IndexError :
555
+ if not inplace :
556
+ entry = copy .deepcopy (entry )
557
+
558
+ entry = self ._process_entry_inplace (entry , ** kwargs )
559
+
560
+ return entry [0 ] if entry is not None else None
561
+
562
+ def _process_entry_inplace (
563
+ self ,
564
+ entry : AnyComputedEntry ,
565
+ clean : bool = True ,
566
+ on_error : Literal ["ignore" , "warn" , "raise" ] = "ignore" ,
567
+ ) -> ComputedEntry | None :
568
+ """Process a single entry with the chosen Corrections. Note
569
+ that this method will change the data of the original entry.
570
+
571
+ Args:
572
+ entry: A ComputedEntry object.
573
+ clean (bool): Whether to remove any previously-applied energy adjustments.
574
+ If True, all EnergyAdjustment are removed prior to processing the Entry.
575
+ Defaults to True.
576
+ on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
577
+ raises CompatibilityError. Defaults to 'ignore'.
578
+
579
+ Returns:
580
+ An adjusted entry if entry is compatible, else None.
581
+ """
582
+ ignore_entry = False
583
+ # if clean is True, remove all previous adjustments from the entry
584
+ if clean :
585
+ entry .energy_adjustments = []
586
+
587
+ try : # get the energy adjustments
588
+ adjustments = self .get_adjustments (entry )
589
+ except CompatibilityError as exc :
590
+ if on_error == "raise" :
591
+ raise
592
+ if on_error == "warn" :
593
+ warnings .warn (str (exc ))
555
594
return None
556
595
596
+ for ea in adjustments :
597
+ # Has this correction already been applied?
598
+ if (ea .name , ea .cls , ea .value ) in [(ea2 .name , ea2 .cls , ea2 .value ) for ea2 in entry .energy_adjustments ]:
599
+ # we already applied this exact correction. Do nothing.
600
+ pass
601
+ elif (ea .name , ea .cls ) in [(ea2 .name , ea2 .cls ) for ea2 in entry .energy_adjustments ]:
602
+ # we already applied a correction with the same name
603
+ # but a different value. Something is wrong.
604
+ ignore_entry = True
605
+ warnings .warn (
606
+ f"Entry { entry .entry_id } already has an energy adjustment called { ea .name } , but its "
607
+ f"value differs from the value of { ea .value :.3f} calculated here. This "
608
+ "Entry will be discarded."
609
+ )
610
+ else :
611
+ # Add the correction to the energy_adjustments list
612
+ entry .energy_adjustments .append (ea )
613
+
614
+ return entry , ignore_entry
615
+
557
616
def process_entries (
558
617
self ,
559
618
entries : AnyComputedEntry | list [AnyComputedEntry ],
560
619
clean : bool = True ,
561
620
verbose : bool = False ,
562
621
inplace : bool = True ,
622
+ n_workers : int = 1 ,
563
623
on_error : Literal ["ignore" , "warn" , "raise" ] = "ignore" ,
564
624
) -> list [AnyComputedEntry ]:
565
625
"""Process a sequence of entries with the chosen Compatibility scheme.
@@ -576,6 +636,7 @@ def process_entries(
576
636
verbose (bool): Whether to display progress bar for processing multiple entries.
577
637
Defaults to False.
578
638
inplace (bool): Whether to adjust input entries in place. Defaults to True.
639
+ n_workers (int): Number of workers to use for parallel processing. Defaults to 1.
579
640
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
580
641
raises CompatibilityError. Defaults to 'ignore'.
581
642
@@ -593,41 +654,28 @@ def process_entries(
593
654
if not inplace :
594
655
entries = copy .deepcopy (entries )
595
656
596
- for entry in tqdm (entries , disable = not verbose ):
597
- ignore_entry = False
598
- # if clean is True, remove all previous adjustments from the entry
599
- if clean :
600
- entry .energy_adjustments = []
601
-
602
- try : # get the energy adjustments
603
- adjustments = self .get_adjustments (entry )
604
- except CompatibilityError as exc :
605
- if on_error == "raise" :
606
- raise
607
- if on_error == "warn" :
608
- warnings .warn (str (exc ))
609
- continue
610
-
611
- for ea in adjustments :
612
- # Has this correction already been applied?
613
- if (ea .name , ea .cls , ea .value ) in [(ea2 .name , ea2 .cls , ea2 .value ) for ea2 in entry .energy_adjustments ]:
614
- # we already applied this exact correction. Do nothing.
615
- pass
616
- elif (ea .name , ea .cls ) in [(ea2 .name , ea2 .cls ) for ea2 in entry .energy_adjustments ]:
617
- # we already applied a correction with the same name
618
- # but a different value. Something is wrong.
619
- ignore_entry = True
620
- warnings .warn (
621
- f"Entry { entry .entry_id } already has an energy adjustment called { ea .name } , but its "
622
- f"value differs from the value of { ea .value :.3f} calculated here. This "
623
- "Entry will be discarded."
624
- )
625
- else :
626
- # Add the correction to the energy_adjustments list
627
- entry .energy_adjustments .append (ea )
628
-
629
- if not ignore_entry :
630
- processed_entry_list .append (entry )
657
+ if n_workers == 1 :
658
+ for entry in tqdm (entries , disable = not verbose ):
659
+ result = self ._process_entry_inplace (entry , clean , on_error )
660
+ if result is None :
661
+ continue
662
+ entry , ignore_entry = result
663
+ if not ignore_entry :
664
+ processed_entry_list .append (entry )
665
+ elif not inplace :
666
+ # set python warnings to ignore otherwise warnings will be printed multiple times
667
+ with tqdm_joblib (tqdm (total = len (entries ), disable = not verbose )), set_python_warnings ("ignore" ):
668
+ results = Parallel (n_jobs = n_workers )(
669
+ delayed (self ._process_entry_inplace )(entry , clean , on_error ) for entry in entries
670
+ )
671
+ for result in results :
672
+ if result is None :
673
+ continue
674
+ entry , ignore_entry = result
675
+ if not ignore_entry :
676
+ processed_entry_list .append (entry )
677
+ else :
678
+ raise ValueError ("Parallel processing is not possible with for 'inplace=True'" )
631
679
632
680
return processed_entry_list
633
681
@@ -1133,7 +1181,9 @@ def get_adjustments(self, entry: AnyComputedEntry) -> list[EnergyAdjustment]:
1133
1181
expected_u = float (u_settings .get (symbol , 0 ))
1134
1182
actual_u = float (calc_u .get (symbol , 0 ))
1135
1183
if actual_u != expected_u :
1136
- raise CompatibilityError (f"Invalid U value of { actual_u :.3} on { symbol } , expected { expected_u :.3} " )
1184
+ raise CompatibilityError (
1185
+ f"Invalid U value of { actual_u :.3} on { symbol } , expected { expected_u :.3} for { entry .as_dict ()} "
1186
+ )
1137
1187
if symbol in u_corrections :
1138
1188
adjustments .append (
1139
1189
CompositionEnergyAdjustment (
@@ -1450,6 +1500,7 @@ def process_entries(
1450
1500
clean : bool = False ,
1451
1501
verbose : bool = False ,
1452
1502
inplace : bool = True ,
1503
+ n_workers : int = 1 ,
1453
1504
on_error : Literal ["ignore" , "warn" , "raise" ] = "ignore" ,
1454
1505
) -> list [AnyComputedEntry ]:
1455
1506
"""Process a sequence of entries with the chosen Compatibility scheme.
@@ -1463,6 +1514,7 @@ def process_entries(
1463
1514
Default is False.
1464
1515
inplace (bool): Whether to modify the entries in place. If False, a copy of the
1465
1516
entries is made and processed. Default is True.
1517
+ n_workers (int): Number of workers to use for parallel processing. Default is 1.
1466
1518
on_error ('ignore' | 'warn' | 'raise'): What to do when get_adjustments(entry)
1467
1519
raises CompatibilityError. Defaults to 'ignore'.
1468
1520
@@ -1480,7 +1532,8 @@ def process_entries(
1480
1532
1481
1533
# pre-process entries with the given solid compatibility class
1482
1534
if self .solid_compat :
1483
- entries = self .solid_compat .process_entries (entries , clean = True )
1535
+ entries = self .solid_compat .process_entries (entries , clean = True , inplace = inplace , n_workers = n_workers )
1536
+ return [entries ]
1484
1537
1485
1538
# when processing single entries, all H2 polymorphs will get assigned the
1486
1539
# same energy
@@ -1514,7 +1567,9 @@ def process_entries(
1514
1567
h2_entries = sorted (h2_entries , key = lambda e : e .energy_per_atom )
1515
1568
self .h2_energy = h2_entries [0 ].energy_per_atom # type: ignore[assignment]
1516
1569
1517
- return super ().process_entries (entries , clean = clean , verbose = verbose , inplace = inplace , on_error = on_error )
1570
+ return super ().process_entries (
1571
+ entries , clean = clean , verbose = verbose , inplace = inplace , n_workers = n_workers , on_error = on_error
1572
+ )
1518
1573
1519
1574
1520
1575
def needs_u_correction (
0 commit comments