Skip to content

Commit 20dc5cb

Browse files
adam-grofeAdam Grofexinyi-joffrekikomiss
authored
Add batching support to sdk (#647)
* Add batching to sdk by uploading qcschema files to a container. * Refactor/simplify tests to use pytest regressions * Change input_data_uri back to blob based uri * Add toc blob to submission * Add further xyz validation * Correct the qcshema format. * Add pytest-regressions to conda env * Add pytest-regressions to ci * Add support for submitting batches of qcschema --------- Co-authored-by: Adam Grofe <[email protected]> Co-authored-by: Xinyi Joffre <[email protected]> Co-authored-by: kikomiss <[email protected]>
1 parent 6420d3e commit 20dc5cb

File tree

33 files changed

+10054
-11
lines changed

33 files changed

+10054
-11
lines changed

.ado/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ jobs:
5858
displayName: Set Python version
5959

6060
- script: |
61-
pip install pytest pytest-azurepipelines pytest-cov
61+
pip install pytest pytest-azurepipelines pytest-cov pytest-regressions
6262
displayName: Install pytest dependencies
6363
6464
- script: |

azure-quantum/azure/quantum/target/microsoft/elements/dft/job.py

Lines changed: 108 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
import collections.abc
2-
from typing import Any, Dict, Union
2+
import logging
3+
from typing import Any, Dict, Union, Optional
34
from azure.quantum.job import JobFailedWithResultsError
5+
from azure.quantum.job.base_job import BaseJob, ContentType
46
from azure.quantum.job.job import Job, DEFAULT_TIMEOUT
57
from azure.quantum._client.models import JobDetails
8+
from azure.quantum.workspace import Workspace
9+
10+
logger = logging.getLogger(__name__)
611

712
class MicrosoftElementsDftJob(Job):
813
"""
@@ -62,4 +67,105 @@ def _is_dft_failure_results(failure_results: Union[Dict[str, Any], str]) -> bool
6267
and "error" in failure_results["results"][0] \
6368
and isinstance(failure_results["results"][0]["error"], dict) \
6469
and "error_type" in failure_results["results"][0]["error"] \
65-
and "error_message" in failure_results["results"][0]["error"]
70+
and "error_message" in failure_results["results"][0]["error"]
71+
72+
@classmethod
73+
def from_input_data_container(
74+
cls,
75+
workspace: "Workspace",
76+
name: str,
77+
target: str,
78+
input_data: bytes,
79+
batch_input_blobs: Dict[str, bytes],
80+
content_type: ContentType = ContentType.json,
81+
blob_name: str = "inputData",
82+
encoding: str = "",
83+
job_id: str = None,
84+
container_name: str = None,
85+
provider_id: str = None,
86+
input_data_format: str = None,
87+
output_data_format: str = None,
88+
input_params: Dict[str, Any] = None,
89+
session_id: Optional[str] = None,
90+
**kwargs
91+
) -> "BaseJob":
92+
"""Create a new Azure Quantum job based on a list of input_data.
93+
94+
:param workspace: Azure Quantum workspace to submit the input_data to
95+
:type workspace: Workspace
96+
:param name: Name of the job
97+
:type name: str
98+
:param target: Azure Quantum target
99+
:type target: str
100+
:param input_data: Raw input data to submit
101+
:type input_data: Dict
102+
:param blob_name: Dict of Input data json to gives a table of contents
103+
:type batch_input_blobs: Dict
104+
:param blob_name: Dict of QcSchema Data where the key is the blob name to store it in the container
105+
:type blob_name: str
106+
:param content_type: Content type, e.g. "application/json"
107+
:type content_type: ContentType
108+
:param encoding: input_data encoding, e.g. "gzip", defaults to empty string
109+
:type encoding: str
110+
:param job_id: Job ID, defaults to None
111+
:type job_id: str
112+
:param container_name: Container name, defaults to None
113+
:type container_name: str
114+
:param provider_id: Provider ID, defaults to None
115+
:type provider_id: str
116+
:param input_data_format: Input data format, defaults to None
117+
:type input_data_format: str
118+
:param output_data_format: Output data format, defaults to None
119+
:type output_data_format: str
120+
:param input_params: Input parameters, defaults to None
121+
:type input_params: Dict[str, Any]
122+
:param input_params: Input params for job
123+
:type input_params: Dict[str, Any]
124+
:return: Azure Quantum Job
125+
:rtype: Job
126+
"""
127+
# Generate job ID if not specified
128+
if job_id is None:
129+
job_id = cls.create_job_id()
130+
131+
# Create container if it does not yet exist
132+
container_uri = workspace.get_container_uri(
133+
job_id=job_id,
134+
container_name=container_name
135+
)
136+
logger.debug(f"Container URI: {container_uri}")
137+
138+
# Upload Input Data
139+
input_data_uri = cls.upload_input_data(
140+
container_uri=container_uri,
141+
input_data=input_data,
142+
content_type=content_type,
143+
blob_name=blob_name,
144+
encoding=encoding,
145+
)
146+
147+
# Upload data to container
148+
for blob_name, input_data_item in batch_input_blobs.items():
149+
cls.upload_input_data(
150+
container_uri=container_uri,
151+
input_data=input_data_item,
152+
content_type=content_type,
153+
blob_name=blob_name,
154+
encoding=encoding,
155+
)
156+
157+
# Create and submit job
158+
return cls.from_storage_uri(
159+
workspace=workspace,
160+
job_id=job_id,
161+
target=target,
162+
input_data_uri=input_data_uri,
163+
container_uri=container_uri,
164+
name=name,
165+
input_data_format=input_data_format,
166+
output_data_format=output_data_format,
167+
provider_id=provider_id,
168+
input_params=input_params,
169+
session_id=session_id,
170+
**kwargs
171+
)

azure-quantum/azure/quantum/target/microsoft/elements/dft/target.py

Lines changed: 157 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@
55
from azure.quantum.target.target import Target
66
from azure.quantum.workspace import Workspace
77
from azure.quantum.target.params import InputParams
8-
from typing import Any, Dict, Type, Union
8+
from typing import Any, Dict, Type, Union, List
99
from .job import MicrosoftElementsDftJob
10+
from pathlib import Path
11+
import copy
12+
import json
1013

1114

1215
class MicrosoftElementsDft(Target):
@@ -73,15 +76,161 @@ def submit(self,
7376
if shots is not None:
7477
warnings.warn("The 'shots' parameter is ignored in Microsoft Elements Dft job.")
7578

76-
return super().submit(
77-
input_data=input_data,
78-
name=name,
79-
shots=shots,
80-
input_params=input_params,
81-
**kwargs
82-
)
79+
if isinstance(input_data, list):
80+
81+
qcschema_data = self._assemble_qcshema_from_files(input_data, input_params)
82+
83+
qcschema_blobs = {}
84+
for i in range(len(qcschema_data)):
85+
qcschema_blobs[f"inputData_{i}"] = self._encode_input_data(qcschema_data[i])
8386

87+
toc_str = self._create_table_of_contents(input_data, list(qcschema_blobs.keys()))
88+
toc = self._encode_input_data(toc_str)
89+
90+
return self._get_job_class().from_input_data_container(
91+
workspace=self.workspace,
92+
name=name,
93+
target=self.name,
94+
input_data=toc,
95+
batch_input_blobs=qcschema_blobs,
96+
input_params={ 'numberOfFiles': len(qcschema_data), "inputFiles": list(qcschema_blobs.keys()), **input_params },
97+
content_type=kwargs.pop('content_type', self.content_type),
98+
encoding=kwargs.pop('encoding', self.encoding),
99+
provider_id=self.provider_id,
100+
input_data_format=kwargs.pop('input_data_format', 'microsoft.qc-schema.v1'),
101+
output_data_format=kwargs.pop('output_data_format', self.output_data_format),
102+
session_id=self.get_latest_session_id(),
103+
**kwargs
104+
)
105+
else:
106+
return super().submit(
107+
input_data=input_data,
108+
name=name,
109+
shots=shots,
110+
input_params=input_params,
111+
**kwargs
112+
)
113+
114+
115+
116+
@classmethod
117+
def _assemble_qcshema_from_files(self, input_data: List[str], input_params: Dict) -> str:
118+
"""
119+
Convert a list of files to a list of qcshema objects serialized in json.
120+
"""
121+
122+
qcshema_objects = []
123+
for file in input_data:
124+
file_path = Path(file)
125+
if not file_path.exists():
126+
raise FileNotFoundError(f"File {file} does not exist.")
127+
128+
file_data = file_path.read_text()
129+
if file_path.suffix == '.xyz':
130+
mol = self._xyz_to_qcschema_mol(file_data)
131+
new_qcschema = self._new_qcshema( input_params, mol )
132+
qcshema_objects.append(new_qcschema)
133+
elif file_path.suffix == '.json':
134+
if input_params is not None and len(input_params.keys()) > 0:
135+
warnings.warn('Input parameters were given along with a QcSchema file which contains parameters, using QcSchema parameters as is.')
136+
with open(file_path, 'r') as f:
137+
qcshema_objects.append( json.load(f) )
138+
else:
139+
raise ValueError(f"File type '{file_path.suffix}' for file '{file_path}' is not supported. Please use xyz or QcSchema file formats.")
140+
141+
return qcshema_objects
142+
143+
@classmethod
144+
def _new_qcshema( self, input_params: Dict[str,Any], mol: Dict[str,Any], ) -> Dict[str, Any]:
145+
"""
146+
Create a new default qcshema object.
147+
"""
148+
149+
if input_params.get("driver") == "go":
150+
copy_input_params = copy.deepcopy(input_params)
151+
copy_input_params["driver"] = "gradient"
152+
new_object = {
153+
"schema_name": "qcschema_optimization_input",
154+
"schema_version": 1,
155+
"initial_molecule": mol,
156+
}
157+
if copy_input_params.get("keywords") and copy_input_params["keywords"].get("geometryOptimization"):
158+
new_object["keywords"] = copy_input_params["keywords"].pop("geometryOptimization")
159+
new_object["input_specification"] = copy_input_params
160+
return new_object
161+
elif input_params.get("driver") == "bomd":
162+
copy_input_params = copy.deepcopy(input_params)
163+
copy_input_params["driver"] = "gradient"
164+
new_object = {
165+
"schema_name": "madft_molecular_dynamics_input",
166+
"schema_version": 1,
167+
"initial_molecule": mol,
168+
}
169+
if copy_input_params.get("keywords") and copy_input_params["keywords"].get("molecularDynamics"):
170+
new_object["keywords"] = copy_input_params["keywords"].pop("molecularDynamics")
171+
new_object["input_specification"] = copy_input_params
172+
return new_object
173+
else:
174+
new_object = copy.deepcopy(input_params)
175+
new_object.update({
176+
"schema_name": "qcschema_input",
177+
"schema_version": 1,
178+
"molecule": mol,
179+
})
180+
return new_object
181+
182+
183+
@classmethod
184+
def _xyz_to_qcschema_mol(self, file_data: str ) -> Dict[str, Any]:
185+
"""
186+
Convert xyz format to qcschema molecule.
187+
"""
188+
189+
lines = file_data.split("\n")
190+
if len(lines) < 3:
191+
raise ValueError("Invalid xyz format.")
192+
n_atoms = int(lines.pop(0))
193+
comment = lines.pop(0)
194+
mol = {
195+
"geometry": [],
196+
"symbols": [],
197+
}
198+
for line in lines:
199+
if line:
200+
elements = line.split()
201+
if len(elements) < 4:
202+
raise ValueError("Invalid xyz format.")
203+
symbol, x, y, z = elements
204+
mol["symbols"].append(symbol)
205+
mol["geometry"] += [float(x), float(y), float(z)]
206+
else:
207+
break
208+
209+
if len(mol["symbols"]) != n_atoms:
210+
raise ValueError("Number of inputs does not match the number of atoms in xyz file.")
211+
212+
return mol
84213

85214
@classmethod
86215
def _get_job_class(cls) -> Type[Job]:
87216
return MicrosoftElementsDftJob
217+
218+
@classmethod
219+
def _create_table_of_contents(cls, input_files: List[str], input_blobs: List[str]) -> Dict[str,Any]:
220+
"""Create the table of contents for a batched job that contains a description of file and the mapping between the file names and the blob names"""
221+
222+
assert len(input_files) == len(input_blobs), "Internal error: number of blobs is not that same as the number of files."
223+
224+
toc = []
225+
for i in range(len(input_files)):
226+
toc.append(
227+
{
228+
"inputFileName": input_files[i],
229+
"qcschemaBlobName": input_blobs[i],
230+
}
231+
)
232+
233+
return {
234+
"description": "This files contains the mapping between the xyz file name that were submitted and the qcschema blobs that are used for the calculation.",
235+
"tableOfContents": toc,
236+
}

azure-quantum/environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ dependencies:
66
- python=3.9
77
- pip>=22.3.1
88
- pytest>=7.1.2
9+
- pytest-regressions
910
- pip:
1011
- -e .[all]

0 commit comments

Comments
 (0)