@@ -155,17 +155,21 @@ class SparseShape {
155155 madness::AtomicInt zero_tile_count;
156156 zero_tile_count = 0 ;
157157
158+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
159+
158160 if (dim == 1u ) {
159161 // This is the easy case where the data is a vector and can be
160162 // normalized directly.
161163 math::inplace_vector_op (
162- [threshold, &zero_tile_count](value_type& norm,
163- const value_type size) {
164+ [threshold, real_max, &zero_tile_count](value_type& norm,
165+ const value_type size) {
164166 if (ScaleBy_ == ScaleBy::Volume)
165167 norm *= size;
166168 else
167169 norm /= size;
168- if (Screen && norm < threshold) {
170+ if (!std::isfinite (norm)) {
171+ norm = real_max;
172+ } else if (Screen && norm < threshold) {
169173 norm = value_type (0 );
170174 ++zero_tile_count;
171175 }
@@ -203,10 +207,12 @@ class SparseShape {
203207 math::outer (
204208 left.size (), right.size (), left.data (), right.data (),
205209 tile_norms.data (),
206- [threshold, &zero_tile_count](value_type& norm, const value_type x,
207- const value_type y) {
210+ [threshold, real_max, &zero_tile_count](
211+ value_type& norm, const value_type x, const value_type y) {
208212 norm *= x * y;
209- if (Screen && norm < threshold) {
213+ if (!std::isfinite (norm)) {
214+ norm = real_max;
215+ } else if (Screen && norm < threshold) {
210216 norm = value_type (0 );
211217 ++zero_tile_count;
212218 }
@@ -280,7 +286,12 @@ class SparseShape {
280286 size_vectors_ (size_vectors),
281287 zero_tile_count_(zero_tile_count),
282288 my_threshold_(my_threshold) {
283- TA_ASSERT (check_norms_finite (tile_norms_));
289+ // Clamp non-finite shape norm estimates (inf/NaN from float overflow in
290+ // shape arithmetic) to real_max. Shape norms are upper bounds used for
291+ // tile screening; clamping preserves correctness (no tile is incorrectly
292+ // zeroed) while preventing inf/NaN from propagating to downstream
293+ // operations.
294+ clamp_nonfinite_norms (tile_norms_);
284295 }
285296
286297 public:
@@ -1273,9 +1284,13 @@ class SparseShape {
12731284 const value_type abs_factor = to_abs_factor (factor);
12741285 madness::AtomicInt zero_tile_count;
12751286 zero_tile_count = 0 ;
1276- auto op = [threshold, &zero_tile_count, abs_factor](value_type value) {
1287+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
1288+ auto op = [threshold, real_max, &zero_tile_count,
1289+ abs_factor](value_type value) {
12771290 value *= abs_factor;
1278- if (value < threshold) {
1291+ if (!std::isfinite (value)) {
1292+ value = real_max;
1293+ } else if (value < threshold) {
12791294 value = value_type (0 );
12801295 ++zero_tile_count;
12811296 }
@@ -1310,9 +1325,13 @@ class SparseShape {
13101325 const value_type abs_factor = to_abs_factor (factor);
13111326 madness::AtomicInt zero_tile_count;
13121327 zero_tile_count = 0 ;
1313- auto op = [threshold, &zero_tile_count, abs_factor](value_type value) {
1328+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
1329+ auto op = [threshold, real_max, &zero_tile_count,
1330+ abs_factor](value_type value) {
13141331 value *= abs_factor;
1315- if (value < threshold) {
1332+ if (!std::isfinite (value)) {
1333+ value = real_max;
1334+ } else if (value < threshold) {
13161335 value = value_type (0 );
13171336 ++zero_tile_count;
13181337 }
@@ -1661,10 +1680,16 @@ class SparseShape {
16611680 // TODO: Make this faster. It can be done without using temporaries
16621681 // for the arguments, but requires a custom matrix multiply.
16631682
1683+ // Preprocessing: multiply tile norms by k_sizes to convert per-element
1684+ // norms to Frobenius norms for the BLAS gemm. Clamp to real_max to
1685+ // prevent overflow to inf (inf * 0 in the gemm would produce NaN).
1686+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
1687+
16641688 Tensor<value_type> left (tile_norms_.range ());
16651689 const size_type mk = M * K;
1666- auto left_op = [](const value_type left, const value_type right) {
1667- return left * right;
1690+ auto left_op = [real_max](const value_type left, const value_type right) {
1691+ auto v = left * right;
1692+ return v > real_max ? real_max : v;
16681693 };
16691694 for (size_type i = 0ul ; i < mk; i += K)
16701695 math::vector_op (left_op, K, left.data () + i, tile_norms_.data () + i,
@@ -1673,30 +1698,39 @@ class SparseShape {
16731698 Tensor<value_type> right (other.tile_norms_ .range ());
16741699 for (integer i = 0ul , k = 0 ; k < K; i += N, ++k) {
16751700 const value_type factor = k_sizes[k];
1676- auto right_op = [=](const value_type arg) { return arg * factor; };
1701+ auto right_op = [=](const value_type arg) {
1702+ auto v = arg * factor;
1703+ return v > real_max ? real_max : v;
1704+ };
16771705 math::vector_op (right_op, N, right.data () + i,
16781706 other.tile_norms_ .data () + i);
16791707 }
16801708
16811709 result_norms = left.gemm (right, abs_factor, gemm_helper);
16821710
1683- // Hard zero tiles that are below the zero threshold.
1711+ // Clamp non-finite results (inf/NaN from float overflow in BLAS gemm)
1712+ // and hard-zero tiles below the threshold.
16841713 result_norms.inplace_unary (
1685- [threshold, &zero_tile_count](value_type& value) {
1686- if (value < threshold) {
1714+ [threshold, real_max, &zero_tile_count](value_type& value) {
1715+ if (!std::isfinite (value)) {
1716+ value = real_max;
1717+ } else if (value < threshold) {
16871718 value = value_type (0 );
16881719 ++zero_tile_count;
16891720 }
16901721 });
16911722
16921723 } else {
16931724 // This is an outer product, so the inputs can be used directly
1725+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
16941726 math::outer_fill (M, N, tile_norms_.data (), other.tile_norms_ .data (),
16951727 result_norms.data (),
1696- [threshold, &zero_tile_count, abs_factor](
1728+ [threshold, real_max, &zero_tile_count, abs_factor](
16971729 const value_type left, const value_type right) {
16981730 value_type norm = left * right * abs_factor;
1699- if (norm < threshold) {
1731+ if (!std::isfinite (norm)) {
1732+ norm = real_max;
1733+ } else if (norm < threshold) {
17001734 norm = value_type (0 );
17011735 ++zero_tile_count;
17021736 }
@@ -1760,6 +1794,19 @@ class SparseShape {
17601794 return true ;
17611795 }
17621796
1797+ // / Clamp non-finite (inf/NaN) shape norm estimates to real_max.
1798+ // / Shape norms are upper bounds; clamping to a large finite value
1799+ // / preserves screening correctness (no false zeros) while preventing
1800+ // / inf/NaN from propagating through subsequent shape arithmetic.
1801+ static void clamp_nonfinite_norms (Tensor<value_type>& norms) {
1802+ constexpr auto real_max = std::numeric_limits<value_type>::max ();
1803+ for (auto & v : norms) {
1804+ if (!std::isfinite (v)) {
1805+ v = real_max;
1806+ }
1807+ }
1808+ }
1809+
17631810 template <MemorySpace S, typename T_>
17641811 friend std::size_t size_of (const SparseShape<T_>& shape);
17651812
0 commit comments