|
22 | 22 | "cell_type": "markdown",
|
23 | 23 | "metadata": {},
|
24 | 24 | "source": [
|
25 |
| - "As a first step, we will run the installation of the *mlip* library directly from pip. We also install the appropriate Jax CUDA backend to run on GPU (comment it out to run on CPU). In this notebook, we will not run any simulation and therefore do not install Jax-MD, for details on how to do so, please refer to our *simulation* tutorial. Note that if you have ran another tutorial in the same environment, this installation is not required. Please refer to [our installation page](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/installation/index.html) for more information." |
| 25 | + "As a first step, we will run the installation of the *mlip* library directly from pip. We also install the appropriate Jax CUDA backend to run on GPU (comment it out to run on CPU). In this notebook, we will not run any simulation and therefore do not install Jax-MD, for details on how to do so, please refer to our *simulation* tutorial. Note that if you have ran another tutorial in the same environment, this installation is not required. Please refer to [our installation page](https://instadeepai.github.io/mlip/installation/index.html) for more information." |
26 | 26 | ]
|
27 | 27 | },
|
28 | 28 | {
|
|
74 | 74 | "- [`MLIPNetwork`][MLIPNetwork] is a base class for GNNs that **computes node-wise energy** summands from edge vectors, node species, and graph edges passed as `senders` and `receivers` index arrays.\n",
|
75 | 75 | "- [`ForceFieldPredictor`][ForceFieldPredictor] is a generic wrapper around any [`MLIPNetwork`][MLIPNetwork].\n",
|
76 | 76 | "\n",
|
77 |
| - " It gathers **total energy, forces (and, if required, stress)** in the [`Prediction`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/prediction.html) dataclass, by summing the node energies obtained from [`MLIPNetwork`][MLIPNetwork] on a [`jraph.GraphsTuple`](https://jraph.readthedocs.io/en/latest/api.html) object, and differentiating with respect to positions (and unit cell).\n", |
| 77 | + " It gathers **total energy, forces (and, if required, stress)** in the [`Prediction`](https://instadeepai.github.io/mlip/api_reference/models/prediction.html) dataclass, by summing the node energies obtained from [`MLIPNetwork`][MLIPNetwork] on a [`jraph.GraphsTuple`](https://jraph.readthedocs.io/en/latest/api.html) object, and differentiating with respect to positions (and unit cell).\n", |
78 | 78 | "\n",
|
79 | 79 | "\n",
|
80 |
| - "For convenience, our training loop and simulation engines finally work with [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) objects that **wrap a force field predictor and its learnable parameters within a frozen dataclass object**.\n", |
| 80 | + "For convenience, our training loop and simulation engines finally work with [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) objects that **wrap a force field predictor and its learnable parameters within a frozen dataclass object**.\n", |
81 | 81 | "\n",
|
82 | 82 | "For illustration, in this notebook we will\n",
|
83 | 83 | "\n",
|
84 | 84 | "2. Define a very simple model that returns constant energies,\n",
|
85 | 85 | "3. Define a more involved GNN model without equivariance constraints.\n",
|
86 | 86 | "\n",
|
87 |
| - "[MLIPNetwork]: https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/mlip_network.html\n", |
88 |
| - "[ForceFieldPredictor]: https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/predictor.html\n", |
89 |
| - "[ForceField]: https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html\n", |
90 |
| - "[Prediction]: https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/prediction.html" |
| 87 | + "[MLIPNetwork]: https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html\n", |
| 88 | + "[ForceFieldPredictor]: https://instadeepai.github.io/mlip/api_reference/models/predictor.html\n", |
| 89 | + "[ForceField]: https://instadeepai.github.io/mlip/api_reference/models/force_field.html\n", |
| 90 | + "[Prediction]: https://instadeepai.github.io/mlip/api_reference/models/prediction.html" |
91 | 91 | ]
|
92 | 92 | },
|
93 | 93 | {
|
|
102 | 102 | "\n",
|
103 | 103 | "### a. *Config and DatasetInfo*\n",
|
104 | 104 | "\n",
|
105 |
| - "To facilitate model loading and saving, our [`MLIPNetwork`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/mlip_network.html) class **gathers (almost) all of their hyperparameters within a `pydantic.BaseModel` subclass**. Their class attribute `.Config` points to this configuration class. Only exceptions consist of hyperparameters that are data dependent, and might\n", |
| 105 | + "To facilitate model loading and saving, our [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) class **gathers (almost) all of their hyperparameters within a `pydantic.BaseModel` subclass**. Their class attribute `.Config` points to this configuration class. Only exceptions consist of hyperparameters that are data dependent, and might\n", |
106 | 106 | "conflict with the data processing pipeline.\n",
|
107 | 107 | "\n",
|
108 |
| - "This is why [`MLIPNetwork`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/mlip_network.html) **also accept a [`DatasetInfo`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/data/dataset_info.html) object** upon initialization, that notably stores:\n", |
| 108 | + "This is why [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) **also accept a [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html) object** upon initialization, that notably stores:\n", |
109 | 109 | "- `cutoff_distance_angstrom : float`\n",
|
110 | 110 | "- `atomic_energies_map : dict[int, float]`\n",
|
111 | 111 | "- `avg_num_neighbours : float`\n",
|
112 | 112 | "- and some other data computed when processing the dataset.\n",
|
113 | 113 | "\n",
|
114 |
| - "This way, we are sure that our models can only be used in the context they were trained for, and will not be evaluated e.g. on atomic numbers they have never seen. We create a dummy [`DatasetInfo`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/data/dataset_info.html) for the purpose of this example:" |
| 114 | + "This way, we are sure that our models can only be used in the context they were trained for, and will not be evaluated e.g. on atomic numbers they have never seen. We create a dummy [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html) for the purpose of this example:" |
115 | 115 | ]
|
116 | 116 | },
|
117 | 117 | {
|
|
197 | 197 | "source": [
|
198 | 198 | "### c. *Constant force field*\n",
|
199 | 199 | "\n",
|
200 |
| - "Now that we have defined this simple `ConstantMLIP` subclass, we can already define a state-holding [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) object. The quickest (but slightly opaque) way is to use the helper classmethod `ForceField.from_mlip_network()`:" |
| 200 | + "Now that we have defined this simple `ConstantMLIP` subclass, we can already define a state-holding [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object. The quickest (but slightly opaque) way is to use the helper classmethod `ForceField.from_mlip_network()`:" |
201 | 201 | ]
|
202 | 202 | },
|
203 | 203 | {
|
|
237 | 237 | "source": [
|
238 | 238 | "For the sake of transparency, let us detail what is actually being done here.\n",
|
239 | 239 | "\n",
|
240 |
| - "First, a [`ForceFieldPredictor`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/predictor.html) instance is created on top of the `constant_mlip` model.\n", |
| 240 | + "First, a [`ForceFieldPredictor`](https://instadeepai.github.io/mlip/api_reference/models/predictor.html) instance is created on top of the `constant_mlip` model.\n", |
241 | 241 | "\n",
|
242 |
| - "Then, random parameters are initialized by calling the predictor's `.init()` method on a random seed and a dummy graph. These two objects (the predictor and its parameter dict) are wrapped for convenience inside the [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) dataclass. The following is thus equivalent:" |
| 242 | + "Then, random parameters are initialized by calling the predictor's `.init()` method on a random seed and a dummy graph. These two objects (the predictor and its parameter dict) are wrapped for convenience inside the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) dataclass. The following is thus equivalent:" |
243 | 243 | ]
|
244 | 244 | },
|
245 | 245 | {
|
|
265 | 265 | "id": "Po9NjwTo3tvX"
|
266 | 266 | },
|
267 | 267 | "source": [
|
268 |
| - "We'll see below how to manually initialize parameters, and call the [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) default constructor : this only requires an input graph.\n", |
| 268 | + "We'll see below how to manually initialize parameters, and call the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) default constructor : this only requires an input graph.\n", |
269 | 269 | "\n",
|
270 |
| - "**N.B.** The [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) dataclass is frozen: this is to prevent any stateful operations to be performed on the parameters, which would be incompatible with JAX compilation and tracing mechanisms. You can think of [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) as holding the _state_ of a learnable [`ForceFieldPredictor`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/predictor.html), although _it remains immutable_." |
| 270 | + "**N.B.** The [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) dataclass is frozen: this is to prevent any stateful operations to be performed on the parameters, which would be incompatible with JAX compilation and tracing mechanisms. You can think of [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) as holding the _state_ of a learnable [`ForceFieldPredictor`](https://instadeepai.github.io/mlip/api_reference/models/predictor.html), although _it remains immutable_." |
271 | 271 | ]
|
272 | 272 | },
|
273 | 273 | {
|
|
360 | 360 | "source": [
|
361 | 361 | "### e. *Wrapping the model state in ForceField*\n",
|
362 | 362 | "\n",
|
363 |
| - "In order to hide the `flax` logic for downstream applications, our `TrainingLoop` class takes in and returns a [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) object that simply wraps the predictor with its initial and final parameters respectively.\n", |
| 363 | + "In order to hide the `flax` logic for downstream applications, our `TrainingLoop` class takes in and returns a [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object that simply wraps the predictor with its initial and final parameters respectively.\n", |
364 | 364 | "\n",
|
365 |
| - "This frozen dataclass can then be easily passed to the [`SimulationEngine`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/simulation/simulation_engine.html), or just saved for later (by JSON-serializing the MLIPNetwork's `.config` and `.dataset_info`, and dumping the flattened parameter dict as `.npz`)." |
| 365 | + "This frozen dataclass can then be easily passed to the [`SimulationEngine`](https://instadeepai.github.io/mlip/api_reference/simulation/simulation_engine.html), or just saved for later (by JSON-serializing the MLIPNetwork's `.config` and `.dataset_info`, and dumping the flattened parameter dict as `.npz`)." |
366 | 366 | ]
|
367 | 367 | },
|
368 | 368 | {
|
|
393 | 393 | "id": "m0lNIXkOovaI"
|
394 | 394 | },
|
395 | 395 | "source": [
|
396 |
| - "Note that [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) instances are also callable, and morally equivalent to `functools.partial(predictor.apply, params)`.\n", |
| 396 | + "Note that [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) instances are also callable, and morally equivalent to `functools.partial(predictor.apply, params)`.\n", |
397 | 397 | "\n",
|
398 | 398 | "This means they can be directly evaluated on a graph by forgetting about the (frozen) learnable parameters, as done during simulation."
|
399 | 399 | ]
|
|
420 | 420 | "id": "_AnG7n_vszX5"
|
421 | 421 | },
|
422 | 422 | "source": [
|
423 |
| - "In theory, the [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) class is duck-typed for the [`SimulationEngine`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/simulation/simulation_engine.html), and you could provide any other object with the following methods and properties (e.g. to wrap models defined in another JAX framework):\n", |
| 423 | + "In theory, the [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) class is duck-typed for the [`SimulationEngine`](https://instadeepai.github.io/mlip/api_reference/simulation/simulation_engine.html), and you could provide any other object with the following methods and properties (e.g. to wrap models defined in another JAX framework):\n", |
424 | 424 | "- `.__call__(graph: GraphsTuple) -> Prediction`\n",
|
425 | 425 | "- `.cutoff_distance: float`\n",
|
426 | 426 | "- `.allowed_atomic_numbers: set[int]`\n",
|
|
472 | 472 | "id": "0MdUJqneXoay"
|
473 | 473 | },
|
474 | 474 | "source": [
|
475 |
| - "Having defined our config, we can now create our MLIP model class. Our custom model must inherits the [`MLIPNetwork`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/mlip_network.html) class, which is itself a `flax.linen.Module` object. As such, we can easily define our network using flax `@nn.compact` decorator, see [the flax docs](https://flax-linen.readthedocs.io/en/latest/quick_start.html) for more information.\n", |
| 475 | + "Having defined our config, we can now create our MLIP model class. Our custom model must inherits the [`MLIPNetwork`](https://instadeepai.github.io/mlip/api_reference/models/mlip_network.html) class, which is itself a `flax.linen.Module` object. As such, we can easily define our network using flax `@nn.compact` decorator, see [the flax docs](https://flax-linen.readthedocs.io/en/latest/quick_start.html) for more information.\n", |
476 | 476 | "\n",
|
477 |
| - "Our model must also have a dataset_info attribute of type [`DatasetInfo`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/data/dataset_info.html). This object encapsulates the relevant informations about the dataset at hand that can be used to create the model. For instance, this attribute contains the average number of neighbors per atom in the dataset, which is used in models like [MACE](https://arxiv.org/pdf/2206.07697) to normalize the messages passed to each nodes.\n", |
| 477 | + "Our model must also have a dataset_info attribute of type [`DatasetInfo`](https://instadeepai.github.io/mlip/api_reference/data/dataset_info.html). This object encapsulates the relevant informations about the dataset at hand that can be used to create the model. For instance, this attribute contains the average number of neighbors per atom in the dataset, which is used in models like [MACE](https://arxiv.org/pdf/2206.07697) to normalize the messages passed to each nodes.\n", |
478 | 478 | "\n",
|
479 | 479 | "We provide a very simple example of MPNN below, which computes messages through an `MLP` encoding of sender and receiver features with edge distances."
|
480 | 480 | ]
|
|
583 | 583 | "id": "MgYcKGGXZZmO"
|
584 | 584 | },
|
585 | 585 | "source": [
|
586 |
| - "Having defined both our model and its associated config classes, we can now instantiate our model and turn it into a [`ForceField`](https://mlip-jax-dot-int-research-tpu.uc.r.appspot.com/api_reference/models/force_field.html) object that can be used for training and simulations." |
| 586 | + "Having defined both our model and its associated config classes, we can now instantiate our model and turn it into a [`ForceField`](https://instadeepai.github.io/mlip/api_reference/models/force_field.html) object that can be used for training and simulations." |
587 | 587 | ]
|
588 | 588 | },
|
589 | 589 | {
|
|
0 commit comments