Skip to content

Commit 6fe5547

Browse files
authored
switch NodeAttr to boost::varient (#12539)
1 parent 535a6e9 commit 6fe5547

File tree

4 files changed

+33
-26
lines changed

4 files changed

+33
-26
lines changed

paddle/fluid/inference/analysis/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ cc_library(analysis SRCS pass_manager.cc dot.cc node.cc data_flow_graph.cc graph
88
helper.cc
99
model_store_pass.cc
1010
DEPS framework_proto proto_desc)
11-
cc_test(test_node SRCS node_tester.cc DEPS analysis)
11+
cc_test(test_node SRCS node_tester.cc DEPS analysis gflags glog gtest)
1212
cc_test(test_dot SRCS dot_tester.cc DEPS analysis)
1313
cc_binary(inference_analyzer SRCS analyzer_main.cc DEPS analysis)
1414

paddle/fluid/inference/analysis/node.cc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@ namespace paddle {
2020
namespace inference {
2121
namespace analysis {
2222

23-
template <>
24-
std::string &NodeAttr::As<std::string>() {
25-
if (data_.empty()) {
26-
type_index_ = std::type_index(typeid(std::string));
27-
}
28-
PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string)));
29-
return data_;
30-
}
31-
32-
std::string &NodeAttr::String() { return As<std::string>(); }
33-
3423
std::vector<Dot::Attr> Value::dot_attrs() const {
3524
return std::vector<Dot::Attr>({Dot::Attr("style", "filled,rounded"),
3625
Dot::Attr("shape", "box"),

paddle/fluid/inference/analysis/node.h

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ limitations under the License. */
2929
#include "paddle/fluid/inference/analysis/device.h"
3030
#include "paddle/fluid/inference/analysis/dot.h"
3131
#include "paddle/fluid/inference/analysis/helper.h"
32+
#include "paddle/fluid/platform/variant.h"
3233

3334
namespace paddle {
3435
namespace inference {
@@ -38,39 +39,35 @@ class NodeMap;
3839

3940
// A helper class to maintain the status from Pass.
4041
struct NodeAttr {
42+
using any_t =
43+
boost::variant<bool, float, int32_t, int64_t, void *, std::string>;
4144
// NOTE T should be a primary type or a struct combined by several primary
4245
// types.
4346
// NOTE the STL containers should not use here.
4447
// Some usages
4548
// Attr attr;
4649
// attr.Bool() = true;
47-
4850
bool &Bool() { return As<bool>(); }
4951
float &Float() { return As<float>(); }
5052
int32_t &Int32() { return As<int32_t>(); }
5153
int64_t &Int64() { return As<int64_t>(); }
5254
void *&Pointer() { return As<void *>(); }
53-
std::string &String();
55+
std::string &String() { return As<std::string>(); }
5456

5557
private:
5658
template <typename T>
5759
T &As() {
58-
// init storage in the first usage.
59-
if (data_.empty()) {
60-
VLOG(4) << "resize data to " << sizeof(T);
61-
type_index_ = std::type_index(typeid(T));
62-
data_.resize(sizeof(T));
60+
if (type_index_ == typeid(NodeAttr)) {
61+
type_index_ = typeid(T);
62+
any_data_ = T();
63+
} else {
64+
PADDLE_ENFORCE(type_index_ == typeid(T), "fetch error type");
6365
}
64-
PADDLE_ENFORCE(framework::IsType<T>(type_index_),
65-
"type not matched, origin is %s, want %s",
66-
DataTypeNamer::Global().repr(type_index_),
67-
DataTypeNamer::Global().repr<T>());
68-
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
69-
return *reinterpret_cast<T *>(&data_[0]);
66+
return boost::get<T>(any_data_);
7067
}
7168

7269
private:
73-
std::string data_;
70+
any_t any_data_;
7471
std::type_index type_index_{typeid(NodeAttr)};
7572
};
7673

paddle/fluid/inference/analysis/node_tester.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,34 @@ namespace paddle {
2020
namespace inference {
2121
namespace analysis {
2222

23+
TEST(NodeAttr, bool) {
24+
NodeAttr x;
25+
x.Bool() = true;
26+
ASSERT_EQ(x.Bool(), true);
27+
}
28+
29+
TEST(NodeAttr, int32) {
30+
NodeAttr x;
31+
x.Int32() = 32;
32+
ASSERT_EQ(x.Int32(), 32);
33+
}
34+
35+
TEST(NodeAttr, string) {
36+
NodeAttr x;
37+
x.String() = "Hello";
38+
ASSERT_EQ(x.String(), "Hello");
39+
}
40+
2341
TEST(Node, Attr) {
2442
// Node is an abstract class, use Value instead for they share the same Attr
2543
// logic.
2644
NodeMap nodes;
2745
auto* node = nodes.Create(Node::Type::kValue);
2846
node->attr("v0").Int32() = 2008;
2947
ASSERT_EQ(node->attr("v0").Int32(), 2008);
48+
49+
node->attr("str").String() = "hello world";
50+
ASSERT_EQ(node->attr("str").String(), "hello world");
3051
}
3152

3253
} // namespace analysis

0 commit comments

Comments
 (0)