Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions pysages/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

from .core import ( # noqa: E402, F401
JaxMDContext,
JaxMDContextState,
SamplingContext,
supported_backends,
)
from .contexts import JaxMDContext, JaxMDContextState # noqa: E402, F401
from .core import SamplingContext, supported_backends # noqa: E402, F401
11 changes: 3 additions & 8 deletions pysages/backends/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
build_data_querier,
)
from pysages.backends.utils import view
from pysages.typing import Callable, NamedTuple
from pysages.typing import Callable
from pysages.utils import ToCPU, copy


Expand All @@ -29,7 +29,7 @@ class Sampler(Calculator):
"""

def __init__(self, context, method_bundle, callback: Callable):
initial_snapshot, initialize, mehod_update = method_bundle
initial_snapshot, initialize, method_update = method_bundle

atoms = context.atoms
self.implemented_properties = atoms.calc.implemented_properties
Expand All @@ -41,7 +41,7 @@ def __init__(self, context, method_bundle, callback: Callable):
self.callback = callback
self.snapshot = initial_snapshot
self.state = initialize()
self.update = mehod_update
self.update = method_update

sig = signature(atoms.calc.calculate).parameters
self._calculator = atoms.calc
Expand Down Expand Up @@ -151,10 +151,6 @@ def dimensionality():
return helpers


class View(NamedTuple):
synchronize: Callable


def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
"""
Entry point for the backend code, it gets called when the simulation
Expand All @@ -166,6 +162,5 @@ def bind(sampling_context: SamplingContext, callback: Callable, **kwargs):
helpers = build_helpers(sampling_context, sampling_method)
method_bundle = sampling_method.build(snapshot, helpers)
sampler = Sampler(context, method_bundle, callback)
sampling_context.view = View((lambda: None))
sampling_context.run = context.run
return sampler
60 changes: 60 additions & 0 deletions pysages/backends/contexts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# SPDX-License-Identifier: MIT
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

"""
This module defines "Context" classes for backends that do not provide a dedicated Python
class to hold the simulation data.
"""

from pysages.typing import Any, Callable, JaxArray, NamedTuple, Optional

JaxMDState = Any


class JaxMDContextState(NamedTuple):
"""
Provides an interface for the data structure returned by `JaxMDContext.init_fn` and
expected as the single argument of `JaxMDContext.step_fn`.

Arguments
---------
state: JaxMDState
Holds the particle information and corresponds to the internal state of
`jax_md.simulate` methods.

extras: Optional[dict]
Additional arguments required by `JaxMDContext.step_fn`, these might include for
instance, the neighbor list or the time step.
"""

state: JaxMDState
extras: Optional[dict]


class JaxMDContext(NamedTuple):
"""
Provides an interface for the data structure expects from `generate_context` for
`jax_md`-backed simulations.

Arguments
---------
init_fn: Callable[..., JaxMDContextState]
Initilizes the `jax_md` state. Generally, this will be the `init_fn` of any
of the simulation routines in `jax_md` (or wrappers around these).

step_fn: Callable[..., JaxMDContextState]
Takes a state and advances a `jax_md` simulation by one step. Generally, this
will be the `apply_fn` of any of the simulation routines in `jax_md` (or wrappers
around these).

box: JaxArray
Affine transformation from a unit hypercube to the simulation box.

dt: float
Step size of the simulation.
"""

init_fn: Callable[..., JaxMDContextState]
step_fn: Callable[..., JaxMDContextState]
box: JaxArray
dt: float
58 changes: 3 additions & 55 deletions pysages/backends/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,58 +3,8 @@

from importlib import import_module

from pysages.typing import Any, Callable, JaxArray, NamedTuple, Optional

JaxMDState = Any


class JaxMDContextState(NamedTuple):
"""
Provides an interface for the data structure returned by `JaxMDContext.init_fn` and
expected as the single argument of `JaxMDContext.step_fn`.

Arguments
---------
state: JaxMDState
Holds the particle information and corresponds to the internal state of
`jax_md.simulate` methods.

extras: Optional[dict]
Additional arguments required by `JaxMDContext.step_fn`, these might include for
instance, the neighbor list or the time step.
"""

state: JaxMDState
extras: Optional[dict]


class JaxMDContext(NamedTuple):
"""
Provides an interface for the data structure expects from `generate_context` for
`jax_md`-backed simulations.

Arguments
---------
init_fn: Callable[..., JaxMDContextState]
Initilizes the `jax_md` state. Generally, this will be the `init_fn` of any
of the simulation routines in `jax_md` (or wrappers around these).

step_fn: Callable[..., JaxMDContextState]
Takes a state and advances a `jax_md` simulation by one step. Generally, this
will be the `apply_fn` of any of the simulation routines in `jax_md` (or wrappers
around these).

box: JaxArray
Affine transformation from a unit hypercube to the simulation box.

dt: float
Step size of the simulation.
"""

init_fn: Callable[..., JaxMDContextState]
step_fn: Callable[..., JaxMDContextState]
box: JaxArray
dt: float
from pysages.backends.contexts import JaxMDContext
from pysages.typing import Callable, Optional


class SamplingContext:
Expand Down Expand Up @@ -95,14 +45,12 @@ def __init__(

self.context = context
self.method = sampling_method
self.view = None
self.run = None

backend = import_module("." + self._backend_name, package="pysages.backends")
self.sampler = backend.bind(self, callback, **kwargs)

# `self.view` and `self.run` *must* be set by the backend bind function.
assert self.view is not None
# `self.run` *must* be set by the backend bind function.
assert self.run is not None

@property
Expand Down
Loading