Skip to content

Commit 1745a10

Browse files
authored
Overhaul backends contexts code (#363)
2 parents a2507e4 + 71dabb6 commit 1745a10

File tree

8 files changed

+225
-224
lines changed

8 files changed

+225
-224
lines changed

pysages/backends/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
# SPDX-License-Identifier: MIT
22
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
33

4-
from .core import ( # noqa: E402, F401
5-
JaxMDContext,
6-
JaxMDContextState,
7-
SamplingContext,
8-
supported_backends,
9-
)
4+
from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401
5+
from .core import SamplingContext, supported_backends # noqa: E402, F401

pysages/backends/ase.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
build_data_querier,
1717
)
1818
from pysages.backends.utils import view
19-
from pysages.typing import Callable, NamedTuple
19+
from pysages.typing import Callable
2020
from pysages.utils import ToCPU, copy
2121

2222

@@ -29,7 +29,7 @@ class Sampler(Calculator):
2929
"""
3030

3131
def __init__(self, context, method_bundle, callback: Callable):
32-
initial_snapshot, initialize, mehod_update = method_bundle
32+
initial_snapshot, initialize, method_update = method_bundle
3333

3434
atoms = context.atoms
3535
self.implemented_properties = atoms.calc.implemented_properties
@@ -41,7 +41,7 @@ def __init__(self, context, method_bundle, callback: Callable):
4141
self.callback = callback
4242
self.snapshot = initial_snapshot
4343
self.state = initialize()
44-
self.update = mehod_update
44+
self.update = method_update
4545

4646
sig = signature(atoms.calc.calculate).parameters
4747
self._calculator = atoms.calc
@@ -151,10 +151,6 @@ def dimensionality():
151151
return helpers
152152

153153

154-
class View(NamedTuple):
155-
synchronize: Callable
156-
157-
158154
def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
159155
"""
160156
Entry point for the backend code, it gets called when the simulation
@@ -166,6 +162,5 @@ def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
166162
helpers = build_helpers(sampling_context, sampling_method)
167163
method_bundle = sampling_method.build(snapshot, helpers)
168164
sampler = Sampler(context, method_bundle, callback)
169-
sampling_context.view = View((lambda: None))
170165
sampling_context.run = context.run
171166
return sampler

pysages/backends/contexts.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# SPDX-License-Identifier: MIT
2+
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES
3+
4+
"""
5+
This module defines "Context" classes for backends that do not provide a dedicated Python
6+
class to hold the simulation data.
7+
"""
8+
9+
from pysages.typing import Any, Callable, JaxArray, NamedTuple, Optional
10+
11+
JaxMDState = Any
12+
13+
14+
class JaxMDContextState(NamedTuple):
15+
"""
16+
Provides an interface for the data structure returned by `JaxMDContext.init_fn` and
17+
expected as the single argument of `JaxMDContext.step_fn`.
18+
19+
Arguments
20+
---------
21+
state: JaxMDState
22+
Holds the particle information and corresponds to the internal state of
23+
`jax_md.simulate` methods.
24+
25+
extras: Optional[dict]
26+
Additional arguments required by `JaxMDContext.step_fn`, these might include for
27+
instance, the neighbor list or the time step.
28+
"""
29+
30+
state: JaxMDState
31+
extras: Optional[dict]
32+
33+
34+
class JaxMDContext(NamedTuple):
35+
"""
36+
Provides an interface for the data structure expects from `generate_context` for
37+
`jax_md`-backed simulations.
38+
39+
Arguments
40+
---------
41+
init_fn: Callable[..., JaxMDContextState]
42+
Initilizes the `jax_md` state. Generally, this will be the `init_fn` of any
43+
of the simulation routines in `jax_md` (or wrappers around these).
44+
45+
step_fn: Callable[..., JaxMDContextState]
46+
Takes a state and advances a `jax_md` simulation by one step. Generally, this
47+
will be the `apply_fn` of any of the simulation routines in `jax_md` (or wrappers
48+
around these).
49+
50+
box: JaxArray
51+
Affine transformation from a unit hypercube to the simulation box.
52+
53+
dt: float
54+
Step size of the simulation.
55+
"""
56+
57+
init_fn: Callable[..., JaxMDContextState]
58+
step_fn: Callable[..., JaxMDContextState]
59+
box: JaxArray
60+
dt: float

pysages/backends/core.py

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,8 @@
33

44
from importlib import import_module
55

6-
from pysages.typing import Any, Callable, JaxArray, NamedTuple, Optional
7-
8-
JaxMDState = Any
9-
10-
11-
class JaxMDContextState(NamedTuple):
12-
"""
13-
Provides an interface for the data structure returned by `JaxMDContext.init_fn` and
14-
expected as the single argument of `JaxMDContext.step_fn`.
15-
16-
Arguments
17-
---------
18-
state: JaxMDState
19-
Holds the particle information and corresponds to the internal state of
20-
`jax_md.simulate` methods.
21-
22-
extras: Optional[dict]
23-
Additional arguments required by `JaxMDContext.step_fn`, these might include for
24-
instance, the neighbor list or the time step.
25-
"""
26-
27-
state: JaxMDState
28-
extras: Optional[dict]
29-
30-
31-
class JaxMDContext(NamedTuple):
32-
"""
33-
Provides an interface for the data structure expects from `generate_context` for
34-
`jax_md`-backed simulations.
35-
36-
Arguments
37-
---------
38-
init_fn: Callable[..., JaxMDContextState]
39-
Initilizes the `jax_md` state. Generally, this will be the `init_fn` of any
40-
of the simulation routines in `jax_md` (or wrappers around these).
41-
42-
step_fn: Callable[..., JaxMDContextState]
43-
Takes a state and advances a `jax_md` simulation by one step. Generally, this
44-
will be the `apply_fn` of any of the simulation routines in `jax_md` (or wrappers
45-
around these).
46-
47-
box: JaxArray
48-
Affine transformation from a unit hypercube to the simulation box.
49-
50-
dt: float
51-
Step size of the simulation.
52-
"""
53-
54-
init_fn: Callable[..., JaxMDContextState]
55-
step_fn: Callable[..., JaxMDContextState]
56-
box: JaxArray
57-
dt: float
6+
from pysages.backends.contexts import JaxMDContext
7+
from pysages.typing import Callable, Optional
588

599

6010
class SamplingContext:
@@ -95,14 +45,12 @@ def __init__(
9545

9646
self.context = context
9747
self.method = sampling_method
98-
self.view = None
9948
self.run = None
10049

10150
backend = import_module("." + self._backend_name, package="pysages.backends")
10251
self.sampler = backend.bind(self, callback, **kwargs)
10352

104-
# `self.view` and `self.run` *must* be set by the backend bind function.
105-
assert self.view is not None
53+
# `self.run` *must* be set by the backend bind function.
10654
assert self.run is not None
10755

10856
@property

0 commit comments

Comments
 (0)