|
7 | 7 | #include <algorithm> |
8 | 8 | #include <stdexcept> |
9 | 9 | #include <string> |
| 10 | +#include <tuple> |
10 | 11 | #include <vector> |
11 | 12 | #include <cstring> |
12 | 13 |
|
@@ -405,4 +406,100 @@ System load_system_buffer(const uint8_t* data, size_t size) { |
405 | 406 | return read_system_from_zip(zr); |
406 | 407 | } |
407 | 408 |
|
| 409 | + |
| 410 | +/// Known inputs and outputs |
| 411 | +inline std::unordered_set<std::string> KNOWN_INPUTS_OUTPUTS = { |
| 412 | + "energy", |
| 413 | + "energy_ensemble", |
| 414 | + "energy_uncertainty", |
| 415 | + "features", |
| 416 | + "non_conservative_forces", |
| 417 | + "non_conservative_stress", |
| 418 | + "positions", |
| 419 | + "momenta", |
| 420 | + "velocities", |
| 421 | + "masses" |
| 422 | +}; |
| 423 | + |
| 424 | +std::tuple<bool, std::string, std::string> details::validate_name_and_check_variant( |
| 425 | + const std::string& name |
| 426 | +) { |
| 427 | + if (KNOWN_INPUTS_OUTPUTS.find(name) != KNOWN_INPUTS_OUTPUTS.end()) { |
| 428 | + // known output, nothing to do |
| 429 | + return {true, name, ""}; |
| 430 | + } |
| 431 | + |
| 432 | + auto double_colon = name.rfind("::"); |
| 433 | + if (double_colon != std::string::npos) { |
| 434 | + if (double_colon == 0 || double_colon == (name.length() - 2)) { |
| 435 | + C10_THROW_ERROR(ValueError, |
| 436 | + "Invalid name for model output: '" + name + "'. " |
| 437 | + "Non-standard names should look like '<domain>::<output>' " |
| 438 | + "with non-empty domain and output." |
| 439 | + ); |
| 440 | + } |
| 441 | + |
| 442 | + auto custom_name = name.substr(0, double_colon); |
| 443 | + auto output_name = name.substr(double_colon + 2); |
| 444 | + |
| 445 | + auto slash = custom_name.find('/'); |
| 446 | + if (slash != std::string::npos) { |
| 447 | + // "domain/variant::custom" is not allowed |
| 448 | + C10_THROW_ERROR(ValueError, |
| 449 | + "Invalid name for model output: '" + name + "'. " |
| 450 | + "Non-standard name with variant should look like " |
| 451 | + "'<domain>::<output>/<variant>'" |
| 452 | + ); |
| 453 | + } |
| 454 | + |
| 455 | + slash = output_name.find('/'); |
| 456 | + if (slash != std::string::npos) { |
| 457 | + if (slash == 0 || slash == (name.length() - 1)) { |
| 458 | + C10_THROW_ERROR(ValueError, |
| 459 | + "Invalid name for model output: '" + name + "'. " |
| 460 | + "Non-standard name with variant should look like " |
| 461 | + "'<domain>::<output>/<variant>' with non-empty domain, " |
| 462 | + "output and variant." |
| 463 | + ); |
| 464 | + } |
| 465 | + } |
| 466 | + |
| 467 | + // this is a custom output, nothing more to check |
| 468 | + return {false, "", ""}; |
| 469 | + } |
| 470 | + |
| 471 | + auto slash = name.find('/'); |
| 472 | + if (slash != std::string::npos) { |
| 473 | + if (slash == 0 || slash == (name.length() - 1)) { |
| 474 | + C10_THROW_ERROR(ValueError, |
| 475 | + "Invalid name for model output: '" + name + "'. " |
| 476 | + "Variant names should look like '<output>/<variant>' " |
| 477 | + "with non-empty output and variant." |
| 478 | + ); |
| 479 | + } |
| 480 | + |
| 481 | + auto base = name.substr(0, slash); |
| 482 | + auto double_colon = base.rfind("::"); |
| 483 | + if (double_colon != std::string::npos) { |
| 484 | + // we don't do anything for custom outputs |
| 485 | + return {false, "", ""}; |
| 486 | + } |
| 487 | + |
| 488 | + if (KNOWN_INPUTS_OUTPUTS.find(base) == KNOWN_INPUTS_OUTPUTS.end()) { |
| 489 | + C10_THROW_ERROR(ValueError, |
| 490 | + "Invalid name for model output with variant: '" + name + "'. " |
| 491 | + "'" + base + "' is not a known output." |
| 492 | + ); |
| 493 | + } |
| 494 | + |
| 495 | + return {true, base, name}; |
| 496 | + } |
| 497 | + |
| 498 | + C10_THROW_ERROR(ValueError, |
| 499 | + "Invalid name for model output: '" + name + "' is not a known output. " |
| 500 | + "Variant names should be of the form '<output>/<variant>'. " |
| 501 | + "Non-standard names should have the form '<domain>::<output>'." |
| 502 | + ); |
| 503 | +} |
| 504 | + |
408 | 505 | } // namespace metatomic_torch |
0 commit comments