Skip to content

Commit baa8851

Browse files
Add NoBatch mode (#602)
* add nobatch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 536bcde commit baa8851

File tree

3 files changed

+74
-37
lines changed

3 files changed

+74
-37
lines changed

openff/evaluator/_tests/test_server.py

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,26 @@
22
Units tests for the openff.evaluator.server module.
33
"""
44

5+
import os
56
import tempfile
67
from time import sleep
78

9+
import pytest
810
from openff.units import unit
911

1012
from openff.evaluator._tests.utils import create_dummy_property
1113
from openff.evaluator.backends.dask import DaskLocalCluster
1214
from openff.evaluator.client import EvaluatorClient, RequestOptions
1315
from openff.evaluator.datasets import PhysicalPropertyDataSet
16+
from openff.evaluator.forcefield import SmirnoffForceFieldSource
1417
from openff.evaluator.layers import (
1518
CalculationLayer,
1619
CalculationLayerResult,
1720
CalculationLayerSchema,
1821
calculation_layer,
1922
)
2023
from openff.evaluator.properties import Density, EnthalpyOfVaporization
21-
from openff.evaluator.server.server import Batch, EvaluatorServer
24+
from openff.evaluator.server.server import Batch, BatchMode, EvaluatorServer
2225
from openff.evaluator.substances import Substance
2326
from openff.evaluator.thermodynamics import ThermodynamicState
2427
from openff.evaluator.utils.utils import temporarily_change_directory
@@ -106,7 +109,8 @@ def test_launch_batch():
106109
assert len(batch.unsuccessful_properties) == 1
107110

108111

109-
def test_same_component_batching():
112+
@pytest.fixture
113+
def c_o_dataset():
110114
thermodynamic_state = ThermodynamicState(
111115
temperature=1.0 * unit.kelvin, pressure=1.0 * unit.atmosphere
112116
)
@@ -134,61 +138,66 @@ def test_same_component_batching():
134138
value=0.0 * unit.kilojoule / unit.mole,
135139
),
136140
)
141+
return data_set
137142

143+
144+
@pytest.fixture
145+
def dataset_submission(c_o_dataset):
138146
options = RequestOptions()
139147

140148
submission = EvaluatorClient._Submission()
141-
submission.dataset = data_set
149+
submission.dataset = c_o_dataset
142150
submission.options = options
151+
submission.force_field_source = SmirnoffForceFieldSource.from_path(
152+
"openff-2.1.0.offxml"
153+
)
154+
155+
return submission
143156

157+
158+
def test_same_component_batching(dataset_submission, tmp_path):
159+
os.chdir(tmp_path)
144160
with DaskLocalCluster() as calculation_backend:
145161
server = EvaluatorServer(calculation_backend)
146-
batches = server._batch_by_same_component(submission, "")
162+
batches = server._batch_by_same_component(dataset_submission, "")
147163

148164
assert len(batches) == 2
149165

150166
assert len(batches[0].queued_properties) == 2
151167
assert len(batches[1].queued_properties) == 2
152168

153169

154-
def test_shared_component_batching():
155-
thermodynamic_state = ThermodynamicState(
156-
temperature=1.0 * unit.kelvin, pressure=1.0 * unit.atmosphere
157-
)
158-
159-
data_set = PhysicalPropertyDataSet()
160-
data_set.add_properties(
161-
Density(
162-
thermodynamic_state=thermodynamic_state,
163-
substance=Substance.from_components("O", "C"),
164-
value=0.0 * unit.kilogram / unit.meter**3,
165-
),
166-
EnthalpyOfVaporization(
167-
thermodynamic_state=thermodynamic_state,
168-
substance=Substance.from_components("O", "C"),
169-
value=0.0 * unit.kilojoule / unit.mole,
170-
),
171-
Density(
172-
thermodynamic_state=thermodynamic_state,
173-
substance=Substance.from_components("O", "CO"),
174-
value=0.0 * unit.kilogram / unit.meter**3,
175-
),
176-
EnthalpyOfVaporization(
177-
thermodynamic_state=thermodynamic_state,
178-
substance=Substance.from_components("O", "CO"),
179-
value=0.0 * unit.kilojoule / unit.mole,
180-
),
181-
)
170+
def test_shared_component_batching(dataset_submission, tmp_path):
171+
os.chdir(tmp_path)
172+
with DaskLocalCluster() as calculation_backend:
173+
server = EvaluatorServer(calculation_backend)
174+
batches = server._batch_by_shared_component(dataset_submission, "")
182175

183-
options = RequestOptions()
176+
assert len(batches) == 1
177+
assert len(batches[0].queued_properties) == 4
184178

185-
submission = EvaluatorClient._Submission()
186-
submission.dataset = data_set
187-
submission.options = options
188179

180+
def test_nobatching(dataset_submission, tmp_path):
181+
os.chdir(tmp_path)
189182
with DaskLocalCluster() as calculation_backend:
190183
server = EvaluatorServer(calculation_backend)
191-
batches = server._batch_by_shared_component(submission, "")
184+
batches = server._no_batch(dataset_submission, "")
192185

193186
assert len(batches) == 1
194187
assert len(batches[0].queued_properties) == 4
188+
assert batches[0].id == "batch_0000"
189+
190+
191+
def test_prepare_batches_nobatch(dataset_submission, tmp_path):
192+
os.chdir(tmp_path)
193+
dataset_submission.options.batch_mode = BatchMode.NoBatch
194+
with DaskLocalCluster() as calculation_backend:
195+
server = EvaluatorServer(calculation_backend)
196+
server._batch_ids_per_client_id["request_id"] = []
197+
batches = server._prepare_batches(dataset_submission, "request_id")
198+
199+
assert len(batches) == 1
200+
assert len(batches[0].queued_properties) == 4
201+
assert batches[0].id == "batch_0000"
202+
assert server._queued_batches["batch_0000"] is batches[0]
203+
assert server._batch_ids_per_client_id["request_id"] == ["batch_0000"]

openff/evaluator/client/client.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,16 @@ class BatchMode(Enum):
8383
common component will be batched together. E.g.The densities of 80:20 and 20:80
8484
mixtures of ethanol and water, and the pure densities of ethanol and water would
8585
be batched together.
86+
* NoBatch: No batching will be performed. Each property will be estimated in a
87+
single, sequentially-increasing batch.
8688
8789
Properties will only be marked as estimated by the server when all properties in a
8890
single batch are completed.
8991
"""
9092

9193
SameComponents = "SameComponents"
9294
SharedComponents = "SharedComponents"
95+
NoBatch = "NoBatch"
9396

9497

9598
class Request(AttributeClass):

openff/evaluator/server/server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,29 @@ def _query_request_status(self, client_request_id):
280280

281281
return request_results, None
282282

283+
def _no_batch(self, submission, force_field_id):
284+
"""Returns a single Batch."""
285+
286+
reserved_batch_ids = {
287+
*self._queued_batches.keys(),
288+
*self._finished_batches.keys(),
289+
}
290+
batch = Batch()
291+
batch.force_field_id = force_field_id
292+
batch.enable_data_caching = self._enable_data_caching
293+
batch.queued_properties = list(submission.dataset.properties)
294+
batch.options = RequestOptions.parse_json(submission.options.json())
295+
batch.parameter_gradient_keys = copy.deepcopy(
296+
submission.parameter_gradient_keys
297+
)
298+
299+
n_batchs = len(reserved_batch_ids)
300+
batch.id = f"batch_{n_batchs:04d}"
301+
while batch.id in reserved_batch_ids:
302+
n_batchs += 1
303+
batch.id = f"batch_{n_batchs:04d}"
304+
return [batch]
305+
283306
def _batch_by_same_component(self, submission, force_field_id):
284307
"""Batches a set of requested properties based on which substance they were
285308
measured for. Properties which were measured for substances containing the
@@ -447,6 +470,8 @@ def _prepare_batches(self, submission, request_id):
447470
batches = self._batch_by_same_component(submission, force_field_id)
448471
elif batch_mode == BatchMode.SharedComponents:
449472
batches = self._batch_by_shared_component(submission, force_field_id)
473+
elif batch_mode == BatchMode.NoBatch:
474+
batches = self._no_batch(submission, force_field_id)
450475
else:
451476
raise NotImplementedError()
452477

0 commit comments

Comments
 (0)