Skip to content

Commit edd947d

Browse files
authored
Merge pull request #11872 from sneaxiy/type_compare
Fix type comparison bugs using std::type_index::hash_code()
2 parents 982dabe + 3f9292c commit edd947d

File tree

9 files changed

+44
-36
lines changed

9 files changed

+44
-36
lines changed

paddle/fluid/framework/lod_tensor.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/data_type.h"
2121
#include "paddle/fluid/framework/framework.pb.h"
2222
#include "paddle/fluid/framework/lod_tensor.h"
23+
#include "paddle/fluid/framework/var_type.h"
2324

2425
#include "paddle/fluid/memory/memcpy.h"
2526
#include "paddle/fluid/memory/memory.h"
@@ -68,9 +69,9 @@ std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
6869
// only print first ten elements
6970
int64_t size = t.numel() < 10 ? t.numel() : 10;
7071
for (int64_t i = 0; i < size; ++i) {
71-
if (t.type().hash_code() == typeid(float).hash_code()) { // NOLINT
72+
if (IsType<float>(t.type())) {
7273
os << t.data<float>()[i] << " ";
73-
} else if (t.type().hash_code() == typeid(int64_t).hash_code()) {
74+
} else if (IsType<int64_t>(t.type())) {
7475
os << t.data<int64_t>()[i] << " ";
7576
} else {
7677
PADDLE_THROW("LoDTensor data type not in [float, int64_t]");
@@ -384,7 +385,7 @@ void LoDTensor::MergeLoDTensor(
384385
LoD new_lod = lod_tensors[0]->lod();
385386
for (size_t i = 1; i < lod_tensors.size(); ++i) {
386387
auto *t = lod_tensors[i];
387-
PADDLE_ENFORCE_EQ(new_type.hash_code(), t->type().hash_code());
388+
PADDLE_ENFORCE_EQ(new_type, t->type());
388389
PADDLE_ENFORCE_EQ(new_layout, t->layout());
389390

390391
PADDLE_ENFORCE_EQ(framework::product(new_dim) / new_dim[0],

paddle/fluid/framework/operator.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,7 @@ static void CheckTensorNANOrInf(const std::string& name,
592592
if (tensor.memory_size() == 0) {
593593
return;
594594
}
595-
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
596-
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
595+
if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
597596
return;
598597
}
599598
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),

paddle/fluid/framework/var_type.h

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,24 @@ limitations under the License. */
2424

2525
namespace paddle {
2626
namespace framework {
27+
28+
template <typename T>
29+
bool IsType(const std::type_index& type_index) {
30+
return type_index == std::type_index(typeid(T));
31+
}
32+
2733
inline proto::VarType::Type ToVarType(std::type_index type) {
28-
if (type.hash_code() == typeid(LoDTensor).hash_code()) {
34+
if (IsType<LoDTensor>(type)) {
2935
return proto::VarType_Type_LOD_TENSOR;
30-
} else if (type.hash_code() == typeid(LoDRankTable).hash_code()) {
36+
} else if (IsType<LoDRankTable>(type)) {
3137
return proto::VarType_Type_LOD_RANK_TABLE;
32-
} else if (type.hash_code() == typeid(LoDTensorArray).hash_code()) {
38+
} else if (IsType<LoDTensorArray>(type)) {
3339
return proto::VarType_Type_LOD_TENSOR_ARRAY;
34-
} else if (type.hash_code() == typeid(SelectedRows).hash_code()) {
40+
} else if (IsType<SelectedRows>(type)) {
3541
return proto::VarType_Type_SELECTED_ROWS;
36-
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
42+
} else if (IsType<ReaderHolder>(type)) {
3743
return proto::VarType_Type_READER;
38-
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
44+
} else if (IsType<ChannelHolder>(type)) {
3945
return proto::VarType_Type_CHANNEL;
4046
} else {
4147
PADDLE_THROW("ToVarType:Unsupported type %s", type.name());

paddle/fluid/inference/analysis/helper.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <cstdio>
1818
#include <string>
19+
#include <typeindex>
1920
#include <unordered_map>
2021
#include <vector>
2122

@@ -41,7 +42,7 @@ int AccuDims(Vec &&vec, int size) {
4142
return res;
4243
}
4344

44-
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
45+
#define SET_TYPE(type__) dic_[std::type_index(typeid(type__))] = #type__;
4546
/*
4647
* Map typeid to representation.
4748
*/
@@ -53,14 +54,14 @@ struct DataTypeNamer {
5354

5455
template <typename T>
5556
const std::string &repr() const {
56-
auto x = typeid(T).hash_code();
57+
auto x = std::type_index(typeid(T));
5758
PADDLE_ENFORCE(dic_.count(x), "unknown type for representation");
5859
return dic_.at(x);
5960
}
6061

61-
const std::string &repr(size_t &hash) const { // NOLINT
62-
PADDLE_ENFORCE(dic_.count(hash), "unknown type for representation");
63-
return dic_.at(hash);
62+
const std::string &repr(const std::type_index &type) const { // NOLINT
63+
PADDLE_ENFORCE(dic_.count(type), "unknown type for representation");
64+
return dic_.at(type);
6465
}
6566

6667
private:
@@ -71,9 +72,7 @@ struct DataTypeNamer {
7172
SET_TYPE(void *);
7273
}
7374

74-
std::unordered_map<decltype(typeid(int).hash_code()), // NOLINT
75-
std::string>
76-
dic_;
75+
std::unordered_map<std::type_index, std::string> dic_;
7776
};
7877
#undef SET_TYPE
7978

paddle/fluid/inference/analysis/node.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ namespace analysis {
2323
template <>
2424
std::string &NodeAttr::As<std::string>() {
2525
if (data_.empty()) {
26-
type_hash_ = typeid(std::string).hash_code();
26+
type_index_ = std::type_index(typeid(std::string));
2727
}
28-
PADDLE_ENFORCE_EQ(type_hash_, typeid(std::string).hash_code());
28+
PADDLE_ENFORCE_EQ(type_index_, std::type_index(typeid(std::string)));
2929
return data_;
3030
}
3131

paddle/fluid/inference/analysis/node.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License. */
2525
#include <unordered_map>
2626
#include <vector>
2727

28+
#include "paddle/fluid/framework/var_type.h"
2829
#include "paddle/fluid/inference/analysis/device.h"
2930
#include "paddle/fluid/inference/analysis/dot.h"
3031
#include "paddle/fluid/inference/analysis/helper.h"
@@ -57,20 +58,20 @@ struct NodeAttr {
5758
// init storage in the first usage.
5859
if (data_.empty()) {
5960
VLOG(4) << "resize data to " << sizeof(T);
60-
type_hash_ = typeid(T).hash_code();
61+
type_index_ = std::type_index(typeid(T));
6162
data_.resize(sizeof(T));
6263
}
63-
PADDLE_ENFORCE(type_hash_ == typeid(T).hash_code(),
64+
PADDLE_ENFORCE(framework::IsType<T>(type_index_),
6465
"type not matched, origin is %s, want %s",
65-
DataTypeNamer::Global().repr(type_hash_),
66+
DataTypeNamer::Global().repr(type_index_),
6667
DataTypeNamer::Global().repr<T>());
6768
PADDLE_ENFORCE_EQ(data_.size(), sizeof(T), "Node attr type recast error");
6869
return *reinterpret_cast<T *>(&data_[0]);
6970
}
7071

7172
private:
7273
std::string data_;
73-
size_t type_hash_{std::numeric_limits<size_t>::max()};
74+
std::type_index type_index_{typeid(NodeAttr)};
7475
};
7576

7677
/*

paddle/fluid/operators/conditional_block_op.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414
#include <algorithm>
1515
#include "paddle/fluid/framework/executor.h"
1616
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/framework/var_type.h"
1718

1819
namespace paddle {
1920
namespace operators {
@@ -47,7 +48,7 @@ class ConditionalOp : public framework::OperatorBase {
4748
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
4849
PADDLE_THROW("should have one initialized input as condition");
4950
}
50-
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() && // NOLINT
51+
if (!(framework::IsType<bool>(ips[0]->type()) && // NOLINT
5152
ips[0]->numel() == 1)) {
5253
PADDLE_THROW(
5354
"condition input's data type should be bool, "

paddle/fluid/operators/print_op.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <ctime>
1717

1818
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/framework/var_type.h"
1920
#include "paddle/fluid/framework/variable.h"
2021

2122
namespace paddle {
@@ -62,7 +63,7 @@ struct Formater {
6263
}
6364
}
6465
void PrintDtype() {
65-
if (dtype.hash_code() != typeid(const char).hash_code()) {
66+
if (!framework::IsType<const char>(dtype)) {
6667
CLOG << "\tdtype: " << dtype.name() << std::endl;
6768
}
6869
}
@@ -83,15 +84,15 @@ struct Formater {
8384
void PrintData(size_t size) {
8485
PADDLE_ENFORCE_NOT_NULL(data);
8586
// print float
86-
if (dtype.hash_code() == typeid(const float).hash_code()) {
87+
if (framework::IsType<const float>(dtype)) {
8788
Display<float>(size);
88-
} else if (dtype.hash_code() == typeid(const double).hash_code()) {
89+
} else if (framework::IsType<const double>(dtype)) {
8990
Display<double>(size);
90-
} else if (dtype.hash_code() == typeid(const int).hash_code()) {
91+
} else if (framework::IsType<const int>(dtype)) {
9192
Display<int>(size);
92-
} else if (dtype.hash_code() == typeid(const int64_t).hash_code()) {
93+
} else if (framework::IsType<const int64_t>(dtype)) {
9394
Display<int64_t>(size);
94-
} else if (dtype.hash_code() == typeid(const bool).hash_code()) {
95+
} else if (framework::IsType<const bool>(dtype)) {
9596
Display<bool>(size);
9697
} else {
9798
CLOG << "\tdata: unprintable type: " << dtype.name() << std::endl;

paddle/fluid/operators/while_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/framework/lod_tensor_array.h"
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/framework/operator.h"
20+
#include "paddle/fluid/framework/var_type.h"
2021
#include "paddle/fluid/operators/detail/safe_ref.h"
2122

2223
namespace paddle {
@@ -135,15 +136,14 @@ class WhileGradOp : public framework::OperatorBase {
135136
auto &og_inside =
136137
detail::Ref(cur_scope.Var(inside_og_name),
137138
"Cannot find inside gradient %s", inside_og_name);
138-
if (og_outside.Type().hash_code() ==
139-
typeid(framework::LoDTensor).hash_code()) {
139+
if (framework::IsType<framework::LoDTensor>(og_outside.Type())) {
140140
auto &outside_tensor = og_outside.Get<framework::LoDTensor>();
141141
auto &inside_tensor =
142142
detail::Ref(og_inside.GetMutable<framework::LoDTensor>());
143143
inside_tensor.set_lod(outside_tensor.lod());
144144
inside_tensor.ShareDataWith(outside_tensor);
145-
} else if (og_outside.Type().hash_code() ==
146-
typeid(framework::LoDTensorArray).hash_code()) {
145+
} else if (framework::IsType<framework::LoDTensorArray>(
146+
og_outside.Type())) {
147147
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
148148
auto &inside_array =
149149
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());

0 commit comments

Comments
 (0)