Skip to content

Commit 9432123

Browse files
committed
SparseShape: clamp non-finite norms from float overflow in shape arithmetic
Shape norm estimates (used for tile screening) are stored as float and computed via upper-bound arithmetic (gemm, scale, mult, add). Chained operations can overflow float to inf; subsequent inf*0 produces NaN. On multi-rank runs, NaN norms cause non-deterministic tile dropping because gop.max(NaN, x) is order-dependent across ranks. Fix: clamp non-finite results (inf/NaN) to float_max at multiple levels: - Private constructor: clamp_nonfinite_norms() as the single chokepoint - gemm: clamp preprocessed tile_norm*k_size and post-BLAS results - scale: clamp norm*factor overflow - scale_tile_norms: clamp norm*volume overflow - gemm outer product path: clamp norm*norm*factor overflow Clamping to float_max preserves screening correctness (no tile is incorrectly zeroed) while preventing inf/NaN propagation.
1 parent f51582e commit 9432123

File tree

1 file changed

+66
-19
lines changed

1 file changed

+66
-19
lines changed

src/TiledArray/sparse_shape.h

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -155,17 +155,21 @@ class SparseShape {
155155
madness::AtomicInt zero_tile_count;
156156
zero_tile_count = 0;
157157

158+
constexpr auto real_max = std::numeric_limits<value_type>::max();
159+
158160
if (dim == 1u) {
159161
// This is the easy case where the data is a vector and can be
160162
// normalized directly.
161163
math::inplace_vector_op(
162-
[threshold, &zero_tile_count](value_type& norm,
163-
const value_type size) {
164+
[threshold, real_max, &zero_tile_count](value_type& norm,
165+
const value_type size) {
164166
if (ScaleBy_ == ScaleBy::Volume)
165167
norm *= size;
166168
else
167169
norm /= size;
168-
if (Screen && norm < threshold) {
170+
if (!std::isfinite(norm)) {
171+
norm = real_max;
172+
} else if (Screen && norm < threshold) {
169173
norm = value_type(0);
170174
++zero_tile_count;
171175
}
@@ -203,10 +207,12 @@ class SparseShape {
203207
math::outer(
204208
left.size(), right.size(), left.data(), right.data(),
205209
tile_norms.data(),
206-
[threshold, &zero_tile_count](value_type& norm, const value_type x,
207-
const value_type y) {
210+
[threshold, real_max, &zero_tile_count](
211+
value_type& norm, const value_type x, const value_type y) {
208212
norm *= x * y;
209-
if (Screen && norm < threshold) {
213+
if (!std::isfinite(norm)) {
214+
norm = real_max;
215+
} else if (Screen && norm < threshold) {
210216
norm = value_type(0);
211217
++zero_tile_count;
212218
}
@@ -280,7 +286,12 @@ class SparseShape {
280286
size_vectors_(size_vectors),
281287
zero_tile_count_(zero_tile_count),
282288
my_threshold_(my_threshold) {
283-
TA_ASSERT(check_norms_finite(tile_norms_));
289+
// Clamp non-finite shape norm estimates (inf/NaN from float overflow in
290+
// shape arithmetic) to real_max. Shape norms are upper bounds used for
291+
// tile screening; clamping preserves correctness (no tile is incorrectly
292+
// zeroed) while preventing inf/NaN from propagating to downstream
293+
// operations.
294+
clamp_nonfinite_norms(tile_norms_);
284295
}
285296

286297
public:
@@ -1273,9 +1284,13 @@ class SparseShape {
12731284
const value_type abs_factor = to_abs_factor(factor);
12741285
madness::AtomicInt zero_tile_count;
12751286
zero_tile_count = 0;
1276-
auto op = [threshold, &zero_tile_count, abs_factor](value_type value) {
1287+
constexpr auto real_max = std::numeric_limits<value_type>::max();
1288+
auto op = [threshold, real_max, &zero_tile_count,
1289+
abs_factor](value_type value) {
12771290
value *= abs_factor;
1278-
if (value < threshold) {
1291+
if (!std::isfinite(value)) {
1292+
value = real_max;
1293+
} else if (value < threshold) {
12791294
value = value_type(0);
12801295
++zero_tile_count;
12811296
}
@@ -1310,9 +1325,13 @@ class SparseShape {
13101325
const value_type abs_factor = to_abs_factor(factor);
13111326
madness::AtomicInt zero_tile_count;
13121327
zero_tile_count = 0;
1313-
auto op = [threshold, &zero_tile_count, abs_factor](value_type value) {
1328+
constexpr auto real_max = std::numeric_limits<value_type>::max();
1329+
auto op = [threshold, real_max, &zero_tile_count,
1330+
abs_factor](value_type value) {
13141331
value *= abs_factor;
1315-
if (value < threshold) {
1332+
if (!std::isfinite(value)) {
1333+
value = real_max;
1334+
} else if (value < threshold) {
13161335
value = value_type(0);
13171336
++zero_tile_count;
13181337
}
@@ -1661,10 +1680,16 @@ class SparseShape {
16611680
// TODO: Make this faster. It can be done without using temporaries
16621681
// for the arguments, but requires a custom matrix multiply.
16631682

1683+
// Preprocessing: multiply tile norms by k_sizes to convert per-element
1684+
// norms to Frobenius norms for the BLAS gemm. Clamp to real_max to
1685+
// prevent overflow to inf (inf * 0 in the gemm would produce NaN).
1686+
constexpr auto real_max = std::numeric_limits<value_type>::max();
1687+
16641688
Tensor<value_type> left(tile_norms_.range());
16651689
const size_type mk = M * K;
1666-
auto left_op = [](const value_type left, const value_type right) {
1667-
return left * right;
1690+
auto left_op = [real_max](const value_type left, const value_type right) {
1691+
auto v = left * right;
1692+
return v > real_max ? real_max : v;
16681693
};
16691694
for (size_type i = 0ul; i < mk; i += K)
16701695
math::vector_op(left_op, K, left.data() + i, tile_norms_.data() + i,
@@ -1673,30 +1698,39 @@ class SparseShape {
16731698
Tensor<value_type> right(other.tile_norms_.range());
16741699
for (integer i = 0ul, k = 0; k < K; i += N, ++k) {
16751700
const value_type factor = k_sizes[k];
1676-
auto right_op = [=](const value_type arg) { return arg * factor; };
1701+
auto right_op = [=](const value_type arg) {
1702+
auto v = arg * factor;
1703+
return v > real_max ? real_max : v;
1704+
};
16771705
math::vector_op(right_op, N, right.data() + i,
16781706
other.tile_norms_.data() + i);
16791707
}
16801708

16811709
result_norms = left.gemm(right, abs_factor, gemm_helper);
16821710

1683-
// Hard zero tiles that are below the zero threshold.
1711+
// Clamp non-finite results (inf/NaN from float overflow in BLAS gemm)
1712+
// and hard-zero tiles below the threshold.
16841713
result_norms.inplace_unary(
1685-
[threshold, &zero_tile_count](value_type& value) {
1686-
if (value < threshold) {
1714+
[threshold, real_max, &zero_tile_count](value_type& value) {
1715+
if (!std::isfinite(value)) {
1716+
value = real_max;
1717+
} else if (value < threshold) {
16871718
value = value_type(0);
16881719
++zero_tile_count;
16891720
}
16901721
});
16911722

16921723
} else {
16931724
// This is an outer product, so the inputs can be used directly
1725+
constexpr auto real_max = std::numeric_limits<value_type>::max();
16941726
math::outer_fill(M, N, tile_norms_.data(), other.tile_norms_.data(),
16951727
result_norms.data(),
1696-
[threshold, &zero_tile_count, abs_factor](
1728+
[threshold, real_max, &zero_tile_count, abs_factor](
16971729
const value_type left, const value_type right) {
16981730
value_type norm = left * right * abs_factor;
1699-
if (norm < threshold) {
1731+
if (!std::isfinite(norm)) {
1732+
norm = real_max;
1733+
} else if (norm < threshold) {
17001734
norm = value_type(0);
17011735
++zero_tile_count;
17021736
}
@@ -1760,6 +1794,19 @@ class SparseShape {
17601794
return true;
17611795
}
17621796

1797+
/// Clamp non-finite (inf/NaN) shape norm estimates to real_max.
1798+
/// Shape norms are upper bounds; clamping to a large finite value
1799+
/// preserves screening correctness (no false zeros) while preventing
1800+
/// inf/NaN from propagating through subsequent shape arithmetic.
1801+
static void clamp_nonfinite_norms(Tensor<value_type>& norms) {
1802+
constexpr auto real_max = std::numeric_limits<value_type>::max();
1803+
for (auto& v : norms) {
1804+
if (!std::isfinite(v)) {
1805+
v = real_max;
1806+
}
1807+
}
1808+
}
1809+
17631810
template <MemorySpace S, typename T_>
17641811
friend std::size_t size_of(const SparseShape<T_>& shape);
17651812

0 commit comments

Comments
 (0)