Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions examples/train-mlp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,16 @@ DATASETS=. python prepare.py
cd my_experiment
DATASETS=.. lorem-train
```

## Fine-tuning from pretrained weights

The `my_experiment_finetune/` directory demonstrates restarting training from
pretrained model weights. Add `initial_weights` to `settings.yaml`:

```yaml
initial_weights: "path/to/previous_run/checkpoints/R2_E+F/model/model.msgpack"
```

Only model weights are loaded — optimizer state, step counter, and data iterator
start fresh. If the current model has layers not present in the source checkpoint,
those layers keep their random initialization.
8 changes: 8 additions & 0 deletions examples/train-mlp/my_experiment_finetune/model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
lorem.Lorem:
cutoff: 5.0
max_degree_lr: 2
num_features: 128
max_degree: 4
num_message_passing: 1
num_spherical_features: 4
21 changes: 21 additions & 0 deletions examples/train-mlp/my_experiment_finetune/settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
train: "mlp_example/train"
valid: "mlp_example/valid"
seed: 42
batcher:
batch_size: 4
coarse_strategy: powers_of_4
size_strategy: powers_of_4
fine_strategy: powers_of_4
loss_weights: {"energy": 0.5, "forces": 0.5}
checkpointers: full
optimizer: muon
valid_every_epoch: 1
decay_style: warmup_cosine
warmup_epochs: 0
start_learning_rate: 1e-4
max_epochs: 1
benchmark_pipeline: True
use_wandb: False
default_matmul_precision: "float32"
worker_count: 2
initial_weights: "../my_experiment/run/checkpoints/R2_E+F/model/model.msgpack"
23 changes: 23 additions & 0 deletions src/lorem/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ def main():
should_filter_mixedpbc = settings.pop("filter_mixed_pbc", False)
should_filter_above_num_atoms = settings.pop("filter_above_num_atoms", False)

initial_weights = settings.pop("initial_weights", None)

loss_weights = settings.pop("loss_weights", {"energy": 0.5, "forces": 0.5})
scale_by_variance = settings.pop("scale_by_variance", False)

Expand Down Expand Up @@ -166,6 +168,26 @@ def main():
num_parameters = int(sum(x.size for x in jax.tree_util.tree_leaves(params)))
comms.state(f"Parameter count: {num_parameters}")

# -- initial weights (pre-trained) --
if initial_weights:

def merge_params(target, source):
if isinstance(target, dict) and isinstance(source, dict):
result = {}
for k in target:
if k in source:
result[k] = merge_params(target[k], source[k])
else:
result[k] = target[k]
return result
return source

from marathon.io import read_msgpack

source_params = read_msgpack(Path(initial_weights))
params = merge_params(params, source_params)
comms.talk(f"loaded initial weights from {initial_weights}")

# -- checkpointers --
from marathon.emit import SummedMetric

Expand Down Expand Up @@ -609,6 +631,7 @@ def optimizer(learning_rate):
"num_parameters": num_parameters,
"worker_count": worker_count,
"worker_buffer_size": worker_buffer_size,
"initial_weights": initial_weights,
}

metrics = {key: ["r2", "mae", "rmse"] for key in keys}
Expand Down
42 changes: 42 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
def merge_params(target, source):
if isinstance(target, dict) and isinstance(source, dict):
result = {}
for k in target:
if k in source:
result[k] = merge_params(target[k], source[k])
else:
result[k] = target[k]
return result
return source


def test_merge_params_full_match():
"""Source fully matches target structure — all weights replaced."""
target = {"a": 1, "b": {"c": 2, "d": 3}}
source = {"a": 10, "b": {"c": 20, "d": 30}}
result = merge_params(target, source)
assert result == {"a": 10, "b": {"c": 20, "d": 30}}


def test_merge_params_partial_match():
"""Target has extra keys not in source — they keep target values."""
target = {"a": 1, "b": 2, "new_layer": 99}
source = {"a": 10, "b": 20}
result = merge_params(target, source)
assert result == {"a": 10, "b": 20, "new_layer": 99}


def test_merge_params_nested_partial():
"""Nested dict with partial overlap."""
target = {"layer1": {"w": 1}, "layer2": {"w": 2}}
source = {"layer1": {"w": 10}}
result = merge_params(target, source)
assert result == {"layer1": {"w": 10}, "layer2": {"w": 2}}


def test_merge_params_source_extra_keys_ignored():
"""Keys in source but not target are ignored (target structure wins)."""
target = {"a": 1}
source = {"a": 10, "extra": 99}
result = merge_params(target, source)
assert result == {"a": 10}
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ usedevelop = true
commands =
python examples/calculator/example.py
bash -c 'cd examples/train-mlp && DATASETS=. python prepare.py && cd my_experiment && DATASETS=.. lorem-train'
bash -c 'cd examples/train-mlp/my_experiment_finetune && DATASETS=.. lorem-train'
bash -c 'cd examples/train-bec && DATASETS=. python prepare.py && cd my_experiment && DATASETS=.. lorem-train'

allowlist_externals = bash
Expand Down