diff --git a/examples/two_dimensional_distribution.cpp b/examples/two_dimensional_distribution.cpp index aca0171..c78e53e 100644 --- a/examples/two_dimensional_distribution.cpp +++ b/examples/two_dimensional_distribution.cpp @@ -33,27 +33,20 @@ int main() auto integrand = hep::make_integrand( gauss, 2, - hep::distribution_parameters{100, 100, min, max, min, max, ""} + hep::distribution_parameters{100, 100, min, max, min, max, "gauss"} ); // now integrate and record the differential distributions - auto const chkpt = hep::plain(integrand, std::vector(1, 10000000)); + auto const chkpt = hep::plain( + integrand, + std::vector(1, 10000000), + hep::make_plain_chkpt(), + hep::callback>(hep::callback_mode::verbose_and_write_chkpt, "dist_chkpt") + ); + auto const result = chkpt.results().back(); // integral is zero - std::cout << "integral is I = " << result.value() << " +- " << result.error() << "\n\n"; - - auto const& distribution = result.distributions()[0]; - auto const& mid_points_x = hep::mid_points_x(distribution); - auto const& mid_points_y = hep::mid_points_y(distribution); - - std::cout.setf(std::ios_base::scientific); - - // print the distribution - for (std::size_t i = 0; i != mid_points_x.size(); ++i) - { - std::cout << mid_points_x[i] << '\t' << mid_points_y[i] << '\t' - << distribution.results()[i].value() << '\t' - << distribution.results()[i].error() << '\n'; - } + std::cout << "integral is I = " << result.value() << " +- " << result.error() << "\n\n" + "to view the distribution use the checkpoint viewer: `chkpt dists dist_chkpt`\n"; } diff --git a/meson.build b/meson.build index 33dfa2a..452bcae 100644 --- a/meson.build +++ b/meson.build @@ -31,4 +31,5 @@ if get_option('examples') subdir('examples') endif +subdir('src') subdir('tests') diff --git a/src/chkpt.cpp b/src/chkpt.cpp new file mode 100644 index 0000000..5c360a8 --- /dev/null +++ b/src/chkpt.cpp @@ -0,0 +1,38 @@ +#include "read_type.hpp" +#include "operations.hpp" + +#include +#include +#include +#include +#include + +int main(int argc, char* argv[]) +{ + std::ios_base::sync_with_stdio(false); + + if (argc < 3) + { + std::cerr << hep::operations_help_string(); + + return 1; + } + + // capture all arguments as strings, neglecting the first, second, and last argument + std::vector arguments(&argv[2], &argv[argc - 1]); + + try + { + std::string operation_name{argv[1]}; + std::ifstream file{argv[argc - 1]}; + auto const& type = hep::read_type(file); + + hep::operations().at(type)(operation_name, arguments, file); + } + catch (std::runtime_error const& exception) + { + std::cerr << "Error: " << exception.what() << '\n'; + + return 1; + } +} diff --git a/src/make_chkpt.hpp b/src/make_chkpt.hpp new file mode 100644 index 0000000..883646c --- /dev/null +++ b/src/make_chkpt.hpp @@ -0,0 +1,71 @@ +#ifndef MAKE_CHKPT_HPP +#define MAKE_CHKPT_HPP + +#include "hep/mc/chkpt.hpp" +#include "hep/mc/multi_channel_chkpt.hpp" +#include "hep/mc/plain_chkpt.hpp" +#include "hep/mc/vegas_chkpt.hpp" + +namespace hep +{ + +template +hep::chkpt_with_rng make_chkpt(std::istream& in); + +template <> +hep::plain_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_plain_chkpt(in); +} + +template <> +hep::plain_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_plain_chkpt(in); +} + +template <> +hep::plain_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_plain_chkpt(in); +} + +template <> +hep::vegas_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_vegas_chkpt(in); +} + +template <> +hep::vegas_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_vegas_chkpt(in); +} + +template <> +hep::vegas_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_vegas_chkpt(in); +} + +template <> +hep::multi_channel_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_multi_channel_chkpt(in); +} + +template <> +hep::multi_channel_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_multi_channel_chkpt(in); +} + +template <> +hep::multi_channel_chkpt_with_rng make_chkpt>(std::istream& in) +{ + return hep::make_multi_channel_chkpt(in); +} + +} + +#endif diff --git a/src/meson.build b/src/meson.build new file mode 100644 index 0000000..ab898ed --- /dev/null +++ b/src/meson.build @@ -0,0 +1,8 @@ +chkpt_srcs = [ + 'chkpt.cpp', + 'operations.cpp', + 'read_type.cpp', + 'stream_rng.cpp', +] + +executable('chkpt', chkpt_srcs, dependencies : hep_mc_dep) diff --git a/src/operations.cpp b/src/operations.cpp new file mode 100644 index 0000000..f2aa36c --- /dev/null +++ b/src/operations.cpp @@ -0,0 +1,263 @@ +#include "operations.hpp" +#include "stream_rng.hpp" +#include "make_chkpt.hpp" + +#include "hep/mc/callback.hpp" +#include "hep/mc/chkpt.hpp" +#include "hep/mc/multi_channel_chkpt.hpp" +#include "hep/mc/multi_channel_result.hpp" +#include "hep/mc/plain_chkpt.hpp" +#include "hep/mc/plain_result.hpp" +#include "hep/mc/vegas_chkpt.hpp" +#include "hep/mc/vegas_result.hpp" + +#include +#include +#include + +namespace +{ + +template +int perform_iter( + std::vector const& arguments, + hep::chkpt_with_rng const& chkpt +) { + if (!arguments.empty()) + { + std::cerr << "Warning: additional arguments ignored\n"; + } + + std::cout << chkpt.results().size() << '\n'; + + return 0; +} + +template +int perform_print( + std::vector const& arguments, + hep::chkpt_with_rng const& chkpt +) { + auto const io_flags = std::cout.flags(); + auto const io_precision = std::cout.precision(); + + for (std::size_t i = 0; i != arguments.size(); ++i) + { + if (arguments.at(i) == "-p") + { + if (i + 1 == arguments.size()) + { + throw std::runtime_error("argument " + arguments.at(i) + " is missing the " + "`precision` parameter"); + } + + unsigned long precision; + + try + { + precision = std::stoul(arguments.at(++i)); + } + catch (std::invalid_argument const& exception) + { + throw std::runtime_error("argument " + arguments.at(i) + " could not be converted " + " to a number"); + } + + std::cout.precision(precision); + } + else if (arguments.at(i) == "-s") + { + std::cout.setf(std::ios_base::scientific, std::ios_base::floatfield); + } + else + { + std::cerr << "Warning: additional argument `" + arguments.at(i) + "` ignored\n"; + } + } + + using T = typename Chkpt::result_type::numeric_type; + + hep::callback callback{hep::callback_mode::verbose, "", T()}; + + for (std::size_t i = 0; i != chkpt.results().size(); ++i) + { + auto copy = chkpt; + copy.rollback(i + 1); + callback(copy); + } + + std::cout.precision(io_precision); + std::cout.flags(io_flags); + + return 0; +} + +template +int perform_dists( + std::vector const& arguments, + hep::chkpt_with_rng const& chkpt +) { + auto const io_flags = std::cout.flags(); + auto const io_precision = std::cout.precision(); + + for (std::size_t i = 0; i != arguments.size(); ++i) + { + if (arguments.at(i) == "-p") + { + if (i + 1 == arguments.size()) + { + throw std::runtime_error("argument " + arguments.at(i) + " is missing the " + "`precision` parameter"); + } + + unsigned long precision; + + try + { + precision = std::stoul(arguments.at(++i)); + } + catch (std::invalid_argument const& exception) + { + throw std::runtime_error("argument " + arguments.at(i) + " could not be converted " + " to a number"); + } + + std::cout.precision(precision); + } + else if (arguments.at(i) == "-s") + { + std::cout.setf(std::ios_base::scientific, std::ios_base::floatfield); + } + else + { + std::cerr << "Warning: additional argument `" + arguments.at(i) + "` ignored\n"; + } + } + + if (chkpt.results().empty() || chkpt.results().front().distributions().empty()) + { + return 0; + } + + std::size_t const distributions = chkpt.results().front().distributions().size(); + + for (std::size_t i = 0; i != distributions; ++i) + { + auto const& parameters = chkpt.results().front().distributions().at(i).parameters(); + + std::cout << "# " << parameters.name() << '\n'; + + auto const& mid_points_x = hep::mid_points_x(chkpt.results().front().distributions().at(i)); + auto const& mid_points_y = hep::mid_points_y(chkpt.results().front().distributions().at(i)); + + for (std::size_t j = 0; j != mid_points_x.size(); ++j) + { + std::cout << mid_points_x.at(j); + + if (parameters.bins_y() > 1) + { + std::cout << ' ' << mid_points_y.at(j); + } + + for (auto const& result : chkpt.results()) + { + auto const& mc_result = result.distributions().at(i).results().at(j); + + std::cout << ' ' << mc_result.value() << ' ' << mc_result.error() << ' ' + << mc_result.calls(); + } + + std::cout << '\n'; + } + + if (i != (distributions - 1)) + { + std::cout << '\n'; + } + } + + std::cout.precision(io_precision); + std::cout.flags(io_flags); + + return 0; +} + +template +int dispatch_operations( + std::string const& operation, + std::vector const& arguments, + std::istream& in +) { + auto const& chkpt = hep::make_chkpt(in); + + if (operation == "iter") + { + return perform_iter(arguments, chkpt); + } + else if (operation == "print") + { + return perform_print(arguments, chkpt); + } + else if (operation == "dists") + { + return perform_dists(arguments, chkpt); + } + else + { + throw std::runtime_error("operation `" + operation + "`not recognized"); + } +} + +} + +namespace hep +{ + +std::unordered_map const& operations() +{ + static std::unordered_map operations_ = { + { std::type_index(typeid(multi_channel_result)), + dispatch_operations> }, + { std::type_index(typeid(multi_channel_result)), + dispatch_operations> }, + { std::type_index(typeid(multi_channel_result)), + dispatch_operations> }, + { std::type_index(typeid(plain_result)), + dispatch_operations> }, + { std::type_index(typeid(plain_result)), + dispatch_operations> }, + { std::type_index(typeid(plain_result)), + dispatch_operations> }, + { std::type_index(typeid(vegas_result)), + dispatch_operations> }, + { std::type_index(typeid(vegas_result)), + dispatch_operations> }, + { std::type_index(typeid(vegas_result)), + dispatch_operations> }, + }; + + return operations_; +} + +std::string operations_help_string() +{ + return /* 72 character limit /////////////////////////////////////////////// */ + "Usage: chkpt [opt_args...] file\n" + " Commands:\n" + " - iter:\n" + " returns the number of iterations stored in this checkpoint file\n" + " - print:\n" + " Uses the standard callback function to print all results of this\n" + " checkpoint. This command accepts the optional parameters `-p` and\n" + " `-s`\n" + " - dists:\n" + " Prints all distributions collected in the checkpoint. This command\n" + " accepts the optional parameters `-p` and `-s`\n" + " Optional arguments:\n" + " - `-s`:\n" + " switches the output to the scientific format\n" + " - `-p `:\n" + " parameter to set the precision\n"; +} + +} diff --git a/src/operations.hpp b/src/operations.hpp new file mode 100644 index 0000000..024e16f --- /dev/null +++ b/src/operations.hpp @@ -0,0 +1,23 @@ +#ifndef OPERATIONS_HPP +#define OPERATIONS_HPP + +#include +#include +#include +#include +#include +#include + +namespace hep +{ + +using chkpt_operation = + std::function const&, std::istream& in)>; + +std::unordered_map const& operations(); + +std::string operations_help_string(); + +} + +#endif diff --git a/src/read_type.cpp b/src/read_type.cpp new file mode 100644 index 0000000..1db83d9 --- /dev/null +++ b/src/read_type.cpp @@ -0,0 +1,145 @@ +#include "read_type.hpp" + +#include "hep/mc/mc_result.hpp" +#include "hep/mc/multi_channel_result.hpp" +#include "hep/mc/plain_result.hpp" +#include "hep/mc/vegas_result.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace hep +{ + +std::type_index read_type(std::istream& in) +{ + auto const pos = in.tellg(); + std::string line; + std::getline(in, line); + in.seekg(pos); + std::istringstream stream{line}; + std::string token; + std::vector tokens; + tokens.reserve(4); + + while (std::getline(stream, token, ' ')) + { + tokens.push_back(token); + } + + if ((tokens.size() != 4) || (tokens.at(0) != "#")) + { + throw std::runtime_error("no header found/invalid header"); + } + + if ((tokens.at(1) != "plain_result") && (tokens.at(1) != "multi_channel_result") && + (tokens.at(1) != "vegas_result")) + { + throw std::runtime_error("invalid result type"); + } + + unsigned long version; + unsigned long max_digits10; + + try + { + version = std::stoul(tokens.at(2)); + } + catch (...) + { + throw std::runtime_error("could not parse `version` field"); + } + + if (version > 1) + { + throw std::runtime_error("version not supported"); + } + + try + { + max_digits10 = std::stoul(tokens.at(3)); + } + catch (...) + { + throw std::runtime_error("could parse `max_digits10` field"); + } + + constexpr auto max_digits10_float = std::numeric_limits::max_digits10; + constexpr auto max_digits10_double = std::numeric_limits::max_digits10; + constexpr auto max_digits10_long_double = std::numeric_limits::max_digits10; + + if (max_digits10 <= max_digits10_float) + { + if (max_digits10 != max_digits10_float) + { + std::cerr << "Warning: `max_digits10` is " << max_digits10 << " but `float` has " + << max_digits10_float; + } + + if (tokens.at(1) == "multi_channel_result") + { + return std::type_index(typeid(hep::multi_channel_result)); + } + else if (tokens.at(1) == "plain_result") + { + return std::type_index(typeid(hep::plain_result)); + } + else if (tokens.at(1) == "vegas_result") + { + return std::type_index(typeid(hep::vegas_result)); + } + } + else if (max_digits10 <= max_digits10_double) + { + if (max_digits10 != max_digits10_double) + { + std::cerr << "Warning: `max_digits10` is " << max_digits10 << " but `double` has " + << max_digits10_double; + } + + if (tokens.at(1) == "multi_channel_result") + { + return std::type_index(typeid(hep::multi_channel_result)); + } + else if (tokens.at(1) == "plain_result") + { + return std::type_index(typeid(hep::plain_result)); + } + else if (tokens.at(1) == "vegas_result") + { + return std::type_index(typeid(hep::vegas_result)); + } + } + else + { + if (max_digits10 != max_digits10_long_double) + { + std::cerr << "Warning: `max_digits10` is " << max_digits10 + << " but `long double` has " << max_digits10_long_double; + } + + if (tokens.at(1) == "multi_channel_result") + { + return std::type_index(typeid(hep::multi_channel_result)); + } + else if (tokens.at(1) == "plain_result") + { + return std::type_index(typeid(hep::plain_result)); + } + else if (tokens.at(1) == "vegas_result") + { + return std::type_index(typeid(hep::vegas_result)); + } + } + + // if this happens, we didn't cover all the cases + assert( false ); +} + +} diff --git a/src/read_type.hpp b/src/read_type.hpp new file mode 100644 index 0000000..cfbc5fc --- /dev/null +++ b/src/read_type.hpp @@ -0,0 +1,15 @@ +#ifndef READ_TYPE_HPP +#define READ_TYPE_HPP + +#include +#include +#include + +namespace hep +{ + +std::type_index read_type(std::istream& in); + +} + +#endif diff --git a/src/stream_rng.cpp b/src/stream_rng.cpp new file mode 100644 index 0000000..88e739b --- /dev/null +++ b/src/stream_rng.cpp @@ -0,0 +1,36 @@ +#include "stream_rng.hpp" + +#include + +namespace hep +{ + +stream_rng::stream_rng() = default; + +std::string const& stream_rng::state() const +{ + return state_; +} + +void stream_rng::state(std::string const& state) +{ + state_ = state; +} + +std::istream& operator>>(std::istream& in, stream_rng& rng) +{ + std::string state; + std::getline(in, state); + rng.state(state); + + return in; +} + +std::ostream& operator<<(std::ostream& out, stream_rng const& rng) +{ + out << rng.state(); + + return out; +} + +} diff --git a/src/stream_rng.hpp b/src/stream_rng.hpp new file mode 100644 index 0000000..0cf67d1 --- /dev/null +++ b/src/stream_rng.hpp @@ -0,0 +1,29 @@ +#ifndef STREAM_RNG_HPP +#define STREAM_RNG_HPP + +#include +#include + +namespace hep +{ + +class stream_rng +{ +public: + stream_rng(); + + std::string const& state() const; + + void state(std::string const& state); + +private: + std::string state_; +}; + +std::istream& operator>>(std::istream& in, stream_rng& rng); + +std::ostream& operator<<(std::ostream& in, stream_rng const& rng); + +} + +#endif