|
| 1 | +#include <Eigen/Core> |
| 2 | +#include <Eigen/Dense> |
| 3 | + |
| 4 | +#include <alpaka/alpaka.hpp> |
| 5 | + |
| 6 | +#define CATCH_CONFIG_MAIN |
| 7 | +#include <catch2/catch_all.hpp> |
| 8 | + |
| 9 | +#include "DataFormats/SoATemplate/interface/SoABlocks.h" |
| 10 | +#include "DataFormats/Portable/interface/PortableCollection.h" |
| 11 | +#include "HeterogeneousCore/AlpakaInterface/interface/config.h" |
| 12 | +#include "HeterogeneousCore/AlpakaInterface/interface/memory.h" |
| 13 | +#include "HeterogeneousCore/AlpakaInterface/interface/workdivision.h" |
| 14 | + |
| 15 | +using namespace ALPAKA_ACCELERATOR_NAMESPACE; |
| 16 | + |
| 17 | +// This test checks the correctness of using SoABlocks with PortableCollections. |
| 18 | + |
| 19 | +GENERATE_SOA_LAYOUT(NodesT, SOA_COLUMN(int, id), SOA_SCALAR(int, count)) |
| 20 | + |
| 21 | +using Nodes = NodesT<>; |
| 22 | + |
| 23 | +GENERATE_SOA_LAYOUT(EdgesT, SOA_COLUMN(int, src), SOA_COLUMN(int, dst), SOA_COLUMN(float, cost), SOA_SCALAR(int, count)) |
| 24 | + |
| 25 | +using Edges = EdgesT<>; |
| 26 | + |
| 27 | +GENERATE_SOA_BLOCKS(GraphT, SOA_BLOCK(nodes, NodesT), SOA_BLOCK(edges, EdgesT)) |
| 28 | + |
| 29 | +using Graph = GraphT<>; |
| 30 | +using GraphView = Graph::View; |
| 31 | +using GraphConstView = Graph::ConstView; |
| 32 | + |
| 33 | +// Fill SoAs |
| 34 | +struct FillSoAs { |
| 35 | + ALPAKA_FN_ACC void operator()(Acc1D const& acc, Nodes::View nodes, Edges::View edges) const { |
| 36 | + const int N = static_cast<int>(nodes.metadata().size()); |
| 37 | + const int E = static_cast<int>(edges.metadata().size()); |
| 38 | + |
| 39 | + // Fill nodes with the indexes |
| 40 | + for (auto i : cms::alpakatools::uniform_elements(acc, nodes.metadata().size())) { |
| 41 | + nodes[i].id() = static_cast<int>(i); |
| 42 | + } |
| 43 | + if (cms::alpakatools::once_per_grid(acc)) { |
| 44 | + nodes.count() = N; |
| 45 | + } |
| 46 | + |
| 47 | + // Fill edges with some arbitrary but deterministic values |
| 48 | + for (auto j : cms::alpakatools::uniform_elements(acc, edges.metadata().size())) { |
| 49 | + int src = static_cast<int>(j % N); |
| 50 | + int dst = static_cast<int>((j * 7 + 3) % N); |
| 51 | + edges[j].src() = src; |
| 52 | + edges[j].dst() = dst; |
| 53 | + edges[j].cost() = 0.5f * float(src + dst); |
| 54 | + } |
| 55 | + if (cms::alpakatools::once_per_grid(acc)) { |
| 56 | + edges.count() = E; |
| 57 | + } |
| 58 | + } |
| 59 | +}; |
| 60 | + |
| 61 | +// Fill SoABlocks |
| 62 | +struct FillBlocks { |
| 63 | + ALPAKA_FN_ACC void operator()(Acc1D const& acc, GraphView blocksView) const { |
| 64 | + const int N = static_cast<int>(blocksView.nodes().metadata().size()); |
| 65 | + const int E = static_cast<int>(blocksView.edges().metadata().size()); |
| 66 | + |
| 67 | + // Fill nodes with the indexes |
| 68 | + for (auto i : cms::alpakatools::uniform_elements(acc, blocksView.nodes().metadata().size())) { |
| 69 | + blocksView.nodes()[i].id() = static_cast<int>(i); |
| 70 | + } |
| 71 | + if (cms::alpakatools::once_per_grid(acc)) { |
| 72 | + blocksView.nodes().count() = N; |
| 73 | + } |
| 74 | + |
| 75 | + // Fill edges with some arbitrary but deterministic values |
| 76 | + for (auto j : cms::alpakatools::uniform_elements(acc, blocksView.edges().metadata().size())) { |
| 77 | + int src = static_cast<int>(j % N); |
| 78 | + int dst = static_cast<int>((j * 7 + 3) % N); |
| 79 | + blocksView.edges()[j].src() = src; |
| 80 | + blocksView.edges()[j].dst() = dst; |
| 81 | + blocksView.edges()[j].cost() = 0.5f * float(src + dst); |
| 82 | + } |
| 83 | + if (cms::alpakatools::once_per_grid(acc)) { |
| 84 | + blocksView.edges().count() = E; |
| 85 | + } |
| 86 | + } |
| 87 | +}; |
| 88 | + |
| 89 | +TEST_CASE("SoABlocks minimal graph in heterogeneous environment") { |
| 90 | + auto const& devices = cms::alpakatools::devices<Platform>(); |
| 91 | + if (devices.empty()) { |
| 92 | + std::cout << "No devices available for the " << EDM_STRINGIZE(ALPAKA_ACCELERATOR_NAMESPACE) |
| 93 | + << " backend, skipping.\n"; |
| 94 | + return; |
| 95 | + } |
| 96 | + |
| 97 | + for (auto const& device : devices) { |
| 98 | + std::cout << "Running on " << alpaka::getName(device) << std::endl; |
| 99 | + Queue queue(device); |
| 100 | + |
| 101 | + // Number of elements |
| 102 | + const int N = 50; |
| 103 | + const int E = 120; |
| 104 | + |
| 105 | + // Portable Collections for SoAs |
| 106 | + PortableCollection<Nodes, Device> nodesCollection(N, queue); |
| 107 | + PortableCollection<Edges, Device> edgesCollection(E, queue); |
| 108 | + Nodes::View& nodesCollectionView = nodesCollection.view(); |
| 109 | + Edges::View& edgesCollectionView = edgesCollection.view(); |
| 110 | + |
| 111 | + // Portable Collection for SoABlocks |
| 112 | + PortableCollection<Graph, Device> graphCollection(queue, N, E); |
| 113 | + GraphView& graphCollectionView = graphCollection.view(); |
| 114 | + |
| 115 | + // Work division |
| 116 | + const std::size_t blockSize = 256; |
| 117 | + const std::size_t maxElems = std::max<std::size_t>(N, E); |
| 118 | + const std::size_t numberOfBlocks = cms::alpakatools::divide_up_by(maxElems, blockSize); |
| 119 | + const auto workDiv = cms::alpakatools::make_workdiv<Acc1D>(numberOfBlocks, blockSize); |
| 120 | + |
| 121 | + // Fill: separate e blocks |
| 122 | + alpaka::exec<Acc1D>(queue, workDiv, FillSoAs{}, nodesCollectionView, edgesCollectionView); |
| 123 | + alpaka::exec<Acc1D>(queue, workDiv, FillBlocks{}, graphCollectionView); |
| 124 | + alpaka::wait(queue); |
| 125 | + |
| 126 | + // Check results on host |
| 127 | + PortableHostCollection<Nodes> nodesHost(N, cms::alpakatools::host()); |
| 128 | + PortableHostCollection<Edges> edgesHost(E, cms::alpakatools::host()); |
| 129 | + PortableHostCollection<Graph> graphHost(cms::alpakatools::host(), N, E); |
| 130 | + |
| 131 | + alpaka::memcpy(queue, nodesHost.buffer(), nodesCollection.buffer()); |
| 132 | + alpaka::memcpy(queue, edgesHost.buffer(), edgesCollection.buffer()); |
| 133 | + alpaka::memcpy(queue, graphHost.buffer(), graphCollection.buffer()); |
| 134 | + alpaka::wait(queue); |
| 135 | + |
| 136 | + const Nodes::ConstView nodesHostView = nodesHost.const_view(); |
| 137 | + const Edges::ConstView edgesHostView = edgesHost.const_view(); |
| 138 | + const GraphConstView graphHostView = graphHost.const_view(); |
| 139 | + |
| 140 | + // Nodes |
| 141 | + REQUIRE(graphHostView.nodes().count() == N); |
| 142 | + for (int i = 0; i < N; ++i) { |
| 143 | + REQUIRE(graphHostView.nodes()[i].id() == nodesHostView[i].id()); |
| 144 | + REQUIRE(graphHostView.nodes()[i].id() == i); |
| 145 | + } |
| 146 | + |
| 147 | + // Edges |
| 148 | + REQUIRE(graphHostView.edges().count() == E); |
| 149 | + for (int j = 0; j < E; ++j) { |
| 150 | + REQUIRE(graphHostView.edges()[j].src() == edgesHostView[j].src()); |
| 151 | + REQUIRE(graphHostView.edges()[j].dst() == edgesHostView[j].dst()); |
| 152 | + REQUIRE(graphHostView.edges()[j].cost() == edgesHostView[j].cost()); |
| 153 | + |
| 154 | + int src = j % N; |
| 155 | + int dst = (j * 7 + 3) % N; |
| 156 | + REQUIRE(graphHostView.edges()[j].src() == src); |
| 157 | + REQUIRE(graphHostView.edges()[j].dst() == dst); |
| 158 | + REQUIRE_THAT(graphHostView.edges()[j].cost(), Catch::Matchers::WithinAbs(0.5f * float(src + dst), 1e-6)); |
| 159 | + } |
| 160 | + } |
| 161 | +} |
0 commit comments