@@ -95,19 +95,19 @@ TEST(GraphTest, Basic) {
95
95
96
96
std::unique_ptr<ir::Graph> g (new ir::Graph (prog));
97
97
std::vector<ir::Node *> nodes (g->Nodes ().begin (), g->Nodes ().end ());
98
- ASSERT_EQ ( nodes[ 0 ]-> Name (), " sum " );
99
- ASSERT_EQ (nodes[ 0 ]-> inputs [ 0 ]-> Name (), " test_a " );
100
- ASSERT_EQ (nodes[ 0 ] ->inputs [ 1 ]-> Name (), " test_b " );
101
- ASSERT_EQ (nodes[ 0 ]-> inputs [ 2 ]-> Name (), " test_c " );
102
- ASSERT_EQ (nodes[ 0 ]-> outputs [ 0 ] ->Name (), " test_out " );
103
- ASSERT_EQ (nodes[ 1 ] ->Name (), " test_a " );
104
- ASSERT_EQ (nodes[ 1 ]-> outputs [ 0 ]-> Name (), " sum " );
105
- ASSERT_EQ (nodes[ 2 ]-> Name (), " test_b " );
106
- ASSERT_EQ (nodes[ 2 ]-> outputs [ 0 ]-> Name (), " sum " );
107
- ASSERT_EQ (nodes[ 3 ]-> Name (), " test_c " );
108
- ASSERT_EQ (nodes[ 3 ] ->outputs [ 0 ]-> Name (), " sum " );
109
- ASSERT_EQ (nodes[ 4 ]-> Name (), " test_out " );
110
- ASSERT_EQ (nodes[ 4 ]-> inputs [ 0 ]-> Name (), " sum " );
98
+ for (ir::Node *n : nodes) {
99
+ if (n-> Name () == " sum " ) {
100
+ ASSERT_EQ (n ->inputs . size (), 3 );
101
+ ASSERT_EQ (n-> outputs . size (), 1 );
102
+ } else if (n-> Name () == " test_a " || n ->Name () == " test_b " ||
103
+ n ->Name () == " test_c " ) {
104
+ ASSERT_EQ (n-> inputs . size (), 0 );
105
+ ASSERT_EQ (n-> outputs . size (), 1 );
106
+ } else if (n-> Name () == " test_out " ) {
107
+ ASSERT_EQ (n-> inputs . size (), 1 );
108
+ ASSERT_EQ (n ->outputs . size (), 0 );
109
+ }
110
+ }
111
111
ASSERT_EQ (nodes.size (), 5 );
112
112
}
113
113
} // namespace framework
0 commit comments