Skip to content

Commit d4b64c5

Browse files
sirmarcelclaude
andauthored
Use marathon-train from PyPI and clean up lazy imports
Switch marathon-train dependency from pinned git URL to PyPI package and move marathon imports to module level in backbone.py and calculator.py (the lazy pattern was a workaround for the git dependency). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent efc010c commit d4b64c5

File tree

3 files changed

+5
-8
lines changed

3 files changed

+5
-8
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111
"jax",
1212
"jax-pme @ git+https://github.com/lab-cosmo/jax-pme.git",
1313
"jaxtyping",
14-
"marathon-train[grain] @ git+https://github.com/sirmarcel/marathon.git@release/v0.2.0",
14+
"marathon-train[grain]",
1515
"opsis",
1616
"optax",
1717
]

src/lorem/calculator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
BaseCalculator,
99
PropertyNotImplementedError,
1010
)
11+
from marathon.emit.checkpoint import read_msgpack
12+
from marathon.io import from_dict, read_yaml
1113

1214
from lorem.neighborlist import NeighborListCache
1315

@@ -78,8 +80,6 @@ def from_checkpoint(
7880
):
7981
from pathlib import Path
8082

81-
from marathon.io import from_dict, read_yaml
82-
8383
folder = Path(folder)
8484

8585
model = from_dict(read_yaml(folder / "model/model.yaml"))
@@ -89,8 +89,6 @@ def from_checkpoint(
8989
baseline = read_yaml(folder / "model/baseline.yaml")
9090
species_to_weight = baseline["elemental"]
9191

92-
from marathon.emit.checkpoint import read_msgpack
93-
9492
params = read_msgpack(folder / "model/model.msgpack")
9593

9694
return cls(model.predict, species_to_weight, params, model.cutoff, **kwargs)

src/lorem/models/backbone.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@
88
import e3x
99
import flax.linen as nn
1010
from flax.core import FrozenDict
11+
from marathon.utils import masked
1112

1213

1314
def _masked(fn, x, mask):
14-
"""Apply fn only where mask is True. Lazy import from marathon."""
15-
from marathon.utils import masked
16-
15+
"""Apply fn only where mask is True."""
1716
return masked(fn, x, mask)
1817

1918

0 commit comments

Comments
 (0)