Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 26 additions & 36 deletions openff/evaluator/_tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
Units tests for the openff.evaluator.server module.
"""

import os
import tempfile
from time import sleep

import pytest
from openff.units import unit

from openff.evaluator._tests.utils import create_dummy_property
Expand Down Expand Up @@ -106,7 +108,8 @@ def test_launch_batch():
assert len(batch.unsuccessful_properties) == 1


def test_same_component_batching():
@pytest.fixture
def c_o_dataset():
thermodynamic_state = ThermodynamicState(
temperature=1.0 * unit.kelvin, pressure=1.0 * unit.atmosphere
)
Expand Down Expand Up @@ -134,61 +137,48 @@ def test_same_component_batching():
value=0.0 * unit.kilojoule / unit.mole,
),
)
return data_set


@pytest.fixture
def dataset_submission(c_o_dataset):
options = RequestOptions()

submission = EvaluatorClient._Submission()
submission.dataset = data_set
submission.dataset = c_o_dataset
submission.options = options

return submission


def test_same_component_batching(dataset_submission, tmp_path):
os.chdir(tmp_path)
with DaskLocalCluster() as calculation_backend:
server = EvaluatorServer(calculation_backend)
batches = server._batch_by_same_component(submission, "")
batches = server._batch_by_same_component(dataset_submission, "")

assert len(batches) == 2

assert len(batches[0].queued_properties) == 2
assert len(batches[1].queued_properties) == 2


def test_shared_component_batching():
thermodynamic_state = ThermodynamicState(
temperature=1.0 * unit.kelvin, pressure=1.0 * unit.atmosphere
)

data_set = PhysicalPropertyDataSet()
data_set.add_properties(
Density(
thermodynamic_state=thermodynamic_state,
substance=Substance.from_components("O", "C"),
value=0.0 * unit.kilogram / unit.meter**3,
),
EnthalpyOfVaporization(
thermodynamic_state=thermodynamic_state,
substance=Substance.from_components("O", "C"),
value=0.0 * unit.kilojoule / unit.mole,
),
Density(
thermodynamic_state=thermodynamic_state,
substance=Substance.from_components("O", "CO"),
value=0.0 * unit.kilogram / unit.meter**3,
),
EnthalpyOfVaporization(
thermodynamic_state=thermodynamic_state,
substance=Substance.from_components("O", "CO"),
value=0.0 * unit.kilojoule / unit.mole,
),
)
def test_shared_component_batching(dataset_submission, tmp_path):
os.chdir(tmp_path)
with DaskLocalCluster() as calculation_backend:
server = EvaluatorServer(calculation_backend)
batches = server._batch_by_shared_component(dataset_submission, "")

options = RequestOptions()
assert len(batches) == 1
assert len(batches[0].queued_properties) == 4

submission = EvaluatorClient._Submission()
submission.dataset = data_set
submission.options = options

def test_nobatching(dataset_submission, tmp_path):
os.chdir(tmp_path)
with DaskLocalCluster() as calculation_backend:
server = EvaluatorServer(calculation_backend)
batches = server._batch_by_shared_component(submission, "")
batches = server._no_batch(dataset_submission, "")

assert len(batches) == 1
assert len(batches[0].queued_properties) == 4
assert batches[0].id == "batch_0000"
3 changes: 3 additions & 0 deletions openff/evaluator/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,16 @@ class BatchMode(Enum):
common component will be batched together. E.g.The densities of 80:20 and 20:80
mixtures of ethanol and water, and the pure densities of ethanol and water would
be batched together.
* NoBatch: No batching will be performed. Each property will be estimated in a
single, sequentially-increasing batch.

Properties will only be marked as estimated by the server when all properties in a
single batch are completed.
"""

SameComponents = "SameComponents"
SharedComponents = "SharedComponents"
NoBatch = "NoBatch"


class Request(AttributeClass):
Expand Down
25 changes: 25 additions & 0 deletions openff/evaluator/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,29 @@ def _query_request_status(self, client_request_id):

return request_results, None

def _no_batch(self, submission, force_field_id):
"""Returns a single Batch."""

reserved_batch_ids = {
*self._queued_batches.keys(),
*self._finished_batches.keys(),
}
batch = Batch()
batch.force_field_id = force_field_id
batch.enable_data_caching = self._enable_data_caching
batch.queued_properties = list(submission.dataset.properties)
batch.options = RequestOptions.parse_json(submission.options.json())
batch.parameter_gradient_keys = copy.deepcopy(
submission.parameter_gradient_keys
)

n_batchs = len(reserved_batch_ids)
batch.id = f"batch_{n_batchs:04d}"
while batch.id in reserved_batch_ids:
n_batchs += 1
batch.id = f"batch_{n_batchs:04d}"
return [batch]

def _batch_by_same_component(self, submission, force_field_id):
"""Batches a set of requested properties based on which substance they were
measured for. Properties which were measured for substances containing the
Expand Down Expand Up @@ -434,6 +457,8 @@ def _prepare_batches(self, submission, request_id):
batches = self._batch_by_same_component(submission, force_field_id)
elif batch_mode == BatchMode.SharedComponents:
batches = self._batch_by_shared_component(submission, force_field_id)
elif batch_mode == BatchMode.NoBatch:
batches = self._no_batch(submission, force_field_id)
Comment on lines +460 to +461
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a test that covers these lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- added a test for _prepare_batches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS This code path should be used in the tests in #604 and #608.

else:
raise NotImplementedError()

Expand Down
Loading