Skip to content

Commit 3ec31cb

Browse files
committed
output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocument based on the input type of mol_or_struct.
change name ForceFieldTaskDocument => ForceFieldStructureTaskDocument output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocument based on type of mol_or_struct update ForceFieldTaskDocument => ForceFieldStructureTaskDocument in the tests import Union from typing include Union in forcefield/md.py take the suggestions from the formatter take ruff's suggestions try again with ruff format ruff format again try again ruff ruff again fix the mypy error Take inputs of both Molecule and Structure update docstring add molecule test for forcefield
1 parent eef0175 commit 3ec31cb

File tree

7 files changed

+200
-79
lines changed

7 files changed

+200
-79
lines changed

src/atomate2/ase/schemas.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -232,27 +232,6 @@ class AseStructureTaskDoc(StructureMetadata):
232232

233233
tags: list[str] | None = Field(None, description="List of tags for the task.")
234234

235-
@classmethod
236-
def from_ase_task_doc(
237-
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
238-
) -> AseStructureTaskDoc:
239-
"""Create an AseStructureTaskDoc for a task that has ASE-compatible outputs.
240-
241-
Parameters
242-
----------
243-
ase_task_doc : AseTaskDoc
244-
Task doc for the calculation
245-
task_document_kwargs : dict
246-
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
247-
"""
248-
task_document_kwargs.update(
249-
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
250-
structure=ase_task_doc.mol_or_struct,
251-
)
252-
return cls.from_structure(
253-
meta_structure=ase_task_doc.mol_or_struct, **task_document_kwargs
254-
)
255-
256235

257236
class AseMoleculeTaskDoc(MoleculeMetadata):
258237
"""Document containing information on molecule manipulation using ASE."""

src/atomate2/forcefields/jobs.py

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515

1616
from atomate2.ase.jobs import AseRelaxMaker
1717
from atomate2.forcefields import MLFF, _get_formatted_ff_name
18-
from atomate2.forcefields.schemas import ForceFieldTaskDocument
18+
from atomate2.forcefields.schemas import (
19+
ForceFieldMoleculeTaskDocument,
20+
ForceFieldStructureTaskDocument,
21+
ForceFieldTaskDocument,
22+
)
1923
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype
2024

2125
if TYPE_CHECKING:
2226
from collections.abc import Callable
2327
from pathlib import Path
2428

2529
from ase.calculators.calculator import Calculator
26-
from pymatgen.core.structure import Structure
30+
from pymatgen.core.structure import Molecule, Structure
2731

2832
logger = logging.getLogger(__name__)
2933

@@ -50,7 +54,8 @@ def forcefield_job(method: Callable) -> job:
5054
This is a thin wrapper around :obj:`~jobflow.core.job.Job` that configures common
5155
settings for all forcefield jobs. For example, it ensures that large data objects
5256
(currently only trajectories) are all stored in the atomate2 data store.
53-
It also configures the output schema to be a ForceFieldTaskDocument :obj:`.TaskDoc`.
57+
It also configures the output schema to be a
58+
ForceFieldStructureTaskDocument :obj:`.TaskDoc`.
5459
5560
Any makers that return forcefield jobs (not flows) should decorate the
5661
``make`` method with @forcefield_job. For example:
@@ -74,9 +79,7 @@ def make(structure):
7479
callable
7580
A decorated version of the make function that will generate forcefield jobs.
7681
"""
77-
return job(
78-
method, data=_FORCEFIELD_DATA_OBJECTS, output_schema=ForceFieldTaskDocument
79-
)
82+
return job(method, data=_FORCEFIELD_DATA_OBJECTS)
8083

8184

8285
@dataclass
@@ -120,7 +123,7 @@ class ForceFieldRelaxMaker(AseRelaxMaker):
120123
tags : list[str] or None
121124
A list of tags for the task.
122125
task_document_kwargs : dict (deprecated)
123-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
126+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
124127
"""
125128

126129
name: str = "Force field relax"
@@ -148,15 +151,15 @@ def __post_init__(self) -> None:
148151

149152
@forcefield_job
150153
def make(
151-
self, structure: Structure, prev_dir: str | Path | None = None
152-
) -> ForceFieldTaskDocument:
154+
self, structure: Molecule | Structure, prev_dir: str | Path | None = None
155+
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
153156
"""
154157
Perform a relaxation of a structure using a force field.
155158
156159
Parameters
157160
----------
158-
structure: .Structure
159-
pymatgen structure.
161+
structure: .Structure or Molecule
162+
pymatgen structure or molecule.
160163
prev_dir : str or Path or None
161164
A previous calculation directory to copy output files from. Unused, just
162165
added to match the method signature of other makers.
@@ -172,7 +175,7 @@ def make(
172175
stacklevel=1,
173176
)
174177

175-
return ForceFieldTaskDocument.from_ase_compatible_result(
178+
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
176179
str(self.force_field_name), # make mypy happy
177180
ase_result,
178181
self.steps,
@@ -214,7 +217,7 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
214217
calculator_kwargs : dict
215218
Keyword arguments that will get passed to the ASE calculator.
216219
task_document_kwargs : dict (deprecated)
217-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
220+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
218221
"""
219222

220223
name: str = "Force field static"
@@ -257,7 +260,7 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
257260
calculator_kwargs : dict
258261
Keyword arguments that will get passed to the ASE calculator.
259262
task_document_kwargs : dict (deprecated)
260-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
263+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
261264
"""
262265

263266
name: str = f"{MLFF.CHGNet} relax"
@@ -293,7 +296,7 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
293296
calculator_kwargs : dict
294297
Keyword arguments that will get passed to the ASE calculator.
295298
task_document_kwargs : dict (deprecated)
296-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
299+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
297300
"""
298301

299302
name: str = f"{MLFF.CHGNet} static"
@@ -336,7 +339,7 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
336339
calculator_kwargs : dict
337340
Keyword arguments that will get passed to the ASE calculator.
338341
task_document_kwargs : dict (deprecated)
339-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
342+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
340343
"""
341344

342345
name: str = f"{MLFF.M3GNet} relax"
@@ -374,7 +377,7 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
374377
calculator_kwargs : dict
375378
Keyword arguments that will get passed to the ASE calculator.
376379
task_document_kwargs : dict (deprecated)
377-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
380+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
378381
"""
379382

380383
name: str = f"{MLFF.M3GNet} static"
@@ -417,7 +420,7 @@ class NEPRelaxMaker(ForceFieldRelaxMaker):
417420
calculator_kwargs : dict
418421
Keyword arguments that will get passed to the ASE calculator.
419422
task_document_kwargs : dict (deprecated)
420-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
423+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
421424
"""
422425

423426
name: str = f"{MLFF.NEP} relax"
@@ -453,7 +456,7 @@ class NEPStaticMaker(ForceFieldStaticMaker):
453456
calculator_kwargs : dict
454457
Keyword arguments that will get passed to the ASE calculator.
455458
task_document_kwargs : dict (deprecated)
456-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
459+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
457460
"""
458461

459462
name: str = f"{MLFF.NEP} static"
@@ -496,7 +499,7 @@ class NequipRelaxMaker(ForceFieldRelaxMaker):
496499
calculator_kwargs : dict
497500
Keyword arguments that will get passed to the ASE calculator.
498501
task_document_kwargs : dict (deprecated)
499-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
502+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
500503
"""
501504

502505
name: str = f"{MLFF.Nequip} relax"
@@ -531,7 +534,7 @@ class NequipStaticMaker(ForceFieldStaticMaker):
531534
calculator_kwargs : dict
532535
Keyword arguments that will get passed to the ASE calculator.
533536
task_document_kwargs : dict (deprecated)
534-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
537+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
535538
"""
536539

537540
name: str = f"{MLFF.Nequip} static"
@@ -578,7 +581,7 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
578581
trained for Matbench Discovery on the MPtrj dataset available at
579582
https://figshare.com/articles/dataset/22715158.
580583
task_document_kwargs : dict (deprecated)
581-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
584+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
582585
"""
583586

584587
name: str = f"{MLFF.MACE_MP_0} relax"
@@ -618,7 +621,7 @@ class MACEStaticMaker(ForceFieldStaticMaker):
618621
trained for Matbench Discovery on the MPtrj dataset available at
619622
https://figshare.com/articles/dataset/22715158.
620623
task_document_kwargs : dict (deprecated)
621-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
624+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
622625
"""
623626

624627
name: str = f"{MLFF.MACE_MP_0} static"
@@ -667,7 +670,7 @@ class SevenNetRelaxMaker(ForceFieldRelaxMaker):
667670
trained for Matbench Discovery on the MPtrj dataset available at
668671
https://figshare.com/articles/dataset/22715158.
669672
task_document_kwargs : dict (deprecated)
670-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
673+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
671674
"""
672675

673676
name: str = f"{MLFF.SevenNet} relax"
@@ -709,7 +712,7 @@ class SevenNetStaticMaker(ForceFieldStaticMaker):
709712
trained for Matbench Discovery on the MPtrj dataset available at
710713
https://figshare.com/articles/dataset/22715158.
711714
task_document_kwargs : dict (deprecated)
712-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
715+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
713716
"""
714717

715718
name: str = f"{MLFF.SevenNet} static"
@@ -749,7 +752,7 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
749752
calculator_kwargs : dict
750753
Keyword arguments that will get passed to the ASE calculator.
751754
task_document_kwargs : dict (deprecated)
752-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
755+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
753756
"""
754757

755758
name: str = f"{MLFF.GAP} relax"
@@ -785,7 +788,7 @@ class GAPStaticMaker(ForceFieldStaticMaker):
785788
calculator_kwargs : dict
786789
Keyword arguments that will get passed to the ASE calculator.
787790
task_document_kwargs : dict (deprecated)
788-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
791+
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
789792
"""
790793

791794
name: str = f"{MLFF.GAP} static"

src/atomate2/forcefields/md.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
_DEFAULT_CALCULATOR_KWARGS,
1616
_FORCEFIELD_DATA_OBJECTS,
1717
)
18-
from atomate2.forcefields.schemas import ForceFieldTaskDocument
18+
from atomate2.forcefields.schemas import (
19+
ForceFieldMoleculeTaskDocument,
20+
ForceFieldStructureTaskDocument,
21+
ForceFieldTaskDocument,
22+
)
1923
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype
2024

2125
if TYPE_CHECKING:
2226
from pathlib import Path
2327

2428
from ase.calculators.calculator import Calculator
25-
from pymatgen.core.structure import Structure
29+
from pymatgen.core.structure import Molecule, Structure
2630

2731

2832
@dataclass
@@ -126,19 +130,18 @@ def __post_init__(self) -> None:
126130

127131
@job(
128132
data=[*_FORCEFIELD_DATA_OBJECTS, "ionic_steps"],
129-
output_schema=ForceFieldTaskDocument,
130133
)
131134
def make(
132135
self,
133-
structure: Structure,
136+
structure: Molecule | Structure,
134137
prev_dir: str | Path | None = None,
135-
) -> ForceFieldTaskDocument:
138+
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
136139
"""
137140
Perform MD on a structure using forcefields and jobflow.
138141
139142
Parameters
140143
----------
141-
structure: .Structure
144+
structure: .Structure or Molecule
142145
pymatgen structure.
143146
prev_dir : str or Path or None
144147
A previous calculation directory to copy output files from. Unused, just
@@ -156,7 +159,7 @@ def make(
156159
stacklevel=1,
157160
)
158161

159-
return ForceFieldTaskDocument.from_ase_compatible_result(
162+
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
160163
str(self.force_field_name), # make mypy happy
161164
md_result,
162165
relax_cell=(self.ensemble == MDEnsemble.npt),

0 commit comments

Comments
 (0)