4040
4141logger = logging .getLogger (__name__ )
4242
43+ EPS_LOG = 10 ** - 10
44+
4345
4446class Wasserstein (EvasionAttack ):
4547 """
@@ -377,6 +379,7 @@ def _conjugate_sinkhorn(self, x: np.ndarray, grad: np.ndarray, cost_matrix: np.n
377379 exp_alpha = np .exp (- alpha )
378380
379381 beta = - self .regularization * grad
382+ beta = beta .astype (np .float64 )
380383 exp_beta = np .exp (- beta )
381384
382385 # Check for overflow
@@ -399,6 +402,7 @@ def _conjugate_sinkhorn(self, x: np.ndarray, grad: np.ndarray, cost_matrix: np.n
399402
400403 for _ in range (self .conjugate_sinkhorn_max_iter ):
401404 # Block coordinate descent iterates
405+ x [x == 0.0 ] = EPS_LOG # Prevent divide by zero in np.log
402406 alpha [I_nonzero_ ] = (np .log (self ._local_transport (K , exp_beta , self .kernel_size )) - np .log (x ))[I_nonzero_ ]
403407 exp_alpha = np .exp (- alpha )
404408
@@ -474,6 +478,7 @@ def _projected_sinkhorn(
474478
475479 for _ in range (self .projected_sinkhorn_max_iter ):
476480 # Block coordinate descent iterates
481+ x_init [x_init == 0.0 ] = EPS_LOG # Prevent divide by zero in np.log
477482 alpha = np .log (self ._local_transport (K , exp_beta , self .kernel_size )) - np .log (x_init )
478483 exp_alpha = np .exp (- alpha )
479484
@@ -733,8 +738,10 @@ def _check_params(self) -> None:
733738 if self .eps_step <= 0 :
734739 raise ValueError ("The perturbation step-size `eps_step` has to be positive." )
735740
736- if self .eps_step > self .eps :
737- raise ValueError ("The iteration step `eps_step` has to be smaller than the total attack `eps`." )
741+ if self .norm == "inf" and self .eps_step > self .eps :
742+ raise ValueError (
743+ "The iteration step `eps_step` has to be smaller than or equal to the total attack budget `eps`."
744+ )
738745
739746 if self .eps_iter <= 0 :
740747 raise ValueError ("The number of epsilon iterations `eps_iter` has to be a positive integer." )
0 commit comments