Skip to content

Commit 40b1661

Browse files
GardevoirXLuthaf
andauthored
Allow model to request additional inputs from the engine (#123)
Co-authored-by: Guillaume Fraux <guillaume.fraux@epfl.ch>
1 parent 19380a9 commit 40b1661

File tree

16 files changed

+485
-137
lines changed

16 files changed

+485
-137
lines changed

docs/src/outputs/masses.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,15 @@ inputs, and must adhere to the following metadata schema:
4040
- masses must have a single property dimension named
4141
``"masses"``, with a single entry set to ``0``.
4242

43-
At the moment, masses are not integrated into any simulation engines.
43+
The following simulation engine can provide ``"masses"`` as inputs to the models.
44+
45+
.. grid:: 1 3 3 3
46+
47+
.. grid-item-card::
48+
:text-align: center
49+
:padding: 1
50+
:link: engine-ase
51+
:link-type: ref
52+
53+
|ase-logo|
54+

docs/src/outputs/momenta.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,15 @@ outputs, and must adhere to the following metadata schema:
4848
- momenta must have a single property dimension named
4949
``"momenta"``, with a single entry set to ``0``.
5050

51-
At the moment, momenta are not integrated into any simulation engines.
51+
The following simulation engine can provide ``"momenta"`` as inputs to the models.
52+
53+
.. grid:: 1 3 3 3
54+
55+
.. grid-item-card::
56+
:text-align: center
57+
:padding: 1
58+
:link: engine-ase
59+
:link-type: ref
60+
61+
|ase-logo|
62+

docs/src/outputs/velocities.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,14 @@ inputs, and must adhere to the following metadata schema:
4343
- velocities must have a single property dimension named
4444
``"velocities"``, with a single entry set to ``0``.
4545

46-
At the moment, velocities are not integrated into any simulation engines.
46+
The following simulation engine can provide ``"velocities"`` as inputs to the models.
47+
48+
.. grid:: 1 3 3 3
49+
50+
.. grid-item-card::
51+
:text-align: center
52+
:padding: 1
53+
:link: engine-ase
54+
:link-type: ref
55+
56+
|ase-logo|

docs/src/torch/reference/misc.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,7 @@ one of the registered unit.
3434
+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+
3535
| **momentum** | ``u*A/fs``, ``u*A/ps``, ``(eV*u)^(1/2)``, ``kg*m/s`` |
3636
+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+
37+
| **mass** | ``u`` (``u``, ``Dalton``), ``kg`` (``kg``, ``kilogram``), ``g`` (``g``, ``gram``) |
38+
+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+
39+
| **velocity** | ``nm/fs``, ``A/fs``, ``m/s``, ``nm/ps`` |
40+
+----------------+------------------------------------------------------------------------------------------------------------------------------------------------------+

metatomic-torch/include/metatomic/torch/misc.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ namespace metatomic_torch {
2020
METATOMIC_TORCH_EXPORT std::string version();
2121

2222
/// Select the best device according to the list of `model_devices` from a
23-
/// model, the user-provided `desired_device` and what's available on the
23+
/// model, the user-provided `desired_device` and what's available on the
2424
/// current machine.
2525
///
2626
/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT
@@ -64,6 +64,19 @@ inline System load_system_buffer(const torch::Tensor& data) {
6464
return load_system_buffer(ptr, n);
6565
}
6666

67+
namespace details {
68+
69+
/// Validate that the given `name` is valid for a model output/input
70+
///
71+
/// The function returns a tuple with:
72+
/// - a boolean indicating whether this is a known output/input
73+
/// - the name of the base output/input (empty if custom)
74+
/// - the name of the variant (empty if none)
75+
std::tuple<bool, std::string, std::string> validate_name_and_check_variant(
76+
const std::string& name
77+
);
78+
}
79+
6780
}
6881

6982
#endif

metatomic-torch/include/metatomic/torch/model_output.hpp

Whitespace-only changes.

metatomic-torch/src/misc.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <algorithm>
88
#include <stdexcept>
99
#include <string>
10+
#include <tuple>
1011
#include <vector>
1112
#include <cstring>
1213

@@ -405,4 +406,100 @@ System load_system_buffer(const uint8_t* data, size_t size) {
405406
return read_system_from_zip(zr);
406407
}
407408

409+
410+
/// Known inputs and outputs
411+
inline std::unordered_set<std::string> KNOWN_INPUTS_OUTPUTS = {
412+
"energy",
413+
"energy_ensemble",
414+
"energy_uncertainty",
415+
"features",
416+
"non_conservative_forces",
417+
"non_conservative_stress",
418+
"positions",
419+
"momenta",
420+
"velocities",
421+
"masses"
422+
};
423+
424+
std::tuple<bool, std::string, std::string> details::validate_name_and_check_variant(
425+
const std::string& name
426+
) {
427+
if (KNOWN_INPUTS_OUTPUTS.find(name) != KNOWN_INPUTS_OUTPUTS.end()) {
428+
// known output, nothing to do
429+
return {true, name, ""};
430+
}
431+
432+
auto double_colon = name.rfind("::");
433+
if (double_colon != std::string::npos) {
434+
if (double_colon == 0 || double_colon == (name.length() - 2)) {
435+
C10_THROW_ERROR(ValueError,
436+
"Invalid name for model output: '" + name + "'. "
437+
"Non-standard names should look like '<domain>::<output>' "
438+
"with non-empty domain and output."
439+
);
440+
}
441+
442+
auto custom_name = name.substr(0, double_colon);
443+
auto output_name = name.substr(double_colon + 2);
444+
445+
auto slash = custom_name.find('/');
446+
if (slash != std::string::npos) {
447+
// "domain/variant::custom" is not allowed
448+
C10_THROW_ERROR(ValueError,
449+
"Invalid name for model output: '" + name + "'. "
450+
"Non-standard name with variant should look like "
451+
"'<domain>::<output>/<variant>'"
452+
);
453+
}
454+
455+
slash = output_name.find('/');
456+
if (slash != std::string::npos) {
457+
if (slash == 0 || slash == (name.length() - 1)) {
458+
C10_THROW_ERROR(ValueError,
459+
"Invalid name for model output: '" + name + "'. "
460+
"Non-standard name with variant should look like "
461+
"'<domain>::<output>/<variant>' with non-empty domain, "
462+
"output and variant."
463+
);
464+
}
465+
}
466+
467+
// this is a custom output, nothing more to check
468+
return {false, "", ""};
469+
}
470+
471+
auto slash = name.find('/');
472+
if (slash != std::string::npos) {
473+
if (slash == 0 || slash == (name.length() - 1)) {
474+
C10_THROW_ERROR(ValueError,
475+
"Invalid name for model output: '" + name + "'. "
476+
"Variant names should look like '<output>/<variant>' "
477+
"with non-empty output and variant."
478+
);
479+
}
480+
481+
auto base = name.substr(0, slash);
482+
auto double_colon = base.rfind("::");
483+
if (double_colon != std::string::npos) {
484+
// we don't do anything for custom outputs
485+
return {false, "", ""};
486+
}
487+
488+
if (KNOWN_INPUTS_OUTPUTS.find(base) == KNOWN_INPUTS_OUTPUTS.end()) {
489+
C10_THROW_ERROR(ValueError,
490+
"Invalid name for model output with variant: '" + name + "'. "
491+
"'" + base + "' is not a known output."
492+
);
493+
}
494+
495+
return {true, base, name};
496+
}
497+
498+
C10_THROW_ERROR(ValueError,
499+
"Invalid name for model output: '" + name + "' is not a known output. "
500+
"Variant names should be of the form '<output>/<variant>'. "
501+
"Non-standard names should have the form '<domain>::<output>'."
502+
);
503+
}
504+
408505
} // namespace metatomic_torch

metatomic-torch/src/model.cpp

Lines changed: 7 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -147,101 +147,19 @@ ModelOutput ModelOutputHolder::from_json(std::string_view json) {
147147

148148
/******************************************************************************/
149149

150-
std::unordered_set<std::string> KNOWN_OUTPUTS = {
151-
"energy",
152-
"energy_ensemble",
153-
"energy_uncertainty",
154-
"features",
155-
"non_conservative_forces",
156-
"non_conservative_stress",
157-
"positions",
158-
"momenta"
159-
};
160150

161151
void ModelCapabilitiesHolder::set_outputs(torch::Dict<std::string, ModelOutput> outputs) {
162152

163153
std::unordered_map<std::string, std::vector<std::string>> variants;
164-
165154
for (const auto& it: outputs) {
166-
const auto& name = it.key();
167-
if (KNOWN_OUTPUTS.find(name) != KNOWN_OUTPUTS.end()) {
168-
// known output, nothing to do
169-
variants[name].push_back(name);
170-
continue;
171-
}
172-
173-
auto double_colon = name.rfind("::");
174-
if (double_colon != std::string::npos) {
175-
if (double_colon == 0 || double_colon == (name.length() - 2)) {
176-
C10_THROW_ERROR(ValueError,
177-
"Invalid name for model output: '" + name + "'. "
178-
"Non-standard names should look like '<domain>::<output>' "
179-
"with non-empty domain and output."
180-
);
181-
}
182-
183-
auto custom_name = name.substr(0, double_colon);
184-
auto output_name = name.substr(double_colon + 2);
185-
186-
auto slash = custom_name.find('/');
187-
if (slash != std::string::npos) {
188-
// "domain/variant::custom" is not allowed
189-
C10_THROW_ERROR(ValueError,
190-
"Invalid name for model output: '" + name + "'. "
191-
"Non-standard name with variant should look like "
192-
"'<domain>::<output>/<variant>'"
193-
);
194-
}
195-
196-
slash = output_name.find('/');
197-
if (slash != std::string::npos) {
198-
if (slash == 0 || slash == (name.length() - 1)) {
199-
C10_THROW_ERROR(ValueError,
200-
"Invalid name for model output: '" + name + "'. "
201-
"Non-standard name with variant should look like "
202-
"'<domain>::<output>/<variant>' with non-empty domain, "
203-
"output and variant."
204-
);
205-
}
206-
}
207-
208-
// this is a custom output, nothing more to check
209-
continue;
210-
}
211-
212-
auto slash = name.find('/');
213-
if (slash != std::string::npos) {
214-
if (slash == 0 || slash == (name.length() - 1)) {
215-
C10_THROW_ERROR(ValueError,
216-
"Invalid name for model output: '" + name + "'. "
217-
"Variant names should look like '<output>/<variant>' "
218-
"with non-empty output and variant."
219-
);
220-
}
221-
222-
auto base = name.substr(0, slash);
223-
auto double_colon = base.rfind("::");
224-
if (double_colon != std::string::npos) {
225-
// we don't do anything for custom outputs
226-
continue;
227-
}
228-
229-
if (KNOWN_OUTPUTS.find(base) == KNOWN_OUTPUTS.end()) {
230-
C10_THROW_ERROR(ValueError,
231-
"Invalid name for model output with variant: '" + name + "'. "
232-
"'" + base + "' is not a known output."
233-
);
155+
auto [is_standard, base, variant] = details::validate_name_and_check_variant(it.key());
156+
if (is_standard) {
157+
if (variant.empty()) {
158+
variants[base].emplace_back(base);
159+
} else {
160+
variants[base].emplace_back(variant);
234161
}
235-
236-
variants[base].push_back(name);
237-
continue;
238-
}
239-
240-
C10_THROW_ERROR(ValueError,
241-
"Invalid name for model output: '" + name + "' is not a known output. "
242-
"Variant names should be of the form '<output>/<variant>'. "
243-
"Non-standard names should have the form '<domain>::<output>'."
244-
);
162+
};
245163
}
246164

247165
// check descriptions for each variant group

0 commit comments

Comments
 (0)