Skip to content

Commit 41b2e99

Browse files
authored
Merge pull request #981 from MolSSI/vendor_qcel
Vendor in some models from qcelemental
2 parents c412f9b + 7e41b46 commit 41b2e99

File tree

10 files changed

+122
-36
lines changed

10 files changed

+122
-36
lines changed

docs/source/user_guide/records/optimization.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,8 @@ when adding to datasets).
101101

102102
.. code-block:: py3
103103
104-
from qcportal.optimization import OptimizationSpecification
104+
from qcportal.optimization import OptimizationSpecification, OptimizationProtocols
105105
from qcportal.singlepoint import QCSpecification
106-
from qcelemental.models.procedures import OptimizationProtocols
107106
108107
opt_spec = OptimizationSpecification(
109108
program="geometric",

qcfractal/qcfractal/components/gridoptimization/testing_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,12 @@
88
except ImportError:
99
import pydantic
1010
from qcelemental.models import Molecule, FailedOperation, ComputeError, OptimizationResult as QCEl_OptimizationResult
11-
from qcelemental.models.procedures import OptimizationProtocols
1211

1312
from qcarchivetesting.helpers import read_procedure_data, read_record_data
1413
from qcfractal.components.gridoptimization.record_db_models import GridoptimizationRecordORM
1514
from qcfractal.testing_helpers import run_service
1615
from qcportal.gridoptimization import GridoptimizationSpecification, GridoptimizationKeywords, GridoptimizationRecord
17-
from qcportal.optimization import OptimizationSpecification
16+
from qcportal.optimization import OptimizationSpecification, OptimizationProtocols
1817
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
1918
from qcportal.singlepoint import SinglepointProtocols, QCSpecification
2019
from qcportal.utils import recursive_normalizer

qcfractal/qcfractal/components/optimization/record_socket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def add_specifications(
308308
to_add = []
309309

310310
for opt_spec in opt_specs:
311-
protocols_dict = opt_spec.protocols.dict(exclude_defaults=True)
311+
protocols_dict = opt_spec.protocols.dict(exclude_defaults=True, exclude_unset=True)
312312

313313
# Don't include lower specifications in the hash
314314
opt_spec_dict = opt_spec.dict(exclude={"protocols", "qc_specification"})

qcfractal/qcfractal/components/optimization/test_record_socket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_optimization_socket_task_spec(
108108

109109
task_input = t.function_kwargs["input_data"]
110110
assert task_input["keywords"] == kw_with_prog
111-
assert task_input["protocols"] == spec.protocols.dict(exclude_defaults=True)
111+
assert task_input["protocols"] == spec.protocols.dict(exclude_defaults=True, exclude_unset=True)
112112

113113
# Forced to gradient in the qcschema input
114114
assert task_input["input_specification"]["driver"] == SinglepointDriver.gradient

qcfractal/qcfractal/components/singlepoint/record_socket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def add_specifications(
234234
to_add = []
235235

236236
for qc_spec in qc_specs:
237-
protocols_dict = qc_spec.protocols.dict(exclude_defaults=True)
237+
protocols_dict = qc_spec.protocols.dict(exclude_defaults=True, exclude_unset=True)
238238

239239
# TODO - if error_correction is manually specified as the default, then it will be an empty dict
240240
if "error_correction" in protocols_dict:

qcfractal/qcfractal/components/singlepoint/test_record_socket.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def test_singlepoint_socket_task_spec(
108108
for t in tasks:
109109
function_kwargs = t.function_kwargs
110110
assert function_kwargs["input_data"]["model"] == {"method": spec.method, "basis": spec.basis}
111-
assert function_kwargs["input_data"]["protocols"] == spec.protocols.dict(exclude_defaults=True)
111+
assert function_kwargs["input_data"]["protocols"] == spec.protocols.dict(
112+
exclude_defaults=True, exclude_unset=True
113+
)
112114
assert function_kwargs["input_data"]["keywords"] == spec.keywords
113115
assert function_kwargs["program"] == spec.program
114116
assert t.compute_tag == "tag1"

qcfractal/qcfractal/components/torsiondrive/testing_helpers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
except ImportError:
99
import pydantic
1010
from qcelemental.models import Molecule, FailedOperation, ComputeError, OptimizationResult as QCEl_OptimizationResult
11-
from qcelemental.models.procedures import OptimizationProtocols
1211

1312
from qcarchivetesting.helpers import read_procedure_data, read_record_data
1413
from qcfractal.components.torsiondrive.record_db_models import TorsiondriveRecordORM
1514
from qcfractal.testing_helpers import run_service
16-
from qcportal.optimization import OptimizationSpecification
15+
from qcportal.optimization import OptimizationSpecification, OptimizationProtocols
1716
from qcportal.record_models import PriorityEnum, RecordStatusEnum, RecordTask
1817
from qcportal.singlepoint import SinglepointProtocols, QCSpecification
1918
from qcportal.torsiondrive import TorsiondriveSpecification, TorsiondriveKeywords, TorsiondriveRecord

qcportal/qcportal/optimization/record_models.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,56 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4+
from enum import Enum
5+
from typing import Iterable
46
from typing import Optional, Union, Any, List, Dict
57

6-
try:
7-
from pydantic.v1 import BaseModel, Field, constr, validator, Extra
8-
except ImportError:
9-
from pydantic import BaseModel, Field, constr, validator, Extra
8+
from pydantic.v1 import BaseModel, Field, constr, validator, Extra
109
from qcelemental.models import Molecule
1110
from qcelemental.models.procedures import (
1211
OptimizationResult,
13-
OptimizationProtocols,
1412
QCInputSpecification,
15-
Model as AtomicResultModel,
1613
)
1714
from typing_extensions import Literal
18-
from typing import Iterable
1915

2016
from qcportal.base_models import RestModelBase
17+
from qcportal.cache import get_records_with_cache
2118
from qcportal.record_models import (
2219
BaseRecord,
2320
RecordAddBodyBase,
2421
RecordQueryFilters,
2522
RecordStatusEnum,
2623
compare_base_records,
2724
)
28-
from qcportal.utils import is_included
29-
from qcportal.cache import get_records_with_cache
3025
from qcportal.singlepoint import (
3126
SinglepointProtocols,
3227
SinglepointRecord,
3328
QCSpecification,
3429
SinglepointDriver,
3530
compare_singlepoint_records,
3631
)
32+
from qcportal.utils import is_included
33+
34+
35+
class TrajectoryProtocolEnum(str, Enum):
36+
"""
37+
Which gradient evaluations to keep in an optimization trajectory.
38+
"""
39+
40+
all = "all"
41+
initial_and_final = "initial_and_final"
42+
final = "final"
43+
none = "none"
44+
45+
46+
class OptimizationProtocols(BaseModel):
47+
"""
48+
Protocols regarding the manipulation of a Optimization output data.
49+
"""
3750

51+
trajectory: TrajectoryProtocolEnum = Field(
52+
TrajectoryProtocolEnum.all, description=str(TrajectoryProtocolEnum.__doc__)
53+
)
3854

3955
class OptimizationSpecification(BaseModel):
4056
"""
@@ -216,7 +232,7 @@ def to_qcschema_result(self) -> OptimizationResult:
216232
keywords=new_keywords,
217233
input_specification=QCInputSpecification(
218234
driver=SinglepointDriver.gradient, # forced
219-
model=AtomicResultModel(
235+
model=dict(
220236
method=self.specification.qc_specification.method,
221237
basis=self.specification.qc_specification.basis,
222238
),

qcportal/qcportal/record_models.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,30 @@
88
from typing import Optional, Dict, Any, List, Union, Iterable, Tuple, Type, Sequence, ClassVar, TypeVar
99

1010
from dateutil.parser import parse as date_parser
11-
12-
try:
13-
from pydantic.v1 import BaseModel, Extra, constr, validator, PrivateAttr, Field, parse_obj_as, root_validator
14-
except ImportError:
15-
from pydantic import BaseModel, Extra, constr, validator, PrivateAttr, Field, parse_obj_as, root_validator
16-
from qcelemental.models.results import Provenance
11+
from pydantic.v1 import BaseModel, Extra, constr, validator, PrivateAttr, Field, root_validator
1712

1813
from qcportal.base_models import (
1914
RestModelBase,
2015
QueryModelBase,
2116
QueryIteratorBase,
2217
)
23-
2418
from qcportal.cache import RecordCache, get_records_with_cache
2519
from qcportal.compression import CompressionEnum, decompress, get_compressed_ext
2620

2721
_T = TypeVar("_T")
2822

2923

24+
class Provenance(BaseModel):
25+
"""Provenance information."""
26+
27+
creator: str = Field(..., description="The name of the program, library, or person who created the object.")
28+
version: str = Field("", description="The version of the creator, blank otherwise")
29+
routine: str = Field("", description="The name of the routine or function within the creator, blank otherwise.")
30+
31+
class Config(BaseModel.Config):
32+
extra: str = "allow"
33+
34+
3035
class PriorityEnum(int, Enum):
3136
"""
3237
The priority of a Task. Higher priority will be pulled first.

qcportal/qcportal/singlepoint/record_models.py

Lines changed: 75 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,17 @@
44
from enum import Enum
55
from typing import Optional, Union, Any, List, Dict, Tuple
66

7-
try:
8-
from pydantic.v1 import BaseModel, Field, constr, validator, Extra, PrivateAttr
9-
except ImportError:
10-
from pydantic import BaseModel, Field, constr, validator, Extra, PrivateAttr
7+
from pydantic.v1 import BaseModel, Field, constr, validator, Extra, PrivateAttr
118
from qcelemental.models import Molecule
129
from qcelemental.models.results import (
1310
AtomicResult,
14-
Model as AtomicResultModel,
15-
AtomicResultProtocols as SinglepointProtocols,
1611
AtomicResultProperties,
1712
WavefunctionProperties,
1813
)
1914
from typing_extensions import Literal
2015

21-
from qcportal.compression import CompressionEnum, decompress
2216
from qcportal.base_models import RestModelBase
17+
from qcportal.compression import CompressionEnum, decompress
2318
from qcportal.record_models import (
2419
RecordStatusEnum,
2520
BaseRecord,
@@ -29,6 +24,24 @@
2924
)
3025

3126

27+
class Model(BaseModel):
28+
"""The computational molecular sciences model to run."""
29+
30+
method: str = Field( # type: ignore
31+
...,
32+
description="The quantum chemistry method to evaluate (e.g., B3LYP, PBE, ...). "
33+
"For MM, name of the force field.",
34+
)
35+
basis: Optional[Union[str, BasisSet]] = Field( # type: ignore
36+
None,
37+
description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
38+
"methods without basis sets. For molecular mechanics, name of the atom-typer.",
39+
)
40+
41+
class Config(BaseModel.Config):
42+
extra: str = "allow"
43+
44+
3245
class SinglepointDriver(str, Enum):
3346
# Copied from qcelemental to add "deferred"
3447
energy = "energy"
@@ -38,6 +51,59 @@ class SinglepointDriver(str, Enum):
3851
deferred = "deferred"
3952

4053

54+
class WavefunctionProtocolEnum(str, Enum):
55+
r"""Wavefunction to keep from a computation."""
56+
57+
all = "all"
58+
orbitals_and_eigenvalues = "orbitals_and_eigenvalues"
59+
occupations_and_eigenvalues = "occupations_and_eigenvalues"
60+
return_results = "return_results"
61+
none = "none"
62+
63+
64+
class ErrorCorrectionProtocol(BaseModel):
65+
r"""Configuration for how computationaal chemistry programs handle error correction
66+
"""
67+
68+
default_policy: bool = Field(
69+
True, description="Whether to allow error corrections to be used " "if not directly specified in `policies`"
70+
)
71+
policies: Optional[Dict[str, bool]] = Field(
72+
None,
73+
description="Settings that define whether specific error corrections are allowed. "
74+
"Keys are the name of a known error and values are whether it is allowed to be used.",
75+
)
76+
77+
def allows(self, policy: str):
78+
if self.policies is None:
79+
return self.default_policy
80+
return self.policies.get(policy, self.default_policy)
81+
82+
83+
class NativeFilesProtocolEnum(str, Enum):
84+
r"""Any program-specific files to keep from a computation."""
85+
86+
all = "all"
87+
input = "input"
88+
none = "none"
89+
90+
91+
class SinglepointProtocols(BaseModel):
92+
r"""Protocols regarding the manipulation of computational result data."""
93+
94+
wavefunction: WavefunctionProtocolEnum = Field(
95+
WavefunctionProtocolEnum.none, description=str(WavefunctionProtocolEnum.__doc__)
96+
)
97+
stdout: bool = Field(True, description="Primary output file to keep from the computation")
98+
error_correction: ErrorCorrectionProtocol = Field(
99+
default_factory=ErrorCorrectionProtocol, description="Policies for error correction"
100+
)
101+
native_files: NativeFilesProtocolEnum = Field(
102+
NativeFilesProtocolEnum.none,
103+
description="Policies for keeping processed files from the computation",
104+
)
105+
106+
41107
class QCSpecification(BaseModel):
42108
class Config:
43109
extra = Extra.forbid
@@ -57,7 +123,7 @@ class Config:
57123
"methods without basis sets.",
58124
)
59125
keywords: Dict[str, Any] = Field({}, description="Program-specific keywords to use for the computation")
60-
protocols: SinglepointProtocols = Field(SinglepointProtocols(), description=str(SinglepointProtocols.__base_doc__))
126+
protocols: SinglepointProtocols = Field(SinglepointProtocols())
61127

62128
@validator("basis", pre=True)
63129
def _convert_basis(cls, v):
@@ -178,7 +244,7 @@ def to_qcschema_result(self) -> AtomicResult:
178244

179245
return AtomicResult(
180246
driver=self.specification.driver,
181-
model=AtomicResultModel(
247+
model=dict(
182248
method=self.specification.method,
183249
basis=self.specification.basis,
184250
),

0 commit comments

Comments
 (0)