File tree Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Expand file tree Collapse file tree 1 file changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -344,14 +344,17 @@ def _compute_perturbation(
344344 # Check for NaN before normalisation an replace with 0
345345 if np .isnan (grad ).any ():
346346 logger .warning ("Elements of the loss gradient are NaN and have been replaced with 0.0." )
347- grad = np .where (np .isnan (grad ), np . zeros_like ( grad ) , grad )
347+ grad = np .where (np .isnan (grad ), 0.0 , grad )
348348
349349 # Apply mask
350350 if mask is not None :
351351 grad = np .where (mask == 0.0 , 0.0 , grad )
352352
353353 # Apply norm bound
354354 def _apply_norm (grad , object_type = False ):
355+ if np .isinf (grad ).any ():
356+ logger .info ("The loss gradient array contains at least one positive or negative infinity." )
357+
355358 if self .norm in [np .inf , "inf" ]:
356359 grad = np .sign (grad )
357360 elif self .norm == 1 :
You can’t perform that action at this time.
0 commit comments