@@ -26,6 +26,113 @@ limitations under the License. */
26
26
27
27
namespace paddle {
28
28
namespace framework {
29
+
30
+ template <typename T>
31
+ struct ExtractAttribute {
32
+ explicit ExtractAttribute (const std::string& attr_name)
33
+ : attr_name_(attr_name) {}
34
+
35
+ T* operator ()(Attribute& attr) const {
36
+ T* attr_value = nullptr ;
37
+ try {
38
+ attr_value = &boost::get<T>(attr);
39
+ } catch (boost::bad_get& bad_get) {
40
+ PADDLE_THROW (" Cannot get attribute %s by type %s, its type is %s" ,
41
+ attr_name_, paddle::platform::demangle (typeid (T).name ()),
42
+ paddle::platform::demangle (attr.type ().name ()));
43
+ }
44
+ return attr_value;
45
+ }
46
+
47
+ const std::string& attr_name_;
48
+ };
49
+
50
+ // special handle bool
51
+ // FIXME(yuyang18): Currently we cast bool into int in python binding. It is
52
+ // hard to change the logic there. In another way, we should correct handle
53
+ // if the user set `some_flag=1`.
54
+ //
55
+ // FIX ME anytime if there is a better solution.
56
+ template <>
57
+ struct ExtractAttribute <bool > {
58
+ explicit ExtractAttribute (const std::string& attr_name)
59
+ : attr_name_(attr_name) {}
60
+
61
+ bool * operator ()(Attribute& attr) const {
62
+ if (attr.type () == typeid (int )) { // NOLINT
63
+ int val = boost::get<int >(attr);
64
+ attr = static_cast <bool >(val);
65
+ } else if (attr.type () == typeid (float )) { // NOLINT
66
+ float val = boost::get<float >(attr);
67
+ attr = static_cast <bool >(val);
68
+ }
69
+ bool * attr_value = nullptr ;
70
+ try {
71
+ attr_value = &boost::get<bool >(attr);
72
+ } catch (boost::bad_get& bad_get) {
73
+ PADDLE_THROW (" Cannot get attribute %s by type bool, its type is %s" ,
74
+ attr_name_, paddle::platform::demangle (attr.type ().name ()));
75
+ }
76
+ return attr_value;
77
+ }
78
+
79
+ const std::string& attr_name_;
80
+ };
81
+
82
+ template <>
83
+ struct ExtractAttribute <int64_t > {
84
+ explicit ExtractAttribute (const std::string& attr_name)
85
+ : attr_name_(attr_name) {}
86
+
87
+ int64_t * operator ()(Attribute& attr) const {
88
+ if (attr.type () == typeid (int )) { // NOLINT
89
+ int val = boost::get<int >(attr);
90
+ attr = static_cast <int64_t >(val);
91
+ } else if (attr.type () == typeid (float )) { // NOLINT
92
+ int val = boost::get<float >(attr);
93
+ attr = static_cast <int64_t >(val);
94
+ }
95
+ int64_t * attr_value = nullptr ;
96
+ try {
97
+ attr_value = &boost::get<int64_t >(attr);
98
+ } catch (boost::bad_get& bad_get) {
99
+ PADDLE_THROW (" Cannot get attribute %s by type int64_t, its type is %s" ,
100
+ attr_name_, paddle::platform::demangle (attr.type ().name ()));
101
+ }
102
+ return attr_value;
103
+ }
104
+
105
+ const std::string& attr_name_;
106
+ };
107
+
108
+ template <>
109
+ struct ExtractAttribute <std::vector<int64_t >> {
110
+ explicit ExtractAttribute (const std::string& attr_name)
111
+ : attr_name_(attr_name) {}
112
+
113
+ std::vector<int64_t >* operator ()(Attribute& attr) const {
114
+ if (attr.type () == typeid (std::vector<int >)) { // NOLINT
115
+ std::vector<int > val = boost::get<std::vector<int >>(attr);
116
+ std::vector<int64_t > vec (val.begin (), val.end ());
117
+ attr = vec;
118
+ } else if (attr.type () == typeid (std::vector<float >)) { // NOLINT
119
+ std::vector<float > val = boost::get<std::vector<float >>(attr);
120
+ std::vector<int64_t > vec (val.begin (), val.end ());
121
+ attr = vec;
122
+ }
123
+ std::vector<int64_t >* attr_value = nullptr ;
124
+ try {
125
+ attr_value = &boost::get<std::vector<int64_t >>(attr);
126
+ } catch (boost::bad_get& bad_get) {
127
+ PADDLE_THROW (" Cannot get attribute %s by type int64_t, its type is %s" ,
128
+ attr_name_, paddle::platform::demangle (attr.type ().name ()));
129
+ }
130
+ return attr_value;
131
+ }
132
+
133
+ const std::string& attr_name_;
134
+ };
135
+
29
136
template <typename T>
30
137
inline proto::AttrType AttrTypeID () {
31
138
Attribute tmp = T ();
@@ -42,7 +149,11 @@ class AttrReader {
42
149
inline const T& Get (const std::string& name) const {
43
150
PADDLE_ENFORCE (attrs_.count (name) != 0 , " %s should be in AttributeMap" ,
44
151
name);
45
- return boost::get<T>(attrs_.at (name));
152
+
153
+ Attribute& attr = const_cast <Attribute&>(attrs_.at (name));
154
+ ExtractAttribute<T> extract_attr (name);
155
+ T* attr_value = extract_attr (attr);
156
+ return *attr_value;
46
157
}
47
158
48
159
private:
@@ -82,7 +193,7 @@ class DefaultValueSetter {
82
193
public:
83
194
explicit DefaultValueSetter (T default_value)
84
195
: default_value_(default_value) {}
85
- void operator ()(T& value) const { value = default_value_; }
196
+ void operator ()(T& value) const { value = default_value_; } // NOLINT
86
197
87
198
private:
88
199
T default_value_;
@@ -117,84 +228,6 @@ class EnumInContainer {
117
228
std::unordered_set<T> container_;
118
229
};
119
230
120
- template <typename T>
121
- struct ExtractAttribute {
122
- explicit ExtractAttribute (const std::string& attr_name)
123
- : attr_name_(attr_name) {}
124
-
125
- T* operator ()(Attribute& attr) const {
126
- T* attr_value = nullptr ;
127
- try {
128
- attr_value = &boost::get<T>(attr);
129
- } catch (boost::bad_get& bad_get) {
130
- PADDLE_THROW (" Cannot get attribute %s by type %s, its type is %s" ,
131
- attr_name_, paddle::platform::demangle (typeid (T).name ()),
132
- paddle::platform::demangle (attr.type ().name ()));
133
- }
134
- return attr_value;
135
- }
136
-
137
- const std::string& attr_name_;
138
- };
139
-
140
- // special handle bool
141
- // FIXME(yuyang18): Currently we cast bool into int in python binding. It is
142
- // hard to change the logic there. In another way, we should correct handle
143
- // if the user set `some_flag=1`.
144
- //
145
- // FIX ME anytime if there is a better solution.
146
- template <>
147
- struct ExtractAttribute <bool > {
148
- explicit ExtractAttribute (const std::string& attr_name)
149
- : attr_name_(attr_name) {}
150
-
151
- bool * operator ()(Attribute& attr) const {
152
- if (attr.type () == typeid (int )) { // NOLINT
153
- int val = boost::get<int >(attr);
154
- attr = static_cast <bool >(val);
155
- } else if (attr.type () == typeid (float )) { // NOLINT
156
- float val = boost::get<float >(attr);
157
- attr = static_cast <bool >(val);
158
- }
159
- bool * attr_value = nullptr ;
160
- try {
161
- attr_value = &boost::get<bool >(attr);
162
- } catch (boost::bad_get& bad_get) {
163
- PADDLE_THROW (" Cannot get attribute %s by type bool, its type is %s" ,
164
- attr_name_, paddle::platform::demangle (attr.type ().name ()));
165
- }
166
- return attr_value;
167
- }
168
-
169
- const std::string& attr_name_;
170
- };
171
-
172
- template <>
173
- struct ExtractAttribute <int64_t > {
174
- explicit ExtractAttribute (const std::string& attr_name)
175
- : attr_name_(attr_name) {}
176
-
177
- int64_t * operator ()(Attribute& attr) const {
178
- if (attr.type () == typeid (int )) { // NOLINT
179
- int val = boost::get<int >(attr);
180
- attr = static_cast <int64_t >(val);
181
- } else if (attr.type () == typeid (float )) { // NOLINT
182
- int val = boost::get<float >(attr);
183
- attr = static_cast <int64_t >(val);
184
- }
185
- int64_t * attr_value = nullptr ;
186
- try {
187
- attr_value = &boost::get<int64_t >(attr);
188
- } catch (boost::bad_get& bad_get) {
189
- PADDLE_THROW (" Cannot get attribute %s by type int64_t, its type is %s" ,
190
- attr_name_, paddle::platform::demangle (attr.type ().name ()));
191
- }
192
- return attr_value;
193
- }
194
-
195
- const std::string& attr_name_;
196
- };
197
-
198
231
// check whether a certain attribute fit its limits
199
232
// an attribute can have more than one limits
200
233
template <typename T>
@@ -235,7 +268,7 @@ class TypedAttrChecker {
235
268
return *this ;
236
269
}
237
270
238
- void operator ()(AttributeMap& attr_map) const {
271
+ void operator ()(AttributeMap& attr_map) const { // NOLINT
239
272
if (!attr_map.count (attr_name_)) {
240
273
// user do not set this attr
241
274
PADDLE_ENFORCE (!default_value_setter_.empty (),
@@ -271,7 +304,7 @@ class OpAttrChecker {
271
304
return *(checker.target <TypedAttrChecker<T>>());
272
305
}
273
306
274
- void Check (AttributeMap& attr_map) const {
307
+ void Check (AttributeMap& attr_map) const { // NOLINT
275
308
for (const auto & checker : attr_checkers_) {
276
309
checker (attr_map);
277
310
}
0 commit comments