Skip to content

Commit f51582e

Browse files
committed
SparseShape: assert tile norms are finite
1 parent eb01419 commit f51582e

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

src/TiledArray/sparse_shape.h

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,19 @@ class SparseShape {
269269
return zero_tile_count;
270270
}
271271

272+
// clang-format off
273+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
274+
// clang-format of
272275
SparseShape(const Tensor<T>& tile_norms,
273276
const std::shared_ptr<vector_type>& size_vectors,
274277
const size_type zero_tile_count,
275278
const value_type my_threshold = threshold_)
276279
: tile_norms_(tile_norms),
277280
size_vectors_(size_vectors),
278281
zero_tile_count_(zero_tile_count),
279-
my_threshold_(my_threshold) {}
282+
my_threshold_(my_threshold) {
283+
TA_ASSERT(check_norms_finite(tile_norms_));
284+
}
280285

281286
public:
282287
/// Default constructor
@@ -292,20 +297,26 @@ class SparseShape {
292297
/// \note this ctor *does not* scale tile norms
293298
/// \note if @c tile_norm is less than the threshold then all tile norms are
294299
/// set to zero
300+
/// \pre `std::isfinite(tile_norm)`
295301
SparseShape(const value_type& tile_norm, const TiledRange& trange)
296302
: tile_norms_(trange.tiles_range(),
297303
(tile_norm < threshold_ ? 0 : tile_norm)),
298304
size_vectors_(initialize_size_vectors(trange)),
299305
zero_tile_count_(tile_norm < threshold_ ? trange.tiles_range().area()
300-
: 0ul) {}
306+
: 0ul) {
307+
TA_ASSERT(std::isfinite(tile_norm));
308+
}
301309

310+
// clang-format off
302311
/// Constructs a SparseShape from a functor returning norm values
303312

304313
/// \tparam Op callable of signature `value_type(const Range::index&)`
305314
/// \param tile_norm_op a functor that returns Frobenius norms of tiles
306315
/// \param trange The tiled range of the tensor
307316
/// \param do_not_scale if true, assume that the tile norms in \c tile_norms
308317
/// are already scaled
318+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
319+
// clang-format on
309320
template <
310321
typename Op,
311322
typename = std::enable_if_t<std::is_invocable_r_v<
@@ -317,6 +328,7 @@ class SparseShape {
317328
zero_tile_count_(0ul) {
318329
TA_ASSERT(!tile_norms_.empty());
319330
TA_ASSERT(tile_norms_.range() == trange.tiles_range());
331+
TA_ASSERT(check_norms_finite(tile_norms_));
320332

321333
if (!do_not_scale) {
322334
zero_tile_count_ = scale_tile_norms<ScaleBy::InverseVolume>(
@@ -326,19 +338,23 @@ class SparseShape {
326338
}
327339
}
328340

341+
// clang-format off
329342
/// Constructor from a tensor of (scaled/unscaled) norm values
330343

331344
/// \param tile_norms The Frobenius norm of tiles
332345
/// \param trange The tiled range of the tensor
333346
/// \param do_not_scale if true, assume that the tile norms in \c tile_norms
334347
/// are already scaled
348+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
349+
// clang-format on
335350
SparseShape(const Tensor<value_type>& tile_norms, const TiledRange& trange,
336351
bool do_not_scale = false)
337352
: tile_norms_(tile_norms.clone()),
338353
size_vectors_(initialize_size_vectors(trange)),
339354
zero_tile_count_(0ul) {
340355
TA_ASSERT(!tile_norms_.empty());
341356
TA_ASSERT(tile_norms_.range() == trange.tiles_range());
357+
TA_ASSERT(check_norms_finite(tile_norms_));
342358

343359
if (!do_not_scale) {
344360
zero_tile_count_ = scale_tile_norms<ScaleBy::InverseVolume>(
@@ -348,6 +364,7 @@ class SparseShape {
348364
}
349365
}
350366

367+
// clang-format off
351368
/// "Sparse" constructor
352369

353370
/// This constructor uses tile norms given as a sparse tensor,
@@ -361,6 +378,8 @@ class SparseShape {
361378
/// \param trange The tiled range of the tensor
362379
/// \param do_not_scale if true, assume that the tile norms in \c tile_norms
363380
/// are already scaled
381+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
382+
// clang-format on
364383
template <typename SparseNormSequence,
365384
typename = std::enable_if_t<
366385
TiledArray::detail::has_member_function_begin_anyreturn<
@@ -388,8 +407,10 @@ class SparseShape {
388407
--zero_tile_count_;
389408
}
390409
}
410+
TA_ASSERT(check_norms_finite(tile_norms_));
391411
}
392412

413+
// clang-format off
393414
/// Collective "dense" constructor
394415

395416
/// This constructor uses tile norms given as a dense tensor.
@@ -404,13 +425,16 @@ class SparseShape {
404425
/// \param trange The tiled range of the tensor
405426
/// \param do_not_scale if true, assume that the tile norms in \c tile_norms
406427
/// are already scaled
428+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
429+
// clang-format on
407430
SparseShape(World& world, const Tensor<value_type>& tile_norms,
408431
const TiledRange& trange, bool do_not_scale = false)
409432
: tile_norms_(tile_norms.clone()),
410433
size_vectors_(initialize_size_vectors(trange)),
411434
zero_tile_count_(0ul) {
412435
TA_ASSERT(!tile_norms_.empty());
413436
TA_ASSERT(tile_norms_.range() == trange.tiles_range());
437+
TA_ASSERT(check_norms_finite(tile_norms_));
414438

415439
// reduce norm data from all processors
416440
world.gop.max(tile_norms_.data(), tile_norms_.size());
@@ -423,6 +447,7 @@ class SparseShape {
423447
}
424448
}
425449

450+
// clang-format off
426451
/// Collective "sparse" constructor
427452

428453
/// This constructor uses tile norms given as a sparse tensor,
@@ -438,6 +463,8 @@ class SparseShape {
438463
/// nonzeros
439464
/// for this rank's subset of tiles, or be replicated.
440465
/// \param trange The tiled range of the tensor
466+
/// \post `std::ranges::all_of(data(), [](auto x) { return std::isfinite(x); })`
467+
// clang-format on
441468
template <typename SparseNormSequence>
442469
SparseShape(World& world, const SparseNormSequence& tile_norms,
443470
const TiledRange& trange)
@@ -1722,6 +1749,17 @@ class SparseShape {
17221749
return cast_abs_factor;
17231750
}
17241751

1752+
/// checks for finite tile norms
1753+
/// @return true if all elements of @p norms are finite
1754+
static bool check_norms_finite(const Tensor<value_type>& norms) {
1755+
for (const auto& v : norms) {
1756+
if (!std::isfinite(v)) {
1757+
return false;
1758+
}
1759+
}
1760+
return true;
1761+
}
1762+
17251763
template <MemorySpace S, typename T_>
17261764
friend std::size_t size_of(const SparseShape<T_>& shape);
17271765

0 commit comments

Comments
 (0)