Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
8fb51e3
Rename package loremjax -> lorem, add train script and BEC example
sirmarcel Feb 11, 2026
a1cc95f
Add train-mlp example (energy + forces on cumulene)
sirmarcel Feb 11, 2026
1264f55
Fix tests for CI and add backbone tests
sirmarcel Feb 11, 2026
55f1c0d
Fix CI: add jaxpme as git dep, make __init__.py tolerate missing deps
sirmarcel Feb 11, 2026
3c7b1e9
Fix jaxpme dependency name: jax-pme matches package metadata
sirmarcel Feb 11, 2026
fd6c894
Fix remaining marathon imports breaking CI
sirmarcel Feb 11, 2026
d0d7ce9
Fix ASE calculator and add usage example
sirmarcel Feb 14, 2026
4c81277
Fill in README with installation, usage, and hyperparameter docs
sirmarcel Feb 19, 2026
7a5ff16
Clarify deps
sirmarcel Feb 19, 2026
4507139
Allow wandb_project and wandb_name to be set via settings.yaml
sirmarcel Feb 22, 2026
869b070
Apply ruff formatting to calculator and example
sirmarcel Feb 22, 2026
57f9265
Add i-PI driven dynamics example with BEC support
sirmarcel Feb 23, 2026
1a57216
flatten ipi driver file structure
PicoCentauri Feb 26, 2026
a5cc2d0
Inherit from ASE BaseCalculator for i-PI compatibility
sirmarcel Mar 3, 2026
354785d
make BEC ase compatible
PicoCentauri Mar 3, 2026
7e3c9c1
Implement stress in ASE calculator
sirmarcel Mar 3, 2026
f306235
fix bec wrapper
PicoCentauri Mar 3, 2026
b9bac8e
Remove halfspace kwarg from get_batch calls (jax-pme API update)
sirmarcel Mar 3, 2026
5ac0a71
Fix remaining test failures after recent refactors
sirmarcel Mar 3, 2026
90d010f
Update imports from marathon.extra.hermes to marathon.grain
sirmarcel Mar 7, 2026
888269b
Add marathon as git dependency, fix k_grid size counting in ToBatch
sirmarcel Mar 7, 2026
2c853f4
Add tox examples runner, fix marathon dep, modernize example workflow
sirmarcel Mar 9, 2026
a14365a
Remove Windows from CI test matrix
sirmarcel Mar 9, 2026
bdf55a2
Expand README training docs, default to float32 matmul precision
sirmarcel Mar 9, 2026
07b3060
Remove try/except guards for comms and marathon imports
sirmarcel Mar 9, 2026
5690468
Add PyPI release workflow (trusted publishing on tag push)
sirmarcel Mar 9, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: Release

on:
push:
tags:
- "v*"

jobs:
build:
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.13"
- run: python -m pip install build
- name: Build package
run: python -m build
- uses: actions/upload-artifact@v4
with:
name: dist
path: dist/

publish:
needs: build
runs-on: ubuntu-22.04
environment: release
permissions:
id-token: write

steps:
- uses: actions/download-artifact@v4
with:
name: dist
path: dist/
- uses: pypa/gh-action-pypi-publish@release/v1
17 changes: 15 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ jobs:
python-version: "3.14"
- os: macos-14
python-version: "3.14"
- os: windows-2022
python-version: "3.14"

steps:
- uses: actions/checkout@v6
Expand All @@ -32,3 +30,18 @@ jobs:
- run: python -m pip install tox coverage[toml]
- name: run Python tests
run: tox -e tests

examples:
runs-on: ubuntu-22.04

steps:
- uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.13"
- run: python -m pip install tox
- name: run examples
run: tox -e examples
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ _version.py
.tox/
build/
dist/
uv.lock
196 changes: 193 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,201 @@
# LOREM-JAX

JAX implementation of [LOREM](https://arxiv.org/abs/2504.20462) (Learning Long-Range Representations with Equivariant Messages), a machine learning interatomic potential with equivariant long-range message passing.

Built on [JAX](https://github.com/jax-ml/jax), [Flax](https://github.com/google/flax), [e3x](https://github.com/google-research/e3x), and [jax-pme](https://github.com/lab-cosmo/jax-pme).

## Installation

Requires Python >= 3.11.

```bash
pip install .
```

## Usage

### ASE calculator

```python
import jax
from ase.build import bulk
from lorem.models.mlip import Lorem
from lorem.calculator import Calculator

model = Lorem(cutoff=5.0)
params = model.init(jax.random.key(42), *model.dummy_inputs())
calc = Calculator.from_model(model, params=params)

atoms = bulk("Ar") * [2, 2, 2]
calc.calculate(atoms)
print(calc.results["energy"], calc.results["forces"].shape)
```

To load a trained model from a checkpoint:

```python
calc = Calculator.from_checkpoint("path/to/checkpoint")
```

### Training

Training a model involves three steps: preparing the data, configuring the model and training settings, and running the training script.

#### 1. Prepare data

Training data is stored in [marathon](https://github.com/sirmarcel/marathon) format. Convert your extended XYZ dataset using a preparation script (see `examples/train-mlp/prepare.py` for a template):

```python
from marathon.data import datasets, get_splits
from marathon.grain import prepare

# datasets is a Path resolved from the $DATASETS environment variable
prepare(train_atoms, folder=datasets / "my_project/train", ...)
prepare(valid_atoms, folder=datasets / "my_project/valid", ...)
```

The `$DATASETS` environment variable sets the root directory where prepared datasets are stored. All dataset paths in `settings.yaml` are resolved relative to this directory.

#### 2. Configure the experiment

Each experiment lives in its own directory containing two YAML files:

**`model.yaml`** defines the model architecture:

```yaml
model:
lorem.Lorem:
cutoff: 5.0
max_degree: 4
max_degree_lr: 2
num_features: 128
num_spherical_features: 4
num_message_passing: 1
```

Use `lorem.LoremBEC` instead of `lorem.Lorem` to train a model that additionally predicts Born effective charges.

**`settings.yaml`** configures training:

```yaml
train: "my_project/train" # path relative to $DATASETS
valid: "my_project/valid" # path relative to $DATASETS
seed: 23
batcher:
batch_size: 4
loss_weights: {"energy": 0.5, "forces": 0.5}
optimizer: adam # adam or muon
start_learning_rate: 1e-3
min_learning_rate: 1e-6
max_epochs: 2000
valid_every_epoch: 2
decay_style: linear # linear, exponential, or warmup_cosine
use_wandb: True
```

<details>
<summary>All training settings</summary>

| Setting | Default | Description |
|---|---|---|
| `train` | *required* | Training dataset path (relative to `$DATASETS`) |
| `valid` | *required* | Validation dataset path (relative to `$DATASETS`) |
| `test_datasets` | `{}` | Extra test datasets: `{name: [path, save_predictions]}` |
| `batcher.batch_size` | *required* | Samples per batch |
| `batcher.size_strategy` | `powers_of_4` | Padding strategy for batch dimensions |
| `loss_weights` | `{"energy": 0.5, "forces": 0.5}` | Per-target loss weights |
| `scale_by_variance` | `False` | Scale loss weights by validation set variance |
| `optimizer` | `adam` | Optimizer (`adam`, `muon`, or any optax optimizer) |
| `start_learning_rate` | `1e-3` | Initial learning rate |
| `min_learning_rate` | `1e-6` | Minimum learning rate |
| `max_epochs` | `2000` | Maximum training epochs |
| `valid_every_epoch` | `2` | Validate every N epochs |
| `decay_style` | `linear` | LR schedule: `linear`, `exponential`, or `warmup_cosine` |
| `start_decay_after` | `10` | Epoch to begin LR decay |
| `stop_decay_after` | `max_epochs` | Epoch to end LR decay (linear only) |
| `warmup_epochs` | `0` | Warmup epochs (`warmup_cosine` only) |
| `gradient_clip` | `0` | Gradient clipping threshold (0 = disabled) |
| `seed` | `0` | Random seed |
| `rotational_augmentation` | `False` | Apply random rotations to training data |
| `filter_mixed_pbc` | `False` | Filter out structures with mixed periodic boundary conditions |
| `filter_above_num_atoms` | `False` | Filter out structures above this atom count |
| `checkpointers` | `default` | `default` or `full` (adds RMSE checkpointers) |
| `use_wandb` | `True` | Log to Weights & Biases |
| `wandb_project` | auto | W&B project name (default: derived from folder names) |
| `wandb_name` | auto | W&B run name (default: experiment folder name) |
| `benchmark_pipeline` | `True` | Benchmark data pipeline before training |
| `compilation_cache` | `False` | Enable JAX persistent compilation cache |
| `default_matmul_precision` | `float32` | JAX matmul precision (`default`, `float32`) |
| `debug_nans` | `False` | Enable JAX NaN debugging (~50% slowdown) |
| `enable_x64` | `False` | Enable 64-bit floating point |
| `worker_count` | `4` | Data loading workers (training) |
| `worker_count_valid` | `worker_count` | Data loading workers (validation) |
| `worker_buffer_size` | `2` | Prefetch buffer per worker (training) |

</details>

#### 3. Run training

```bash
cd my_experiment
DATASETS=/path/to/datasets lorem-train
```

Training writes checkpoints, logs, and plots to a `run/` directory inside the experiment folder. If a `run/` directory already exists, training resumes from the latest checkpoint.

See `examples/train-mlp/` and `examples/train-bec/` for complete examples including data preparation and configuration files.

### Model variants

- **`Lorem`** -- the standard MLIP model (energy + forces + stress)
- **`LoremBEC`** -- predicts Born effective charges in addition to energy/forces

### Key hyperparameters

| Parameter | Default | Description |
|---|---|---|
| `cutoff` | 5.0 | Short-range cutoff radius (A) |
| `max_degree` | 6 | Maximum angular momentum for spherical features |
| `max_degree_lr` | 2 | Maximum angular momentum for long-range charges |
| `num_features` | 128 | Number of scalar features |
| `num_spherical_features` | 8 | Number of spherical feature channels |
| `num_radial` | 32 | Number of radial basis functions |
| `num_message_passing` | 0 | Number of short-range message passing steps |
| `lr` | True | Enable long-range (Ewald) interaction |

## Installing the i-PI driver

After installation of the package install the i-PI driver via
After installation of the package, install the i-PI driver via:

```bash
loremjax-install-ipi-driver
lorem-install-ipi-driver
```

The installation script will hotfix the i-PI installation by performing the following steps copying the LOREM driver into the i-PI `pes` directory. You can rerun `loremjax-install-ipi-driver` anytime (it is idempotent) if you switch environments or reinstall i-PI.
This copies the LOREM driver into the i-PI `pes` directory. You can rerun `lorem-install-ipi-driver` anytime (it is idempotent) if you switch environments or reinstall i-PI.

## Development

Format and lint:

```bash
ruff format . && ruff check --fix .
```

Run tests:

```bash
python -m pytest tests/ -v --override-ini="addopts="
```

Or use tox:

```bash
tox -e lint # check formatting + linting
tox -e tests # run unit tests
tox -e examples # run examples as smoke tests
tox -e format # auto-format
```

## License

BSD-3-Clause
49 changes: 49 additions & 0 deletions examples/calculator/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Example: using the LOREM ASE calculator with randomly initialised weights."""

import jax

from ase.build import bulk, molecule

from lorem.calculator import Calculator
from lorem.models.mlip import Lorem

# Create a Lorem model with default hyperparameters
model = Lorem(cutoff=5.0)

# Initialize with random parameters
key = jax.random.key(42)
params = model.init(key, *model.dummy_inputs())

# Create calculator (no baseline offset with random weights)
calc = Calculator.from_model(model, params=params)

# -- periodic bulk system --
atoms = bulk("Ar") * [2, 2, 2]
calc.calculate(atoms)

energy = calc.results["energy"]
forces = calc.results["forces"]

print("=== Periodic bulk Ar (2x2x2) ===")
print(f"Number of atoms: {len(atoms)}")
print(f"Energy: {energy:.6f} eV")
print(f"Forces shape: {forces.shape}")
print(f"Max force component: {abs(forces).max():.6f} eV/A")

# -- non-periodic molecule --
atoms_mol = molecule("H2O")
atoms_mol.center(vacuum=5.0)

calc_mol = Calculator.from_model(model, params=params)
calc_mol.calculate(atoms_mol)

print("\n=== Non-periodic H2O molecule ===")
print(f"Number of atoms: {len(atoms_mol)}")
print(f"Energy: {calc_mol.results['energy']:.6f} eV")
print(f"Forces:\n{calc_mol.results['forces']}")

# -- using ASE interface --
atoms.calc = Calculator.from_model(model, params=params)
print("\n=== ASE interface ===")
print(f"get_potential_energy: {atoms.get_potential_energy():.6f} eV")
print(f"get_forces max: {abs(atoms.get_forces()).max():.6f} eV/A")
52 changes: 52 additions & 0 deletions examples/md-ipi/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Driven dynamics with LOREM via i-PI

This example shows how to run an applied electric field simulation using a LOREM BEC model and i-PI's driven dynamics (Electric Dipole Approximation).

Unlike the standard i-PI efield example which reads fixed Born Effective Charges from a file, LOREM computes BECs on-the-fly from its learned model (`<bec mode="driver"/>`).

## Prerequisites

- Trained `LoremBEC` checkpoint (with `model/model.yaml`, `model/baseline.yaml`, `model/model.msgpack`)
- i-PI installed (`pip install ipi`)
- LOREM i-PI driver installed: `lorem-install-ipi-driver`

## Files

- `input.xml` β€” i-PI configuration for driven dynamics (EDA-NVE). Adapt the E-field parameters (`amp`, `freq`, `peak`, `sigma`), cell size, and `start.xyz` to your system.
- `start.xyz` β€” Starting structure in atomic units. Replace with your system's geometry.
- `run.sh` β€” Launch script. Set `MODEL_PATH` to your trained checkpoint folder.

## Running

```bash
MODEL_PATH="/path/to/checkpoint"

# Install LOREM driver into i-PI (idempotent)
lorem-install-ipi-driver

# Start i-PI server
i-pi input.xml > i-pi.out &
sleep 5

# Start LOREM driver
i-pi-driver -a lorem -u -m lorem \
-o model_path=${MODEL_PATH},template=start.xyz \
> driver.out &

wait
```

See `run.sh` for a self-contained reference script.

i-PI outputs will be written to `i-pi.*` files. Key outputs:
- `i-pi.properties.out` β€” energy, kinetic energy, E-field strength over time
- `i-pi.positions_*` β€” nuclear trajectories
- `i-pi.bec{x,y,z}_*` β€” BEC tensor components over time

## Adapting to your system

1. Replace `start.xyz` with your starting geometry (atomic units).
2. Set the cell size in `input.xml` (large for isolated systems, physical for periodic).
3. Set `pbc='True'` in `<ffsocket>` if your system is periodic.
4. Adjust E-field parameters to match your target excitation.
5. Point `MODEL_PATH` in `run.sh` to your trained LoremBEC checkpoint.
Loading