23
23
namespace phi {
24
24
namespace funcs {
25
25
26
+ template <typename T>
27
+ static bool NaNSafeEqual (const T& a, const T& b) {
28
+ if constexpr (std::is_floating_point_v<T>) {
29
+ if (std::isnan (a) && std::isnan (b)) {
30
+ return &a == &b;
31
+ }
32
+ if (std::isnan (a) || std::isnan (b)) {
33
+ return false ;
34
+ }
35
+ }
36
+ return a == b;
37
+ }
38
+
39
+ template <typename T>
40
+ static bool NaNSafeLess (const T& a, const T& b) {
41
+ if constexpr (std::is_floating_point_v<T>) {
42
+ if (std::isnan (a) && !std::isnan (b)) {
43
+ return false ;
44
+ }
45
+ if (!std::isnan (a) && std::isnan (b)) {
46
+ return true ;
47
+ }
48
+ if (std::isnan (a) && std::isnan (b)) {
49
+ return &a < &b;
50
+ }
51
+ }
52
+ return a < b;
53
+ }
54
+
26
55
template <typename Context, typename InT>
27
56
struct UniqueOpFunctor {
28
57
const Context& dev_ctx_;
@@ -122,7 +151,7 @@ static bool Equal(const DenseTensor& a, const DenseTensor& b) {
122
151
return false ;
123
152
}
124
153
for (int64_t i = 0 ; i < a.numel (); ++i) {
125
- if (a.data <T>()[i] != b.data <T>()[i]) {
154
+ if (! NaNSafeEqual ( a.data <T>()[i], b.data <T>()[i]) ) {
126
155
return false ;
127
156
}
128
157
}
@@ -140,7 +169,15 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
140
169
bool return_inverse,
141
170
bool return_counts) {
142
171
const InT* in_data = in.data <InT>();
143
- std::set<InT> unique (in_data, in_data + in.numel ());
172
+
173
+ auto nan_safe_comp = [](const InT& a, const InT& b) {
174
+ return NaNSafeLess (a, b);
175
+ };
176
+ std::set<InT, decltype (nan_safe_comp)> unique (nan_safe_comp);
177
+ for (int64_t i = 0 ; i < in.numel (); ++i) {
178
+ unique.insert (in_data[i]);
179
+ }
180
+
144
181
out->Resize (common::make_ddim ({static_cast <int64_t >(unique.size ())}));
145
182
auto * out_data = dev_ctx.template Alloc <InT>(out);
146
183
std::copy (unique.begin (), unique.end (), out_data);
@@ -162,29 +199,27 @@ static void UniqueFlattenedTensor(const Context& dev_ctx,
162
199
if (return_inverse) {
163
200
index->Resize (common::make_ddim ({in.numel ()}));
164
201
auto inverse_data = dev_ctx.template Alloc <IndexT>(index);
165
- std::unordered_map<InT, IndexT> inverse_map;
166
- inverse_map.reserve (out->numel ());
167
- for (int64_t i = 0 ; i < out->numel (); ++i) {
168
- inverse_map[out_data[i]] = i;
169
- }
170
202
for (int64_t i = 0 ; i < in.numel (); ++i) {
171
- inverse_data[i] = inverse_map[in_data[i]];
203
+ for (int64_t j = 0 ; j < out->numel (); ++j) {
204
+ if (NaNSafeEqual (in_data[i], out_data[j])) {
205
+ inverse_data[i] = j;
206
+ break ;
207
+ }
208
+ }
172
209
}
173
210
}
174
211
175
212
if (return_counts) {
176
213
count->Resize (common::make_ddim ({out->numel ()}));
177
214
auto count_data = dev_ctx.template Alloc <IndexT>(count);
178
- std::unordered_map<InT, IndexT> counts_map;
179
- counts_map.reserve (out->numel ());
180
215
for (int64_t i = 0 ; i < out->numel (); ++i) {
181
- counts_map[out_data[i]] = 0 ;
182
- }
183
- for ( int64_t i = 0 ; i < in. numel (); i++ ) {
184
- counts_map[in_data[i]] += 1 ;
185
- }
186
- for ( int64_t i = 0 ; i < out-> numel (); i++) {
187
- count_data[i] = counts_map[out_data[i]] ;
216
+ IndexT cnt = 0 ;
217
+ for ( int64_t j = 0 ; j < in. numel (); ++j) {
218
+ if ( NaNSafeEqual (out_data[i], in_data[j]) ) {
219
+ cnt++ ;
220
+ }
221
+ }
222
+ count_data[i] = cnt ;
188
223
}
189
224
}
190
225
}
0 commit comments