44from enum import Enum
55from typing import Optional , Union , Any , List , Dict , Tuple
66
7- try :
8- from pydantic .v1 import BaseModel , Field , constr , validator , Extra , PrivateAttr
9- except ImportError :
10- from pydantic import BaseModel , Field , constr , validator , Extra , PrivateAttr
7+ from pydantic .v1 import BaseModel , Field , constr , validator , Extra , PrivateAttr
118from qcelemental .models import Molecule
129from qcelemental .models .results import (
1310 AtomicResult ,
14- Model as AtomicResultModel ,
15- AtomicResultProtocols as SinglepointProtocols ,
1611 AtomicResultProperties ,
1712 WavefunctionProperties ,
1813)
1914from typing_extensions import Literal
2015
21- from qcportal .compression import CompressionEnum , decompress
2216from qcportal .base_models import RestModelBase
17+ from qcportal .compression import CompressionEnum , decompress
2318from qcportal .record_models import (
2419 RecordStatusEnum ,
2520 BaseRecord ,
2924)
3025
3126
27+ class Model (BaseModel ):
28+ """The computational molecular sciences model to run."""
29+
30+ method : str = Field ( # type: ignore
31+ ...,
32+ description = "The quantum chemistry method to evaluate (e.g., B3LYP, PBE, ...). "
33+ "For MM, name of the force field." ,
34+ )
35+ basis : Optional [Union [str , BasisSet ]] = Field ( # type: ignore
36+ None ,
37+ description = "The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
38+ "methods without basis sets. For molecular mechanics, name of the atom-typer." ,
39+ )
40+
41+ class Config (BaseModel .Config ):
42+ extra : str = "allow"
43+
44+
3245class SinglepointDriver (str , Enum ):
3346 # Copied from qcelemental to add "deferred"
3447 energy = "energy"
@@ -38,6 +51,59 @@ class SinglepointDriver(str, Enum):
3851 deferred = "deferred"
3952
4053
54+ class WavefunctionProtocolEnum (str , Enum ):
55+ r"""Wavefunction to keep from a computation."""
56+
57+ all = "all"
58+ orbitals_and_eigenvalues = "orbitals_and_eigenvalues"
59+ occupations_and_eigenvalues = "occupations_and_eigenvalues"
60+ return_results = "return_results"
61+ none = "none"
62+
63+
64+ class ErrorCorrectionProtocol (BaseModel ):
65+ r"""Configuration for how computationaal chemistry programs handle error correction
66+ """
67+
68+ default_policy : bool = Field (
69+ True , description = "Whether to allow error corrections to be used " "if not directly specified in `policies`"
70+ )
71+ policies : Optional [Dict [str , bool ]] = Field (
72+ None ,
73+ description = "Settings that define whether specific error corrections are allowed. "
74+ "Keys are the name of a known error and values are whether it is allowed to be used." ,
75+ )
76+
77+ def allows (self , policy : str ):
78+ if self .policies is None :
79+ return self .default_policy
80+ return self .policies .get (policy , self .default_policy )
81+
82+
83+ class NativeFilesProtocolEnum (str , Enum ):
84+ r"""Any program-specific files to keep from a computation."""
85+
86+ all = "all"
87+ input = "input"
88+ none = "none"
89+
90+
91+ class SinglepointProtocols (BaseModel ):
92+ r"""Protocols regarding the manipulation of computational result data."""
93+
94+ wavefunction : WavefunctionProtocolEnum = Field (
95+ WavefunctionProtocolEnum .none , description = str (WavefunctionProtocolEnum .__doc__ )
96+ )
97+ stdout : bool = Field (True , description = "Primary output file to keep from the computation" )
98+ error_correction : ErrorCorrectionProtocol = Field (
99+ default_factory = ErrorCorrectionProtocol , description = "Policies for error correction"
100+ )
101+ native_files : NativeFilesProtocolEnum = Field (
102+ NativeFilesProtocolEnum .none ,
103+ description = "Policies for keeping processed files from the computation" ,
104+ )
105+
106+
41107class QCSpecification (BaseModel ):
42108 class Config :
43109 extra = Extra .forbid
@@ -57,7 +123,7 @@ class Config:
57123 "methods without basis sets." ,
58124 )
59125 keywords : Dict [str , Any ] = Field ({}, description = "Program-specific keywords to use for the computation" )
60- protocols : SinglepointProtocols = Field (SinglepointProtocols (), description = str ( SinglepointProtocols . __base_doc__ ) )
126+ protocols : SinglepointProtocols = Field (SinglepointProtocols ())
61127
62128 @validator ("basis" , pre = True )
63129 def _convert_basis (cls , v ):
@@ -178,7 +244,7 @@ def to_qcschema_result(self) -> AtomicResult:
178244
179245 return AtomicResult (
180246 driver = self .specification .driver ,
181- model = AtomicResultModel (
247+ model = dict (
182248 method = self .specification .method ,
183249 basis = self .specification .basis ,
184250 ),
0 commit comments