This guide explains how to use finetune.py to finetune Orb models with custom loss weights and reference energies.
Control the relative importance of energy, forces, and stress in the loss function.
- Load from file: Provide your own reference energies
- Fixed or trainable: Keep them fixed during training or optimize them
- Multiple formats: Supports JSON and text file formats
The finetuning script expects data in ASE SQLite database format. This is a standard format used by the Atomic Simulation Environment (ASE) library.
Each structure in your database should have:
- Positions: Atomic positions (automatically stored with the Atoms object)
- Atomic numbers: Element types (automatically stored with the Atoms object)
- Cell: Unit cell vectors (for periodic systems)
- Energy: Total energy of the structure (in eV)
- Forces: Forces on each atom (in eV/Å), shape
(n_atoms, 3) - Stress (optional): Stress tensor in Voigt notation (6 components), in eV/ų
Here's how to convert your data into the required format:
import ase
import ase.db
from ase import Atoms
from ase.calculators.singlepoint import SinglePointCalculator
# Create a database file
db = ase.db.connect('my_training_data.db')
# For each structure in your dataset:
for structure in your_structures:
# Create an ASE Atoms object
atoms = Atoms(
symbols=structure['symbols'], # e.g., ['H', 'H', 'O']
positions=structure['positions'], # shape: (n_atoms, 3) in Angstroms
cell=structure['cell'], # shape: (3, 3) in Angstroms
pbc=True # Set to True for periodic systems
)
# Attach energy, forces, and stress using SinglePointCalculator
calc = SinglePointCalculator(
atoms=atoms,
energy=structure['energy'], # Total energy in eV
forces=structure['forces'], # shape: (n_atoms, 3) in eV/Å
stress=structure['stress'] # shape: (6,) in eV/ų (Voigt notation)
)
atoms.calc = calc
# Write to database
db.write(atoms)
print(f"Created database with {len(db)} structures")If you have structures in XYZ format with energies/forces in separate files:
import ase.io
import ase.db
import numpy as np
from ase.calculators.singlepoint import SinglePointCalculator
# Read structures
atoms_list = ase.io.read('structures.xyz', index=':')
# Load your energies and forces (example)
energies = np.loadtxt('energies.txt') # One energy per structure
forces_list = [...] # List of (n_atoms, 3) arrays
# Create database
db = ase.db.connect('training_data.db')
for atoms, energy, forces in zip(atoms_list, energies, forces_list):
calc = SinglePointCalculator(
atoms=atoms,
energy=energy,
forces=forces
)
atoms.calc = calc
db.write(atoms)If your data includes stress, it should be in Voigt notation (6 components):
stress = [σ_xx, σ_yy, σ_zz, σ_yz, σ_xz, σ_xy] # in eV/ųIf you have a full 3×3 stress tensor, convert it to Voigt notation:
stress_voigt = [
stress_3x3[0, 0], # σ_xx
stress_3x3[1, 1], # σ_yy
stress_3x3[2, 2], # σ_zz
stress_3x3[1, 2], # σ_yz
stress_3x3[0, 2], # σ_xz
stress_3x3[0, 1], # σ_xy
]Check that your database is formatted correctly:
import ase.db
db = ase.db.connect('my_training_data.db')
print(f"Total structures: {len(db)}")
# Check first structure
row = db.get(1)
atoms = row.toatoms()
print(f"Formula: {row.formula}")
print(f"Energy: {row.energy} eV")
print(f"Forces shape: {row.forces.shape}")
print(f"Has stress: {row.stress is not None}")
print(f"Number of atoms: {row.natoms}")Expected output:
Total structures: 1000
Formula: H2O
Energy: -14.2234 eV
Forces shape: (3, 3)
Has stress: True
Number of atoms: 3
python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0 \
--stress_loss_weight 0.0 \
--batch_size 100 \
--max_epochs 50python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--custom_reference_energies /path/to/reference_energies.json \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--custom_reference_energies /path/to/reference_energies.json \
--trainable_reference_energies \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0python finetune.py \
--data_path /path/to/your/dataset.db \
--base_model orb_v3_conservative_omol \
--trainable_reference_energies \
--energy_loss_weight 0.1 \
--forces_loss_weight 1.0You can use either element symbols or atomic numbers as keys:
With element symbols:
{
"H": -13.6,
"C": -1030.5,
"N": -1400.0,
"O": -2000.0
}With atomic numbers:
{
"1": -13.6,
"6": -1030.5,
"7": -1400.0,
"8": -2000.0
}With element symbols:
H -13.6
C -1030.5
N -1400.0
O -2000.0
With atomic numbers:
1 -13.6
6 -1030.5
7 -1400.0
8 -2000.0
Lines starting with # are treated as comments and ignored.
--energy_loss_weight: Weight for energy loss (default: 1.0)--forces_loss_weight: Weight for forces loss (default: 1.0)--stress_loss_weight: Weight for stress loss (set to 0 to disable)--equigrad_loss_weight: Weight for the Equigrad loss (turned off by default). Only available for the conservative models.- NOTE: We've found that Equigrad loss should be ≳1000x smaller than the other losses
--custom_reference_energies: Path to reference energies file (JSON or text format)--trainable_reference_energies: Make reference energies trainable during finetuning
The script automatically handles the differences between conservative and direct models:
-
Conservative models (e.g.,
orb_v3_conservative_omol):- Use
grad_forcesandgrad_stressas loss-weight keys - Compute forces via automatic differentiation
- Use
-
Direct models (e.g.,
orb_v3_direct_omol):- Use
forcesandstressas loss-weight keys - Predict forces directly
- Use
When you specify loss weights via command line (e.g., --forces_loss_weight 10.0), the script automatically maps this to the correct key (grad_forces for conservative models, forces for direct models).
If you prefer to write your own finetuning script, you can use the clean API directly:
from orb_models.forcefield import pretrained
# Load model with custom configuration
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
device='cuda',
precision='float32-high',
train=True,
train_reference_energies=True, # Make reference energies trainable
loss_weights={
'energy': 1.0,
'grad_forces': 10.0, # Use 'grad_forces' for conservative models
'grad_stress': 100.0 # Use 'grad_stress' for conservative models
}
)
# For direct models, use 'forces' and 'stress' keys:
model, atoms_adapter = pretrained.orb_v3_direct_omol(
device='cuda',
train=True,
loss_weights={
'energy': 1.0,
'forces': 10.0, # Use 'forces' for direct models
'stress': 100.0 # Use 'stress' for direct models
}
)
# The model is now ready for training with your custom configuration!- Without custom reference energies: The model uses the pretrained reference energies from the checkpoint
- With
--custom_reference_energies: Your custom values replace the pretrained ones - With
--trainable_reference_energies(ortrain_reference_energies=Truein the API): Reference energies become learnable parameters that will be optimized during training
When you save a checkpoint after finetuning, the reference energies (whether custom or trained) are saved in the state dict. When you load the checkpoint later:
import torch
from orb_models.forcefield import pretrained
# Load model architecture (set train=False for inference)
model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
# Load your finetuned checkpoint
model.load_state_dict(torch.load('path/to/finetuned_checkpoint.pt'))
# The custom/trained reference energies and any modified parameters are now loaded!You can also specify loss weights when loading for further finetuning:
# Load for continued finetuning with different loss weights
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
train=True,
loss_weights={'energy': 0.5, 'grad_forces': 20.0}
)
model.load_state_dict(torch.load('path/to/finetuned_checkpoint.pt'))Finetuning on ORCA wB97M-V data with different reference scheme:
- Create your reference energies file (
my_refs.json):
{
"H": -13.6,
"C": -1030.5,
"N": -1400.0,
"O": -2000.0
}- Run finetuning with fixed references:
python finetune.py \
--data_path my_dataset.db \
--base_model orb_v3_conservative_omol \
--custom_reference_energies my_refs.json \
--energy_loss_weight 1.0 \
--forces_loss_weight 10.0 \
--max_epochs 50- Use the finetuned model:
from orb_models.forcefield import pretrained
import torch
model, atoms_adapter = pretrained.orb_v3_conservative_omol(train=False)
model.load_state_dict(torch.load('checkpoints/my_finetuned_model.pt'))
# Reference energies from my_refs.json are now loaded!import torch
from torch.utils.data import DataLoader
from orb_models.forcefield import pretrained
from orb_models.dataset.ase_sqlite_dataset import AseSqliteDataset
# Load model with configuration
model, atoms_adapter = pretrained.orb_v3_conservative_omol(
device='cuda',
train=True,
train_reference_energies=False, # Fixed reference energies
loss_weights={
'energy': 1.0,
'grad_forces': 10.0,
}
)
# Load your data
dataset = AseSqliteDataset(
name='my_dataset',
path='my_dataset.db',
system_config=model.system_config,
target_config={'graph': ['energy'], 'node': ['forces']}
)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# Set up optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
# Training loop
for epoch in range(50):
for batch in dataloader:
batch = batch.to('cuda')
output = model.loss(batch)
loss = output.loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Loss: {loss.item()}")
# Save checkpoint
torch.save(model.state_dict(), 'my_finetuned_model.pt')-
Energy vs Forces weighting: If forces are more important for your application, use higher
--forces_loss_weight(e.g., 10.0) and lower--energy_loss_weight(e.g., 0.1) -
Fixed vs Trainable references:
- Use fixed if you know the correct reference energies for your method
- Use trainable if you want the model to learn the best reference energies from your data
-
Starting from pretrained: If you don't provide custom reference energies, the model starts with the pretrained values (e.g., ORCA-fitted for OMol models, VASP for OMAT models)
-
Stress training: Set
--stress_loss_weight 0.0if your dataset doesn't include stress information