Skip to content

Commit 35e1889

Browse files
sirmarcelclaude
andauthored
Add initial_weights support for fine-tuning
Adds initial_weights setting to settings.yaml — loads model weights from a previous checkpoint (.msgpack) while starting optimizer, step counter, and data iterator fresh. Recursive merge_params supports partial architecture matches (new layers keep random init). Includes fine-tuning example and unit tests. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6805441 commit 35e1889

File tree

6 files changed

+108
-0
lines changed

6 files changed

+108
-0
lines changed

examples/train-mlp/README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,16 @@ DATASETS=. python prepare.py
1919
cd my_experiment
2020
DATASETS=.. lorem-train
2121
```
22+
23+
## Fine-tuning from pretrained weights
24+
25+
The `my_experiment_finetune/` directory demonstrates restarting training from
26+
pretrained model weights. Add `initial_weights` to `settings.yaml`:
27+
28+
```yaml
29+
initial_weights: "path/to/previous_run/checkpoints/R2_E+F/model/model.msgpack"
30+
```
31+
32+
Only model weights are loaded — optimizer state, step counter, and data iterator
33+
start fresh. If the current model has layers not present in the source checkpoint,
34+
those layers keep their random initialization.
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
model:
2+
lorem.Lorem:
3+
cutoff: 5.0
4+
max_degree_lr: 2
5+
num_features: 128
6+
max_degree: 4
7+
num_message_passing: 1
8+
num_spherical_features: 4
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
train: "mlp_example/train"
2+
valid: "mlp_example/valid"
3+
seed: 42
4+
batcher:
5+
batch_size: 4
6+
coarse_strategy: powers_of_4
7+
size_strategy: powers_of_4
8+
fine_strategy: powers_of_4
9+
loss_weights: {"energy": 0.5, "forces": 0.5}
10+
checkpointers: full
11+
optimizer: muon
12+
valid_every_epoch: 1
13+
decay_style: warmup_cosine
14+
warmup_epochs: 0
15+
start_learning_rate: 1e-4
16+
max_epochs: 1
17+
benchmark_pipeline: True
18+
use_wandb: False
19+
default_matmul_precision: "float32"
20+
worker_count: 2
21+
initial_weights: "../my_experiment/run/checkpoints/R2_E+F/model/model.msgpack"

src/lorem/train.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ def main():
3535
should_filter_mixedpbc = settings.pop("filter_mixed_pbc", False)
3636
should_filter_above_num_atoms = settings.pop("filter_above_num_atoms", False)
3737

38+
initial_weights = settings.pop("initial_weights", None)
39+
3840
loss_weights = settings.pop("loss_weights", {"energy": 0.5, "forces": 0.5})
3941
scale_by_variance = settings.pop("scale_by_variance", False)
4042

@@ -166,6 +168,26 @@ def main():
166168
num_parameters = int(sum(x.size for x in jax.tree_util.tree_leaves(params)))
167169
comms.state(f"Parameter count: {num_parameters}")
168170

171+
# -- initial weights (pre-trained) --
172+
if initial_weights:
173+
174+
def merge_params(target, source):
175+
if isinstance(target, dict) and isinstance(source, dict):
176+
result = {}
177+
for k in target:
178+
if k in source:
179+
result[k] = merge_params(target[k], source[k])
180+
else:
181+
result[k] = target[k]
182+
return result
183+
return source
184+
185+
from marathon.io import read_msgpack
186+
187+
source_params = read_msgpack(Path(initial_weights))
188+
params = merge_params(params, source_params)
189+
comms.talk(f"loaded initial weights from {initial_weights}")
190+
169191
# -- checkpointers --
170192
from marathon.emit import SummedMetric
171193

@@ -609,6 +631,7 @@ def optimizer(learning_rate):
609631
"num_parameters": num_parameters,
610632
"worker_count": worker_count,
611633
"worker_buffer_size": worker_buffer_size,
634+
"initial_weights": initial_weights,
612635
}
613636

614637
metrics = {key: ["r2", "mae", "rmse"] for key in keys}

tests/test_train.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
def merge_params(target, source):
2+
if isinstance(target, dict) and isinstance(source, dict):
3+
result = {}
4+
for k in target:
5+
if k in source:
6+
result[k] = merge_params(target[k], source[k])
7+
else:
8+
result[k] = target[k]
9+
return result
10+
return source
11+
12+
13+
def test_merge_params_full_match():
14+
"""Source fully matches target structure — all weights replaced."""
15+
target = {"a": 1, "b": {"c": 2, "d": 3}}
16+
source = {"a": 10, "b": {"c": 20, "d": 30}}
17+
result = merge_params(target, source)
18+
assert result == {"a": 10, "b": {"c": 20, "d": 30}}
19+
20+
21+
def test_merge_params_partial_match():
22+
"""Target has extra keys not in source — they keep target values."""
23+
target = {"a": 1, "b": 2, "new_layer": 99}
24+
source = {"a": 10, "b": 20}
25+
result = merge_params(target, source)
26+
assert result == {"a": 10, "b": 20, "new_layer": 99}
27+
28+
29+
def test_merge_params_nested_partial():
30+
"""Nested dict with partial overlap."""
31+
target = {"layer1": {"w": 1}, "layer2": {"w": 2}}
32+
source = {"layer1": {"w": 10}}
33+
result = merge_params(target, source)
34+
assert result == {"layer1": {"w": 10}, "layer2": {"w": 2}}
35+
36+
37+
def test_merge_params_source_extra_keys_ignored():
38+
"""Keys in source but not target are ignored (target structure wins)."""
39+
target = {"a": 1}
40+
source = {"a": 10, "extra": 99}
41+
result = merge_params(target, source)
42+
assert result == {"a": 10}

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ usedevelop = true
4040
commands =
4141
python examples/calculator/example.py
4242
bash -c 'cd examples/train-mlp && DATASETS=. python prepare.py && cd my_experiment && DATASETS=.. lorem-train'
43+
bash -c 'cd examples/train-mlp/my_experiment_finetune && DATASETS=.. lorem-train'
4344
bash -c 'cd examples/train-bec && DATASETS=. python prepare.py && cd my_experiment && DATASETS=.. lorem-train'
4445

4546
allowlist_externals = bash

0 commit comments

Comments
 (0)