-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathconfig.yaml
More file actions
234 lines (198 loc) · 12.5 KB
/
config.yaml
File metadata and controls
234 lines (198 loc) · 12.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# The config file is divided into 4 sections -- `data`, `train`, `model`, and `global_options`
# The config system relies on omegaconf (https://omegaconf.readthedocs.io/en/2.3_branch/index.html)
# and hydra (https://hydra.cc/docs/intro/) functionalities, such as
# - omegaconf's variable interpolation (https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#variable-interpolation)
# - omegaconf's resolvers (https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#resolvers)
# - hydra's instantiate (https://hydra.cc/docs/advanced/instantiate_objects/overview/)
# With hydra's instantiation (notice the `_target_`s everywhere), the config file (almost) directly corresponds to instantiating objects as one would normally do in Python.
# Much of the infrastructure is based on PyTorch Lightning (https://lightning.ai/docs/pytorch/stable/), such as the use of Lightning's Trainer, DataModule, LightningModule, Callback objects.
# ===========
# RUN
# ===========
# the run types will be completed in sequence
# one can do `train`, `val`, `test` run types
run: [train, test]
# the following parameters (cutoff_radius, chemical_symbols, model_type_names) are not used direcly by the code
# parameters that take thier values show up multiple times in the config, so this allows us to use
# variable interpolation to keep their multiple instances consistent
# data and model r_max can be different (model's r_max should be smaller), but we try to make them the same
cutoff_radius: 5.0
# There are two sets of atomic types to keep track of in most applications
# -- there is the conventional atomic species (e.g. C, H), and a separate `type_names` known to the model.
# The model only knows types based on a set of zero-based indices and user-given `type_names` argument.
# An example where this distinction is necessary include datasets with the same atomic species with different charge states:
# we could define `chemical_symbols: [C, C]` and model `type_names: [C3, C4]` for +3 and +4 charge states.
# There could also be instances such as coarse graining we only care about the model's `type_names` (no need to define chemical species).
# Because of this distinction, these variables show up as arguments across different categories, including, data, model, metrics and even callbacks.
# In this case, we fix both to be the same, so we define a single set of each here and use variable interpolation to retrieve them below.
# This ensures a single location where the values are set to reduce the chances of mis-configuring runs.
chemical_symbols: [Si]
model_type_names: ${chemical_symbols}
# ============
# DATA
# ============
# `data` is managed by `LightningDataModule`s
# NequIP provides some standard datamodules that can be found in `nequip.data.datamodule`
# Users are free to define and use their own datamodules that subclass nequip.data.datamodule.NequIPDataModule
data:
_target_: nequip.data.datamodule.ASEDataModule
seed: 456 # dataset seed for reproducibility
# here we take an ASE-readable file (in extxyz format) and split it into train:val:test = 80:10:10
split_dataset:
file_path: ./sitraj.xyz
train: 0.8
val: 0.1
test: 0.1
# `transforms` convert data from the Dataset to a form that can be used by the ML model
# the transforms are only performed right before data is given to the model
# data is kept in its untransformed form
transforms:
# data doesn't usually come with a neighborlist -- this tranforms prepares the neighborlist
- _target_: nequip.data.transforms.NeighborListTransform
r_max: ${cutoff_radius}
# the models only know atom types, which can be different from the chemical species (e.g. C, H)
# for instance we can have data with different charge states of carbon, which means they are
# all labeled by chemical species `C`, but may have different atom type labels based on the charge states
# in this case, the atom types are the same as the chemical species, but we still have to include this
# transformation to ensure that the data has 0-indexed atom type lists used in the various model operations
- _target_: nequip.data.transforms.ChemicalSpeciesToAtomTypeMapper
model_type_names: ${model_type_names}
chemical_species_to_atom_type_map: ${list_to_identity_dict:${chemical_symbols}}
# the following are torch.utils.data.DataLoader configs excluding the arguments `dataset` and `collate_fn`
# https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
train_dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 5
num_workers: 5
shuffle: true
val_dataloader:
_target_: torch.utils.data.DataLoader
batch_size: 10
num_workers: ${data.train_dataloader.num_workers} # we want to use the same num_workers -- variable interpolation helps
test_dataloader: ${data.val_dataloader} # variable interpolation comes in handy again
# dataset statistics can be calculated to be used for model initialization such as for shifting, scaling and standardizing.
# it is advised to provide custom names -- you will have to retrieve them later under model to initialize certain parameters to the dataset statistics computed
stats_manager:
# dataset statistics is handled by the `DataStatisticsManager`
# here, we use `CommonDataStatisticsManager` for a basic set of dataset statistics for general use cases
# the dataset statistics include `num_neighbors_mean`, `per_atom_energy_mean`, `forces_rms`, `per_type_forces_rms`
_target_: nequip.data.CommonDataStatisticsManager
# dataloader kwargs for data statistics computation
# `batch_size` should ideally be as large as possible without trigerring OOM
dataloader_kwargs:
batch_size: 10
# we need to provide the same type names that correspond to the model's `type_names`
# so we interpolate the "central source of truth" model type names from above
type_names: ${model_type_names}
# `trainer` (mandatory) is a Lightning.Trainer object (https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api)
trainer:
_target_: lightning.Trainer
accelerator: auto
enable_checkpointing: true
max_epochs: 1000
max_time: 03:00:00:00
log_every_n_steps: 1 # how often to log
# use any Lightning supported logger
logger:
_target_: lightning.pytorch.loggers.csv_logs.CSVLogger
name: tutorial_log
save_dir: ./results # use resolver to place wandb logs in hydra's output directory
flush_logs_every_n_steps: 100
# use any Lightning callbacks https://lightning.ai/docs/pytorch/stable/api_references.html#callbacks
# and any custom callbakcs that subclass Lightning's Callback parent class
callbacks:
# Common callbacks used in ML
# stop training when some criterion is met
- _target_: lightning.pytorch.callbacks.EarlyStopping
monitor: val0_epoch/weighted_sum # validation metric to monitor
min_delta: 1e-3 # how much to be considered a "change"
patience: 20 # how many instances of "no change" before stopping
# checkpoint based on some criterion
- _target_: lightning.pytorch.callbacks.ModelCheckpoint
monitor: val0_epoch/weighted_sum # validation metric to monitor
dirpath: ./results
filename: best # best.ckpt is the checkpoint name
save_last: true # last.ckpt will be saved
# log learning rate, e.g. to monitor what the learning rate scheduler is doing
- _target_: lightning.pytorch.callbacks.LearningRateMonitor
logging_interval: epoch
# training_module refers to a NequIPLightningModule
training_module:
_target_: nequip.train.EMALightningModule
# We are using an EMA model (i.e. we keep a separate model whose weights are an exponential moving average of the base model's weights)
# The use of an EMA model is configured by setting `ema_decay` to be a float (e.g. 0.999) under `training_module` (it is a `NequIPLightningModule` argument). The default of `ema_decay` is None, which means an EMA model is not used, if `ema_decay` is not explicitly configured
# EMA allows for smoother validation curves and thus more reliable metrics for monitoring
# Loading from a checkpoint for use in the `nequip.ase.NequIPCalculator` or during `nequip-compile` and `nequip-package` will always load the EMA model if it's present
ema_decay: 0.999
# here, we use a simplified MetricsManager wrapper (see docs) to construct the energy-force loss function
# the more general `nequip.train.MetricsManager` could also be used to configure a custom loss function
loss:
_target_: nequip.train.EnergyForceLoss
per_atom_energy: true
coeffs:
total_energy: 1.0
forces: 1.0
# again, we use a simplified MetricsManager wrapper (see docs) to construct the energy-force metrics
# the more general `nequip.train.MetricsManager` could also be used in this case
# validation metrics are used for monitoring and influencing training, e.g. with LR schedulers or early stopping, etc
val_metrics:
_target_: nequip.train.EnergyForceMetrics
coeffs:
total_energy_mae: 1.0
forces_mae: 1.0
# keys `total_energy_rmse` and `forces_rmse`, `per_atom_energy_rmse` and `per_atom_energy_mae` are also available
# we could have train_metrics and test_metrics be different from val_metrics, but it makes sense to have them be the same
train_metrics: ${training_module.val_metrics} # use variable interpolation
test_metrics: ${training_module.val_metrics} # use variable interpolation
# any torch compatible optimizer: https://pytorch.org/docs/stable/optim.html#algorithms
optimizer:
_target_: torch.optim.Adam
lr: 0.03
# see options for lr_scheduler_config
# https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.configure_optimizers
lr_scheduler:
# any torch compatible lr sceduler
scheduler:
_target_: torch.optim.lr_scheduler.ReduceLROnPlateau
factor: 0.6
patience: 5
threshold: 0.2
min_lr: 1e-6
monitor: val0_epoch/weighted_sum
interval: epoch
frequency: 1
# model details
model:
_target_: nequip.model.NequIPGNNModel
# == basic model params ==
seed: 456
model_dtype: float32
type_names: ${model_type_names}
r_max: ${cutoff_radius}
# == bessel encoding ==
num_bessels: 8 # number of basis functions used in the radial Bessel basis, the default of 8 usually works well
bessel_trainable: false # set true to train the bessel weights (default false)
polynomial_cutoff_p: 6 # p-exponent used in polynomial cutoff function, smaller p corresponds to stronger decay with distance
# == convnet layers ==
num_layers: 3 # number of interaction blocks, we find 3-5 to work best
l_max: 1 # the maximum irrep order (rotation order) for the network's features, l=1 is a good default, l=2 is more accurate but slower
parity: true # whether to include features with odd mirror parity; often turning parity off gives equally good results but faster networks, so do consider this
num_features: 32 # the multiplicity of the features, 32 is a good default for accurate network, if you want to be more accurate, go larger, if you want to be faster, go lower
# == radial network ==
radial_mlp_depth: 2 # number of radial layers, usually 1-3 works best, smaller is faster
radial_mlp_width: 64 # number of hidden neurons in radial function, smaller is faster
# dataset statistics used to inform the model's initial parameters for normalization, shifting and rescaling
# we use omegaconf's resolvers (https://omegaconf.readthedocs.io/en/2.3_branch/usage.html#resolvers)
# to facilitate getting the dataset statistics from the DataStatisticsManager
# average number of neighbors for edge sum normalization
avg_num_neighbors: ${training_data_stats:num_neighbors_mean}
# == per-type per-atom scales and shifts ==
per_type_energy_scales: ${training_data_stats:per_type_forces_rms}
per_type_energy_shifts: ${training_data_stats:per_atom_energy_mean}
per_type_energy_scales_trainable: false
per_type_energy_shifts_trainable: false
# == ZBL pair potential ==
pair_potential:
_target_: nequip.nn.pair_potential.ZBL
units: metal # LAMMPS unit names; allowed values "metal" and "real"
chemical_species: ${chemical_symbols} # must tell ZBL the chemical species of the various model atom types