diff --git a/tasks/mpi/runner.cpp b/tasks/mpi/runner.cpp index bca7df559..dc5256667 100644 --- a/tasks/mpi/runner.cpp +++ b/tasks/mpi/runner.cpp @@ -2,25 +2,54 @@ #include #include +#include class UnreadMessagesDetector : public ::testing::EmptyTestEventListener { public: - UnreadMessagesDetector(boost::mpi::communicator world) : world_(std::move(world)) {} + UnreadMessagesDetector(boost::mpi::communicator com) : com_(std::move(com)) {} void OnTestEnd(const ::testing::TestInfo& test_info) override { - world_.barrier(); - if (const auto msg = world_.iprobe(boost::mpi::any_source, boost::mpi::any_tag)) { + com_.barrier(); + if (const auto msg = com_.iprobe(boost::mpi::any_source, boost::mpi::any_tag)) { fprintf( stderr, "[ PROCESS %d ] [ FAILED ] %s.%s: MPI message queue has an unread message from process %d with tag %d\n", - world_.rank(), test_info.test_suite_name(), test_info.name(), msg->source(), msg->tag()); + com_.rank(), test_info.test_suite_name(), test_info.name(), msg->source(), msg->tag()); exit(2); } - world_.barrier(); + com_.barrier(); } private: - boost::mpi::communicator world_; + boost::mpi::communicator com_; +}; + +class WorkerTestFailurePrinter : public ::testing::EmptyTestEventListener { + public: + WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener> base, boost::mpi::communicator com) + : base_(std::move(base)), com_(std::move(com)) {} + + void OnTestEnd(const ::testing::TestInfo& test_info) override { + if (test_info.result()->Passed()) { + return; + } + PrintProcessRank(); + base_->OnTestEnd(test_info); + } + + void OnTestPartResult(const ::testing::TestPartResult& test_part_result) override { + if (test_part_result.passed() || test_part_result.skipped()) { + return; + } + PrintProcessRank(); + base_->OnTestPartResult(test_part_result); + } + + private: + void PrintProcessRank() const { printf(" [ PROCESS %d ] ", com_.rank()); } + + std::shared_ptr<::testing::TestEventListener> base_; + boost::mpi::communicator com_; }; int main(int argc, char** argv) { @@ -28,10 +57,13 @@ int main(int argc, char** argv) { boost::mpi::communicator world; ::testing::InitGoogleTest(&argc, argv); + auto& listeners = ::testing::UnitTest::GetInstance()->listeners(); - if (world.rank() != 0) { - delete listeners.Release(listeners.default_result_printer()); + if (world.rank() != 0 && (argc < 2 || argv[1] != std::string("--print-workers"))) { + auto* listener = listeners.Release(listeners.default_result_printer()); + listeners.Append(new WorkerTestFailurePrinter(std::shared_ptr<::testing::TestEventListener>(listener), world)); } listeners.Append(new UnreadMessagesDetector(world)); + return RUN_ALL_TESTS(); }