Skip to content

Commit de81ccb

Browse files
authored
feature/analysis node representation (#10522)
1 parent 8231960 commit de81ccb

File tree

7 files changed

+496
-1
lines changed

7 files changed

+496
-1
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
cc_library(dot SRCS dot.cc)
1+
cc_library(analysis SRCS dot.cc node.cc node.h)
2+
cc_test(test_node SRCS node_tester.cc DEPS analysis)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
namespace paddle {
16+
namespace inference {
17+
namespace analysis {
18+
19+
enum class Device { CPU, GPU };
20+
21+
} // namespace analysis
22+
} // namespace inference
23+
} // namespace paddle
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/inference/analysis/dot.h"
16+
17+
#include <gtest/gtest.h>
18+
#include <memory>
19+
#include "paddle/fluid/inference/analysis/data_flow_graph.h"
20+
21+
namespace paddle {
22+
namespace inference {
23+
namespace analysis {
24+
25+
class DotTester : public ::testing::Test {
26+
protected:
27+
void SetUp() override {
28+
std::vector<Dot::Attr> attrs({{"title", "hello"}});
29+
dot.reset(new Dot(attrs));
30+
dot->AddNode("a", {Dot::Attr{"shape", "box"}, Dot::Attr("color", "blue")});
31+
dot->AddNode("b", {});
32+
dot->AddNode("c", {});
33+
dot->AddEdge("a", "b", {});
34+
dot->AddEdge("b", "c", {});
35+
dot->AddEdge("a", "c", {});
36+
}
37+
38+
std::unique_ptr<Dot> dot;
39+
};
40+
41+
TEST_F(DotTester, Build) {
42+
auto codes = dot->Build();
43+
// Output the DOT language code, the generated codes are too long to compare
44+
// the string.
45+
//
46+
// The output is
47+
//
48+
// digraph G {
49+
// title="hello"
50+
// node_1
51+
// node_2
52+
// node_0[label="a" shape="box" color="blue"]
53+
// node_0->node_1
54+
// node_1->node_2
55+
// node_0->node_2
56+
// } // end G
57+
LOG(INFO) << '\n' << codes;
58+
}
59+
60+
} // namespace analysis
61+
} // namespace inference
62+
} // namespace paddle
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <unordered_map>
19+
#include <vector>
20+
21+
#include "paddle/fluid/platform/enforce.h"
22+
23+
namespace paddle {
24+
namespace inference {
25+
namespace analysis {
26+
27+
template <typename IteratorT>
28+
class iterator_range {
29+
IteratorT begin_, end_;
30+
31+
public:
32+
template <typename Container>
33+
explicit iterator_range(Container &&c) : begin_(c.begin()), end_(c.end()) {}
34+
35+
iterator_range(const IteratorT &begin, const IteratorT &end)
36+
: begin_(begin), end_(end) {}
37+
38+
const IteratorT &begin() const { return begin_; }
39+
const IteratorT &end() const { return end_; }
40+
};
41+
42+
/*
43+
* An registry helper class, with its records keeps the order they registers.
44+
*/
45+
template <typename T>
46+
class OrderedRegistry {
47+
public:
48+
T *Register(const std::string &name, T *x) {
49+
PADDLE_ENFORCE(!dic_.count(name));
50+
dic_[name] = data_.size();
51+
data_.emplace_back(std::unique_ptr<T>(x));
52+
return data_.back().get();
53+
}
54+
55+
T *Lookup(const std::string &name) {
56+
auto it = dic_.find(name);
57+
if (it == dic_.end()) return nullptr;
58+
return data_[it->second].get();
59+
}
60+
61+
protected:
62+
std::unordered_map<std::string, int> dic_;
63+
std::vector<std::unique_ptr<T>> data_;
64+
};
65+
66+
} // namespace analysis
67+
} // namespace inference
68+
} // namespace paddle
69+
70+
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
71+
\
72+
type__(const type__ &) = delete; \
73+
\
74+
void operator=(const type__ &) = delete;
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/inference/analysis/node.h"
16+
#include "glog/logging.h"
17+
#include "paddle/fluid/platform/enforce.h"
18+
19+
namespace paddle {
20+
namespace inference {
21+
namespace analysis {
22+
23+
std::vector<Dot::Attr> Value::dot_attrs() const {
24+
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
25+
Dot::Attr("shape", "box"),
26+
Dot::Attr("fillcolor", "red")});
27+
}
28+
29+
std::vector<Dot::Attr> Function::dot_attrs() const {
30+
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
31+
Dot::Attr("shape", "diamond"),
32+
Dot::Attr("fillcolor", "yellow")});
33+
}
34+
35+
Node *NodeMap::Create(Node::Type type) {
36+
switch (type) {
37+
case Node::Type::kFunction:
38+
nodes_.emplace_back(new Function);
39+
break;
40+
case Node::Type::kValue:
41+
nodes_.emplace_back(new Value);
42+
break;
43+
default:
44+
PADDLE_THROW("Not supported node type.");
45+
}
46+
nodes_.back()->id_ = size() - 1;
47+
return nodes_.back().get();
48+
}
49+
50+
Node *NodeMap::GetMutable(size_t id) {
51+
PADDLE_ENFORCE_GT(size(), id);
52+
return nodes_[id].get();
53+
}
54+
55+
const Node &NodeMap::Get(size_t id) const {
56+
PADDLE_ENFORCE_GT(size(), id);
57+
return *nodes_[id].get();
58+
}
59+
60+
void NodeMap::Delete(size_t id) {
61+
PADDLE_ENFORCE_LT(id, size());
62+
nodes_[id]->SetDeleted();
63+
}
64+
65+
} // namespace analysis
66+
} // namespace inference
67+
} // namespace paddle

0 commit comments

Comments
 (0)