Skip to content

Commit b62d576

Browse files
authored
add readme example for confidences (#84)
1 parent 2079db2 commit b62d576

File tree

1 file changed

+41
-0
lines changed

1 file changed

+41
-0
lines changed

README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,47 @@ print("Optimized Energy:", atoms.get_potential_energy())
159159
Or you can use it to run MD simulations. The script, an example input xyz file and a Colab notebook demonstration are available in the [examples directory.](./examples) This should work with any input, simply modify the input_file and cell_size parameters. We recommend using constant volume simulations.
160160

161161

162+
#### Confidence head (Orb-v3 Models Only)
163+
164+
Orb-v3 models have a confidence head which produces a per-atom discrete confidence measure based on a classifier head which learns to predict the binned MAE between predicted and true forces during training. This classifier head has 50 bins, linearly spaced between 0 and 0.4A.
165+
166+
167+
```python
168+
import ase
169+
from ase.build import molecule
170+
from seaborn import heatmap # optional, for visualization only
171+
import matplotlib.pyplot as plt # optional, for visualization only
172+
import numpy
173+
174+
from orb_models.forcefield import pretrained
175+
from orb_models.forcefield.calculator import ORBCalculator
176+
177+
device="cpu" # or device="cuda"
178+
# or choose another model using ORB_PRETRAINED_MODELS[model_name]()
179+
orbff = pretrained.orb_v3_conservative_inf_omat(
180+
device=device,
181+
)
182+
calc = ORBCalculator(orbff, device=device)
183+
# Use a molecule (OOD for Orb, so confidence plot is
184+
# more interesting than a bulk crystal)
185+
atoms = molecule("CH3CH2Cl")
186+
atoms.calc = calc
187+
188+
forces = atoms.get_forces()
189+
confidences = calc.results["confidence"]
190+
predicted_bin_per_atom = numpy.argmax(confidences, axis=-1)
191+
192+
print(forces.shape, confidences.shape) # (num_atoms, 3), (num_atoms, 50)
193+
print(predicted_bin_per_atom) # List of length num_atoms
194+
heatmap(confidences)
195+
plt.xlabel('Confidence Bin')
196+
plt.ylabel('Atom Index')
197+
plt.title('Confidence Heatmap')
198+
plt.show()
199+
200+
```
201+
202+
162203
### Floating Point Precision
163204

164205
As shown in usage snippets above, we support 3 floating point precision types: `"float32-high"`, `"float32-highest"` and `"float64"`.

0 commit comments

Comments
 (0)