-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_dag.cpp
More file actions
52 lines (43 loc) · 1.94 KB
/
test_dag.cpp
File metadata and controls
52 lines (43 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <iostream>
#include "dag.h"
int main() {
// create graph
Graph graph;
// create node
GraphNode multiplyNode(ComputeType::CPU, [](auto& inputs, auto& outputs) {
double inputVal = std::get<double>(inputs["multiplyin"]);
outputs["multiplyout"] = inputVal * 2;
});
multiplyNode.addInput("multiplyin", DataContainer()); // set multiplyNode's input field
multiplyNode.addOutput("multiplyout", DataContainer()); // set multiplyNode's output field
GraphNode divideNode(ComputeType::CPU, [](auto& inputs, auto& outputs) {
double inputVal = std::get<double>(inputs["multiplyout"]);
outputs["divideout"] = inputVal / 10;
});
divideNode.addInput("multiplyout", DataContainer()); // set divideNode's input field
divideNode.addOutput("divideout", DataContainer()); // set divideNode's output field
// add node to graph
size_t multiplyNodeId = graph.addNode(multiplyNode);
size_t divideNodeId = graph.addNode(divideNode);
// try to add edge
std::cout << "Adding edge multiplyNode -> divideNode: " << (graph.addEdge(multiplyNodeId, divideNodeId) ? "Success" : "Failed") << "\n";
// input MiniBatch
std::vector<std::unordered_map<std::string, MiniBatch>> inputBatches = {
{{"multiplyin", MiniBatch({1.0, 2.0, 3.0})}}
};
std::cout << "executor start" << std::endl;
// create executor
Executor executor(graph, inputBatches);
executor.run();
std::cout << "reach end" << std::endl;
// output MiniBatch
for (size_t batchId = 0; batchId < inputBatches.size(); ++batchId) {
std::cout << "Batch " << batchId << " output: ";
auto output = graph.getMiniBatch(divideNodeId, batchId, "divideout");
for (size_t i = 0; i < output.size(); ++i) {
std::cout << std::get<double>(output.getData(i)) << " ";
}
std::cout << std::endl;
}
return 0;
}