Skip to content

Commit 155f5af

Browse files
authored
Added a module for working with RMG reaction families (#754)
Now we can identify for each ARC reaction all possible pathways between the reactants and products according to reaction families supported by RMG. Also added dummy ARC families (w/o kinetics) to assist in identifying reaction templates. Tests added.
2 parents f92a478 + 18cb90e commit 155f5af

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+4306
-3016
lines changed

.github/workflows/cont_int.yml

Lines changed: 45 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ jobs:
3636

3737
- name: Cache RMG-Py
3838
id: cache-rmg-py
39-
uses: actions/cache@v2
39+
uses: actions/cache@v3
40+
env:
41+
CACHE_NUMBER: 2
4042
with:
4143
path: RMG-Py
4244
key: ${{ runner.os }}-rmg-main
@@ -54,7 +56,9 @@ jobs:
5456

5557
- name: Cache RMG-database
5658
id: cache-rmg-db
57-
uses: actions/cache@v2
59+
uses: actions/cache@v3
60+
env:
61+
CACHE_NUMBER: 2
5862
with:
5963
path: RMG-database
6064
key: ${{ runner.os }}-rmgdb-main
@@ -71,7 +75,7 @@ jobs:
7175

7276
- name: Cache AutoTST
7377
id: cache-autotst
74-
uses: actions/cache@v2
78+
uses: actions/cache@v3
7579
with:
7680
path: AutoTST
7781
key: ${{ runner.os }}-autotst-main
@@ -86,26 +90,26 @@ jobs:
8690
ref: main
8791
fetch-depth: 1
8892

89-
- name: Cache TS-GCN
90-
id: cache-tsgcn
91-
uses: actions/cache@v2
92-
with:
93-
path: TS-GCN
94-
key: ${{ runner.os }}-tsgcn-main
95-
restore-keys: |
96-
${{ runner.os }}-tsgcn-
97-
- name: Checkout TS-GCN
98-
if: steps.cache-tsgcn.outputs.cache-hit != 'true'
99-
uses: actions/checkout@v3
100-
with:
101-
repository: ReactionMechanismGenerator/TS-GCN
102-
path: TS-GCN
103-
ref: main
104-
fetch-depth: 1
105-
93+
# - name: Cache TS-GCN
94+
# id: cache-tsgcn
95+
# uses: actions/cache@v3
96+
# with:
97+
# path: TS-GCN
98+
# key: ${{ runner.os }}-tsgcn-main
99+
# restore-keys: |
100+
# ${{ runner.os }}-tsgcn-
101+
# - name: Checkout TS-GCN
102+
# if: steps.cache-tsgcn.outputs.cache-hit != 'true'
103+
# uses: actions/checkout@v3
104+
# with:
105+
# repository: ReactionMechanismGenerator/TS-GCN
106+
# path: TS-GCN
107+
# ref: main
108+
# fetch-depth: 1
109+
106110
- name: Cache KinBot
107111
id: cache-kinbot
108-
uses: actions/cache@v2
112+
uses: actions/cache@v3
109113
with:
110114
path: KinBot
111115
key: ${{ runner.os }}-kinbot-2.0.6
@@ -122,14 +126,14 @@ jobs:
122126
fetch-depth: 1
123127

124128
- name: Cache Packages
125-
uses: actions/cache@v2
129+
uses: actions/cache@v3
126130
env:
127131
CACHE_NUMBER: 0
128132
with:
129133
path: ~/conda_pkgs_dir
130134
key:
131135
${{ runner.os }}-conda-${{ env.CACHE_NUMBER }}-${{ hashFiles('environment.yml') }}
132-
136+
133137
- name: Set up miniconda
134138
uses: conda-incubator/setup-miniconda@v3
135139
with:
@@ -140,7 +144,7 @@ jobs:
140144
conda-solver: libmamba
141145

142146
- name: Cache ARC env
143-
uses: actions/cache@v2
147+
uses: actions/cache@v3
144148
with:
145149
path: ${{ env.CONDA }}/envs/arc_env
146150
key: conda-${{ runner.os }}--${{ runner.arch }}-arcenv-${{ env.CACHE_NUMBER}}
@@ -164,14 +168,14 @@ jobs:
164168
echo 'export PATH=$PATH:'"$(pwd)" >> ~/.bashrc
165169
make
166170
cd ..
167-
171+
168172
- name: Set ARC Path
169173
run: |
170174
echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> $GITHUB_ENV
171175
echo 'export PATH=$PATH:'"$(pwd)" >> $GITHUB_ENV
172176
echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> ~/.bashrc
173177
echo 'export PATH=$PATH:'"$(pwd)" >> ~/.bashrc
174-
178+
175179
- name: Install AutoTST
176180
run: |
177181
cd AutoTST
@@ -183,18 +187,18 @@ jobs:
183187
# install pyaml
184188
conda install -n tst_env -c conda-forge -y pyyaml
185189
cd ..
186-
187-
- name: Install TS-GCN
188-
run: |
189-
cd TS-GCN
190-
echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> $GITHUB_ENV
191-
echo 'export PATH=$PATH:'"$(pwd)" >> $GITHUB_ENV
192-
echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> ~/.bashrc
193-
echo 'export PATH=$PATH:'"$(pwd)" >> ~/.bashrc
194-
bash devtools/create_env_cpu.sh
195-
conda env update -n ts_gcn -f environment.yml
196-
cd ..
197-
190+
191+
# - name: Install TS-GCN
192+
# run: |
193+
# cd TS-GCN
194+
# echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> $GITHUB_ENV
195+
# echo 'export PATH=$PATH:'"$(pwd)" >> $GITHUB_ENV
196+
# echo 'export PYTHONPATH=$PYTHONPATH:'"$(pwd)" >> ~/.bashrc
197+
# echo 'export PATH=$PATH:'"$(pwd)" >> ~/.bashrc
198+
# bash devtools/create_env_cpu.sh
199+
# conda env update -n ts_gcn -f environment.yml
200+
# cd ..
201+
198202
- name: Install KinBot
199203
run: |
200204
cd KinBot
@@ -218,13 +222,13 @@ jobs:
218222
- name: Install Torch Ani
219223
run: |
220224
conda env create -f devtools/tani_environment.yml
221-
225+
222226
- name: Install XTB
223227
run: |
224228
conda env create -f devtools/xtb_environment.yml
225-
229+
226230
- name: Test with pytest
227-
shell: bash -el {0}
231+
shell: bash -el {0}
228232
run: |
229233
source ~/.bashrc
230234
conda activate arc_env

ARC.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,8 @@ def parse_command_line_arguments(command_line_args=None):
2121
command_line_args: The command line arguments.
2222
2323
Returns:
24-
The parsed command-line arguments by key words.
24+
The parsed command-line arguments by keywords.
2525
"""
26-
2726
parser = argparse.ArgumentParser(description='Automatic Rate Calculator (ARC)')
2827
parser.add_argument('file', metavar='FILE', type=str, nargs=1,
2928
help='a file describing the job to execute')
@@ -45,7 +44,6 @@ def main():
4544
"""
4645
The main ARC executable function
4746
"""
48-
4947
args = parse_command_line_arguments()
5048
input_file = args.file
5149
project_directory = os.path.abspath(os.path.dirname(args.file))

arc/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import arc.utils
1313

1414
import arc.job
15+
import arc.family
16+
import arc.reaction
1517
import arc.settings
1618
import arc.species
1719
import arc.statmech

arc/checks/ts.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numpy as np
99
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
1010

11-
import arc.rmgdb as rmgdb
1211
from arc import parser
1312
from arc.common import (ARC_PATH,
1413
convert_list_index_0_to_1,
@@ -18,11 +17,9 @@
1817
read_yaml_file,
1918
sum_list_entries,
2019
)
20+
from arc.family.family import get_reaction_family_products
2121
from arc.imports import settings
2222
from arc.species.converter import check_xyz_dict, displace_xyz, xyz_to_dmat
23-
from arc.mapping.engine import (get_atom_indices_of_labeled_atoms_in_an_rmg_reaction,
24-
get_rmg_reactions_from_arc_reaction,
25-
)
2623
from arc.statmech.factory import statmech_factory
2724

2825
if TYPE_CHECKING:
@@ -302,12 +299,10 @@ def check_normal_mode_displacement(reaction: 'ARCReaction',
302299
"""
303300
if job is None:
304301
return
305-
if reaction.family is None:
306-
rmgdb.determine_family(reaction)
307302
amplitudes = amplitudes or [0.1, 0.2, 0.4, 0.6, 0.8, 1]
308303
amplitudes = [amplitudes] if isinstance(amplitudes, float) else amplitudes
309304
reaction.ts_species.ts_checks['NMD'] = False
310-
rmg_reactions = get_rmg_reactions_from_arc_reaction(arc_reaction=reaction) or list()
305+
product_dicts = get_reaction_family_products(rxn=reaction, rmg_family_set=reaction.family if reaction.family else None)
311306
freqs, normal_modes_disp = parser.parse_normal_mode_displacement(path=job.local_path_to_output_file, raise_error=False)
312307
if not len(normal_modes_disp):
313308
return
@@ -332,11 +327,8 @@ def check_normal_mode_displacement(reaction: 'ARCReaction',
332327
tolerance=1.5,
333328
bond_lone_hydrogens=bond_lone_hs)
334329
got_expected_changing_bonds = False
335-
for i, rmg_reaction in enumerate(rmg_reactions):
336-
r_label_dict = get_atom_indices_of_labeled_atoms_in_an_rmg_reaction(arc_reaction=reaction,
337-
rmg_reaction=rmg_reaction)[0]
338-
if r_label_dict is None:
339-
continue
330+
for i, product_dict in enumerate(product_dicts):
331+
r_label_dict = product_dict['r_label_map']
340332
expected_breaking_bonds, expected_forming_bonds = reaction.get_expected_changing_bonds(r_label_dict=r_label_dict)
341333
if expected_breaking_bonds is None or expected_forming_bonds is None:
342334
continue
@@ -355,7 +347,7 @@ def check_normal_mode_displacement(reaction: 'ARCReaction',
355347
'breaking/forming bonds due to a missing RMG template; '
356348
reaction.ts_species.ts_checks['NMD'] = True
357349
break
358-
if not len(rmg_reactions):
350+
if not len(product_dicts):
359351
# Just check that some bonds break/form, and that this is not a torsional saddle point.
360352
warning = f'Cannot check normal mode displacement for reaction {reaction} since a corresponding ' \
361353
f'RMG template could not be generated'
@@ -535,18 +527,17 @@ def get_rxn_normal_mode_disp_atom_number(rxn_family: Optional[str] = None,
535527
Returns:
536528
int: The respective number of atoms.
537529
"""
530+
if rxn_family is None and reaction is None:
531+
raise ValueError('Either `rxn_family` or `reaction` must be given.')
538532
default = 3
539533
if rms_list is not None \
540534
and (not isinstance(rms_list, list) or not all(isinstance(entry, float) for entry in rms_list)):
541535
raise TypeError(f'rms_list must be a non empty list, got {rms_list} of type {type(rms_list)}.')
542-
family = rxn_family
543-
if family is None and reaction is not None and reaction.family is not None:
544-
family = reaction.family.label
545-
if family is None:
536+
if reaction is not None and reaction.family is None:
546537
logger.warning(f'Cannot deduce a reaction family for {reaction}, assuming {default} atoms in the reaction zone.')
547538
return default
548539
content = read_yaml_file(os.path.join(ARC_PATH, 'data', 'rxn_normal_mode_disp.yml'))
549-
number_by_family = content.get(rxn_family, default)
540+
number_by_family = content.get(rxn_family or reaction.family, default)
550541
if rms_list is None or not len(rms_list):
551542
return number_by_family
552543
entry = None

arc/checks/ts_test.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from arc.level import Level
1818
from arc.parser import parse_normal_mode_displacement, parse_xyz_from_file
1919
from arc.reaction import ARCReaction
20-
from arc.rmgdb import load_families_only, make_rmg_database_object
2120
from arc.species.species import ARCSpecies, TSGuess
2221
from arc.utils.wip import work_in_progress
2322

@@ -32,8 +31,6 @@ def setUpClass(cls):
3231
A method that is run before all unit tests in this class.
3332
"""
3433
cls.maxDiff = None
35-
cls.rmgdb = make_rmg_database_object()
36-
load_families_only(cls.rmgdb)
3734

3835
cls.rms_list_1 = [0.01414213562373095, 0.05, 0.04, 0.5632938842203065, 0.7993122043357026, 0.08944271909999159,
3936
0.10677078252031312, 0.09000000000000001, 0.05, 0.09433981132056604]
@@ -185,15 +182,6 @@ def setUpClass(cls):
185182
xyz=os.path.join(ts.ARC_PATH, 'arc', 'testing', 'freq', 'TS_nC3H7-iC3H7.out'))
186183
cls.rxn_8.ts_label = cls.rxn_8.ts_species.label
187184

188-
cls.rxn_2a.determine_family(rmg_database=cls.rmgdb, save_order=True)
189-
cls.rxn_2b.determine_family(rmg_database=cls.rmgdb, save_order=True)
190-
cls.rxn_3.determine_family(rmg_database=cls.rmgdb, save_order=True)
191-
cls.rxn_4.determine_family(rmg_database=cls.rmgdb, save_order=True)
192-
cls.rxn_5.determine_family(rmg_database=cls.rmgdb, save_order=True)
193-
cls.rxn_6.determine_family(rmg_database=cls.rmgdb, save_order=True)
194-
cls.rxn_7.determine_family(rmg_database=cls.rmgdb, save_order=True)
195-
cls.rxn_8.determine_family(rmg_database=cls.rmgdb, save_order=True)
196-
197185
cls.ccooj_xyz = {'symbols': ('C', 'C', 'O', 'O', 'H', 'H', 'H', 'H', 'H'),
198186
'isotopes': (12, 12, 16, 16, 1, 1, 1, 1, 1),
199187
'coords': ((-1.0558210286905791, -0.033295741345331475, -0.10080257427276477),
@@ -380,7 +368,6 @@ def test_check_normal_mode_displacement(self):
380368
self.assertFalse(self.rxn_2a.ts_species.ts_checks['NMD'])
381369
self.job1.local_path_to_output_file = os.path.join(ts.ARC_PATH, 'arc', 'testing', 'composite',
382370
'TS_intra_H_migration_CBS-QB3.out')
383-
self.rxn_2a.determine_family(rmg_database=self.rmgdb)
384371
ts.check_normal_mode_displacement(reaction=self.rxn_2a, job=self.job1)
385372
self.assertTrue(self.rxn_2a.ts_species.ts_checks['NMD'])
386373
self.rxn_2a.ts_species.populate_ts_checks()
@@ -659,7 +646,8 @@ def test_get_rxn_normal_mode_disp_atom_number(self):
659646
ts.get_rxn_normal_mode_disp_atom_number('family', rms_list=['family'])
660647
with self.assertRaises(TypeError):
661648
ts.get_rxn_normal_mode_disp_atom_number('family', rms_list=15.215)
662-
self.assertEqual(ts.get_rxn_normal_mode_disp_atom_number(), 3)
649+
with self.assertRaises(ValueError):
650+
self.assertEqual(ts.get_rxn_normal_mode_disp_atom_number(), 3)
663651
self.assertEqual(ts.get_rxn_normal_mode_disp_atom_number('default'), 3)
664652
self.assertEqual(ts.get_rxn_normal_mode_disp_atom_number('intra_H_migration'), 3)
665653
self.assertEqual(ts.get_rxn_normal_mode_disp_atom_number('intra_H_migration', rms_list=self.rms_list_1), 4)

0 commit comments

Comments
 (0)