-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconfig_finetuning.yaml
More file actions
193 lines (166 loc) · 10.1 KB
/
config_finetuning.yaml
File metadata and controls
193 lines (166 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# The config file is divided into 4 sections -- `data`, `train`, `model`, and `global_options`
# The config system relies on omegaconf (https://omegaconf.readthedocs.io/en/2.3_branch/index.html)
# and hydra (https://hydra.cc/docs/intro/) functionalities, such as
# - omegaconf's variable interpolation (https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#variable-interpolation)
# - omegaconf's resolvers (https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#resolvers)
# - hydra's instantiate (https://hydra.cc/docs/advanced/instantiate_objects/overview/)
# With hydra's instantiation (notice the `_target_`s everywhere), the config file (almost) directly corresponds to instantiating objects as one would normally do in Python.
# Much of the infrastructure is based on PyTorch Lightning (https://lightning.ai/docs/pytorch/stable/), such as the use of Lightning's Trainer, DataModule, LightningModule, Callback objects.
# ===========
# RUN
# ===========
# the run types will be completed in sequence
# one can do `train`, `val`, `test` run types
run: [train, test]
# the following parameters (cutoff_radius, chemical_symbols, model_type_names) are not used direcly by the code
# parameters that take thier values show up multiple times in the config, so this allows us to use
# variable interpolation to keep their multiple instances consistent
# There are two sets of atomic types to keep track of in most applications
# -- there is the conventional atomic species (e.g. C, H), and a separate `type_names` known to the model.
# The model only knows types based on a set of zero-based indices and user-given `type_names` argument.
# An example where this distinction is necessary include datasets with the same atomic species with different charge states:
# we could define `chemical_symbols: [C, C]` and model `type_names: [C3, C4]` for +3 and +4 charge states.
# There could also be instances such as coarse graining we only care about the model's `type_names` (no need to define chemical species).
# Because of this distinction, these variables show up as arguments across different categories, including, data, model, metrics and even callbacks.
# In this case, we fix both to be the same, so we define a single set of each here and use variable interpolation to retrieve them below.
# This ensures a single location where the values are set to reduce the chances of mis-configuring runs.
model_package_path: /content/NequIP-OAM-S-0.1.nequip.zip
model_type_names: ${type_names_from_package:${model_package_path}}
# data and model r_max can be different (model's r_max should be smaller), but we try to make them the same
cutoff_radius: ${cutoff_radius_from_package:${model_package_path}}
results_dir: ./results_ft
# ============
# DATA
# ============
# `data` is managed by `LightningDataModule`s
# NequIP provides some standard datamodules that can be found in `nequip.data.datamodule`
# Users are free to define and use their own datamodules that subclass nequip.data.datamodule.NequIPDataModule
data:
_target_: nequip.data.datamodule.ASEDataModule
seed: 456 # dataset seed for reproducibility
# here we take an ASE-readable file (in extxyz format) and split it into train:val:test = 80:10:10
split_dataset:
file_path: ./sitraj.xyz
train: 0.8
val: 0.1
test: 0.1
# `transforms` convert data from the Dataset to a form that can be used by the ML model
# the transforms are only performed right before data is given to the model
# data is kept in its untransformed form
transforms:
# data doesn't usually come with a neighborlist -- this tranforms prepares the neighborlist
- _target_: nequip.data.transforms.NeighborListTransform
r_max: ${cutoff_radius}
# the models only know atom types, which can be different from the chemical species (e.g. C, H)
# for instance we can have data with different charge states of carbon, which means they are
# all labeled by chemical species `C`, but may have different atom type labels based on the charge states
# in this case, the atom types are the same as the chemical species, but we still have to include this
# transformation to ensure that the data has 0-indexed atom type lists used in the various model operations
- _target_: nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper
model_type_names: ${model_type_names}
# the following are torch.utils.data.DataLoader configs excluding the arguments `dataset` and `collate_fn`
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
train_dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 5
num_workers: 5
shuffle: true
val_dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 10
num_workers: ${data.train_dataloader.num_workers} # we want to use the same num_workers -- variable interpolation helps
test_dataloader: ${data.val_dataloader} # variable interpolation comes in handy again
# dataset statistics can be calculated to be used for model initialization such as for shifting, scaling and standardizing.
# it is advised to provide custom names -- you will have to retrieve them later under model to initialize certain parameters to the dataset statistics computed
stats_manager:
# dataset statistics is handled by the `DataStatisticsManager`
# here, we use `CommonDataStatisticsManager` for a basic set of dataset statistics for general use cases
# the dataset statistics include `num_neighbors_mean`, `per_atom_energy_mean`, `forces_rms`, `per_type_forces_rms`
_target_: nequip.data.CommonDataStatisticsManager
# dataloader kwargs for data statistics computation
# `batch_size` should ideally be as large as possible without trigerring OOM
dataloader_kwargs:
batch_size: 10
# we need to provide the same type names that correspond to the model's `type_names`
# so we interpolate the "central source of truth" model type names from above
type_names: ${model_type_names}
# `trainer` (mandatory) is a Lightning.Trainer object (https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api)
trainer:
_target_: lightning.Trainer
accelerator: auto
enable_checkpointing: true
max_epochs: 10
max_time: 03:00:00:00
log_every_n_steps: 1 # how often to log
# use any Lightning supported logger
logger:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
name: tutorial_log
save_dir: ${results_dir}
flush_logs_every_n_steps: 100
# use any Lightning callbacks https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
# and any custom callbakcs that subclass Lightning's Callback parent class
callbacks:
# Common callbacks used in ML
# stop training when some criterion is met
- _target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val0_epoch/weighted_sum # validation metric to monitor
min_delta: 1e-3 # how much to be considered a "change"
patience: 20 # how many instances of "no change" before stopping
# checkpoint based on some criterion
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
monitor: val0_epoch/weighted_sum # validation metric to monitor
dirpath: ${results_dir}
filename: best # best.ckpt is the checkpoint name
save_last: true # last.ckpt will be saved
# log learning rate, e.g. to monitor what the learning rate scheduler is doing
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
# training_module refers to a NequIPLightningModule
training_module:
_target_: nequip.train.EMALightningModule
# We are using an EMA model (i.e. we keep a separate model whose weights are an exponential moving average of the base model's weights)
# The use of an EMA model is configured by setting `ema_decay` to be a float (e.g. 0.999) under `training_module` (it is a `NequIPLightningModule` argument). The default of `ema_decay` is None, which means an EMA model is not used, if `ema_decay` is not explicitly configured
# EMA allows for smoother validation curves and thus more reliable metrics for monitoring
# Loading from a checkpoint for use in the `nequip.ase.NequIPCalculator` or during `nequip-compile` and `nequip-package` will always load the EMA model if it's present
ema_decay: 0.999
# here, we use a simplified MetricsManager wrapper (see docs) to construct the energy-force loss function
# the more general `nequip.train.MetricsManager` could also be used to configure a custom loss function
loss:
_target_: nequip.train.EnergyForceLoss
per_atom_energy: true
coeffs:
total_energy: 1.0
forces: 1.0
# again, we use a simplified MetricsManager wrapper (see docs) to construct the energy-force metrics
# the more general `nequip.train.MetricsManager` could also be used in this case
# validation metrics are used for monitoring and influencing training, e.g. with LR schedulers or early stopping, etc
val_metrics:
_target_: nequip.train.EnergyForceMetrics
coeffs:
total_energy_mae: 1.0
forces_mae: 1.0
# keys `total_energy_rmse` and `forces_rmse`, `per_atom_energy_rmse` and `per_atom_energy_mae` are also available
# we could have train_metrics and test_metrics be different from val_metrics, but it makes sense to have them be the same
train_metrics: ${training_module.val_metrics} # use variable interpolation
test_metrics: ${training_module.val_metrics} # use variable interpolation
# any torch compatible optimizer: https://pytorch.org/docs/stable/optim.html#algorithms
optimizer:
_target_: torch.optim.AdamW
lr: 0.002
weight_decay: 1e-8
amsgrad: false
# see options for lr_scheduler_config
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers
lr_scheduler:
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.1
patience: 5
min_lr: 1e-6
monitor: val0_epoch/weighted_sum
interval: epoch
frequency: 1
model:
_target_: nequip.model.ModelFromPackage
package_path: ${model_package_path}