Skip to content

Commit aa4837f

Browse files
Finetune script (#13)
* save initial files * add one more file * save current progress * tidy up, stress augmentation to fix * remove some redundant code * run tested fine wandb ok * remove ema * modify lr scheduler * some review comments * update version of black and rerun * remove e3nn dependency * Update finetune.py Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com> * Update finetune.py Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com> * Update finetune.py Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com> * Update finetune.py Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com> * save * second reviewcomments * update readme * revert epoch saving * update readme and tidy up wandb reporting * Update README.md Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com> * update readme --------- Co-authored-by: Ben Rhodes <benjamin.rhodes26@gmail.com>
1 parent 32484f0 commit aa4837f

23 files changed

+936
-54
lines changed

README.md

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ For more information on the models, please see the [MODELS.md](MODELS.md) file.
3838

3939
import ase
4040
from ase.build import bulk
41-
from orb_models.forcefield import pretrained
42-
from orb_models.forcefield import atomic_system
41+
42+
from orb_models.forcefield import atomic_system, pretrained
4343
from orb_models.forcefield.base import batch_graphs
4444

4545
device = "cpu" # or device="cuda"
@@ -66,10 +66,10 @@ atoms = atomic_system.atom_graphs_to_ase_atoms(
6666
```python
6767
import ase
6868
from ase.build import bulk
69+
6970
from orb_models.forcefield import pretrained
7071
from orb_models.forcefield.calculator import ORBCalculator
7172

72-
7373
device="cpu" # or device="cuda"
7474
orbff = pretrained.orb_v1(device=device) # or choose another model using ORB_PRETRAINED_MODELS[model_name]()
7575
calc = ORBCalculator(orbff, device=device)
@@ -95,6 +95,21 @@ print("Optimized Energy:", atoms.get_potential_energy())
9595
```
9696

9797

98+
### Finetuning
99+
You can finetune the model using your custom dataset.
100+
The dataset should be an [ASE sqlite database](https://wiki.fysik.dtu.dk/ase/ase/db/db.html#module-ase.db.core).
101+
```python
102+
python finetune.py --dataset=<dataset_name> --data_path=<your_data_path>
103+
```
104+
After the model is finetuned, checkpoints will, by default, be saved to the ckpts folder in the directory you ran the finetuning script from.
105+
106+
You can use the new model and load the checkpoint by:
107+
```python
108+
from orb_models.forcefield import pretrained
109+
110+
model = pretrained.orb_v1(weights_path=<path_to_ckpt>)
111+
```
112+
98113
### Citing
99114

100115
We are currently preparing a preprint for publication.

finetune.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
"""Finetuning loop."""
2+
3+
import argparse
4+
import logging
5+
import os
6+
from typing import Dict, Optional, Union
7+
8+
import torch
9+
import tqdm
10+
from torch.optim.lr_scheduler import _LRScheduler
11+
from torch.utils.data import BatchSampler, DataLoader, RandomSampler
12+
13+
import wandb
14+
from orb_models import utils
15+
from orb_models.dataset.ase_dataset import AseSqliteDataset
16+
from orb_models.forcefield import base, pretrained
17+
from wandb import wandb_run
18+
19+
logging.basicConfig(
20+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
21+
)
22+
23+
24+
def finetune(
25+
model: torch.nn.Module,
26+
optimizer: torch.optim.Optimizer,
27+
dataloader: DataLoader,
28+
lr_scheduler: Optional[_LRScheduler] = None,
29+
num_steps: Optional[int] = None,
30+
clip_grad: Optional[float] = None,
31+
log_freq: float = 10,
32+
device: torch.device = torch.device("cpu"),
33+
epoch: int = 0,
34+
):
35+
"""Train for a fixed number of steps.
36+
37+
Args:
38+
model: The model to optimize.
39+
optimizer: The optimizer for the model.
40+
dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed.
41+
lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate.
42+
num_steps: The number of training steps to take. This is required for distributed training,
43+
because controlling parallism is easier if all processes take exactly the same number of steps (
44+
this particularly applies when using dynamic batching).
45+
clip_grad: Optional, the gradient clipping threshold.
46+
log_freq: The logging frequency for step metrics.
47+
device: The device to use for training.
48+
epoch: The number of epochs the model has been fintuned.
49+
50+
Returns
51+
A dictionary of metrics.
52+
"""
53+
run: Optional[wandb_run.Run] = wandb.run
54+
55+
if clip_grad is not None:
56+
hook_handles = utils.gradient_clipping(model, clip_grad)
57+
58+
metrics = utils.ScalarMetricTracker()
59+
60+
# Set the model to "train" mode.
61+
model.train()
62+
63+
# Get tqdm for the training batches
64+
batch_generator = iter(dataloader)
65+
num_training_batches: Union[int, float]
66+
if num_steps is not None:
67+
num_training_batches = num_steps
68+
else:
69+
try:
70+
num_training_batches = len(dataloader)
71+
except TypeError:
72+
raise ValueError("Dataloader has no length, you must specify num_steps.")
73+
74+
batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches)
75+
76+
i = 0
77+
batch_iterator = iter(batch_generator_tqdm)
78+
while True:
79+
if num_steps and i == num_steps:
80+
break
81+
82+
optimizer.zero_grad(set_to_none=True)
83+
84+
step_metrics = {
85+
"batch_size": 0.0,
86+
"batch_num_edges": 0.0,
87+
"batch_num_nodes": 0.0,
88+
}
89+
90+
# Reset metrics so that it reports raw values for each step but still do averages on
91+
# the gradient accumulation.
92+
if i % log_freq == 0:
93+
metrics.reset()
94+
95+
batch = next(batch_iterator)
96+
batch = batch.to(device)
97+
step_metrics["batch_size"] += len(batch.n_node)
98+
step_metrics["batch_num_edges"] += batch.n_edge.sum()
99+
step_metrics["batch_num_nodes"] += batch.n_node.sum()
100+
101+
with torch.cuda.amp.autocast(enabled=False):
102+
batch_outputs = model.loss(batch)
103+
loss = batch_outputs.loss
104+
metrics.update(batch_outputs.log)
105+
if torch.isnan(loss):
106+
raise ValueError("nan loss encountered")
107+
loss.backward()
108+
109+
optimizer.step()
110+
111+
if lr_scheduler is not None:
112+
lr_scheduler.step()
113+
114+
metrics.update(step_metrics)
115+
116+
if i != 0 and i % log_freq == 0:
117+
metrics_dict = metrics.get_metrics()
118+
if run is not None:
119+
step = (epoch * num_training_batches) + i
120+
if run.sweep_id is not None:
121+
run.log(
122+
{"loss": metrics_dict["loss"]},
123+
commit=False,
124+
)
125+
run.log(
126+
{"step": step},
127+
commit=False,
128+
)
129+
run.log(utils.prefix_keys(metrics_dict, "finetune_step"), commit=True)
130+
131+
# Finished a single full step!
132+
i += 1
133+
134+
if clip_grad is not None:
135+
for h in hook_handles:
136+
h.remove()
137+
138+
return metrics.get_metrics()
139+
140+
141+
def build_train_loader(
142+
dataset_path: str,
143+
num_workers: int,
144+
batch_size: int,
145+
augmentation: Optional[bool] = True,
146+
target_config: Optional[Dict] = None,
147+
**kwargs,
148+
) -> DataLoader:
149+
"""Builds the train dataloader from a config file.
150+
151+
Args:
152+
dataset_path: Dataset path.
153+
num_workers: The number of workers for each dataset.
154+
batch_size: The batch_size config for each dataset.
155+
augmentation: If rotation augmentation is used.
156+
target_config: The target config.
157+
158+
Returns:
159+
The train Dataloader.
160+
"""
161+
log_train = "Loading train datasets:\n"
162+
dataset = AseSqliteDataset(
163+
dataset_path, target_config=target_config, augmentation=augmentation, **kwargs
164+
)
165+
166+
log_train += f"Total train dataset size: {len(dataset)} samples"
167+
logging.info(log_train)
168+
169+
sampler = RandomSampler(dataset)
170+
171+
batch_sampler = BatchSampler(
172+
sampler,
173+
batch_size=batch_size,
174+
drop_last=False,
175+
)
176+
177+
train_loader: DataLoader = DataLoader(
178+
dataset,
179+
num_workers=num_workers,
180+
worker_init_fn=utils.worker_init_fn,
181+
collate_fn=base.batch_graphs,
182+
batch_sampler=batch_sampler,
183+
timeout=10 * 60 if num_workers > 0 else 0,
184+
)
185+
return train_loader
186+
187+
188+
def run(args):
189+
"""Training Loop.
190+
191+
Args:
192+
config (DictConfig): Config for training loop.
193+
"""
194+
device = utils.init_device(device_id=args.device_id)
195+
utils.seed_everything(args.random_seed)
196+
197+
# Make sure to use this flag for matmuls on A100 and H100 GPUs.
198+
torch.set_float32_matmul_precision("high")
199+
200+
# Instantiate model
201+
model = pretrained.orb_v1(device=device)
202+
for param in model.parameters():
203+
param.requires_grad = True
204+
model_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
205+
logging.info(f"Model has {model_params} trainable parameters.")
206+
207+
# Move model to correct device.
208+
model.to(device=device)
209+
total_steps = args.max_epochs * args.num_steps
210+
optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model)
211+
212+
wandb_run = None
213+
# Logger instantiation/configuration
214+
if args.wandb:
215+
logging.info("Instantiating WandbLogger.")
216+
wandb_run = utils.init_wandb_from_config(
217+
dataset=args.dataset, job_type="finetuning", entity=args.wandb_entity
218+
)
219+
220+
wandb.define_metric("step")
221+
wandb.define_metric("finetune_step/*", step_metric="step")
222+
223+
loader_args = dict(
224+
dataset_path=args.data_path,
225+
num_workers=args.num_workers,
226+
batch_size=args.batch_size,
227+
target_config={"graph": ["energy", "stress"], "node": ["forces"]},
228+
)
229+
train_loader = build_train_loader(
230+
**loader_args,
231+
augmentation=True,
232+
)
233+
logging.info("Starting training!")
234+
235+
num_steps = args.num_steps
236+
237+
start_epoch = 0
238+
239+
for epoch in range(start_epoch, args.max_epochs):
240+
print(f"Start epoch: {epoch} training...")
241+
finetune(
242+
model=model,
243+
optimizer=optimizer,
244+
dataloader=train_loader,
245+
lr_scheduler=lr_scheduler,
246+
clip_grad=args.gradient_clip_val,
247+
device=device,
248+
num_steps=num_steps,
249+
epoch=epoch,
250+
)
251+
252+
# Save checkpoint from last epoch
253+
if epoch == args.max_epochs - 1:
254+
# create ckpts folder if it does not exist
255+
if not os.path.exists(args.checkpoint_path):
256+
os.makedirs(args.checkpoint_path)
257+
torch.save(
258+
model.state_dict(),
259+
os.path.join(args.checkpoint_path, f"checkpoint_epoch{epoch}.ckpt"),
260+
)
261+
logging.info(f"Checkpoint saved to {args.checkpoint_path}")
262+
263+
if wandb_run is not None:
264+
wandb_run.finish()
265+
266+
267+
def main():
268+
"""Main."""
269+
parser = argparse.ArgumentParser(
270+
description="Finetune orb model",
271+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
272+
)
273+
parser.add_argument(
274+
"--random_seed", default=1234, type=int, help="Random seed for finetuning."
275+
)
276+
parser.add_argument(
277+
"--device_id", default=0, type=int, help="GPU index to use if GPU is available."
278+
)
279+
parser.add_argument(
280+
"--wandb",
281+
default=True,
282+
action="store_true",
283+
help="If the run is logged to Weights and Biases (requires installation).",
284+
)
285+
parser.add_argument(
286+
"--wandb_entity",
287+
default="orbitalmaterials",
288+
type=str,
289+
help="Entity to log the run to in Weights and Biases.",
290+
)
291+
parser.add_argument(
292+
"--dataset",
293+
default="mp-traj",
294+
type=str,
295+
help="Dataset name for wandb run logging.",
296+
)
297+
parser.add_argument(
298+
"--data_path",
299+
default=os.path.join(os.getcwd(), "datasets/mptraj/finetune.db"),
300+
type=str,
301+
help="Dataset path to an ASE sqlite database (you must convert your data into this format).",
302+
)
303+
parser.add_argument(
304+
"--num_workers",
305+
default=8,
306+
type=int,
307+
help="Number of cpu workers for the pytorch data loader.",
308+
)
309+
parser.add_argument(
310+
"--batch_size", default=100, type=int, help="Batch size for finetuning."
311+
)
312+
parser.add_argument(
313+
"--gradient_clip_val", default=0.5, type=float, help="Gradient clip value."
314+
)
315+
parser.add_argument(
316+
"--max_epochs",
317+
default=50,
318+
type=int,
319+
help="Maximum number of epochs to finetune.",
320+
)
321+
parser.add_argument(
322+
"--num_steps",
323+
default=100,
324+
type=int,
325+
help="Num steps of in each epoch.",
326+
)
327+
parser.add_argument(
328+
"--checkpoint_path",
329+
default=os.path.join(os.getcwd(), "ckpts"),
330+
type=str,
331+
help="Path to save the model checkpoint.",
332+
)
333+
parser.add_argument(
334+
"--lr",
335+
default=3e-04,
336+
type=float,
337+
help="Learning rate. 3e-4 is purely a sensible default; you may want to tune this for your problem.",
338+
)
339+
args = parser.parse_args()
340+
run(args)
341+
342+
343+
if __name__ == "__main__":
344+
main()

0 commit comments

Comments
 (0)