Skip to content

Commit db7b04e

Browse files
authored
Merge pull request #780 from Trusted-AI/development_issue_776
Increase precision for exponentiation in Wasserstein
2 parents a900e61 + f1517d2 commit db7b04e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

art/attacks/evasion/wasserstein.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
logger = logging.getLogger(__name__)
4242

43+
EPS_LOG = 10 ** -10
44+
4345

4446
class Wasserstein(EvasionAttack):
4547
"""
@@ -377,6 +379,7 @@ def _conjugate_sinkhorn(self, x: np.ndarray, grad: np.ndarray, cost_matrix: np.n
377379
exp_alpha = np.exp(-alpha)
378380

379381
beta = -self.regularization * grad
382+
beta = beta.astype(np.float64)
380383
exp_beta = np.exp(-beta)
381384

382385
# Check for overflow
@@ -399,6 +402,7 @@ def _conjugate_sinkhorn(self, x: np.ndarray, grad: np.ndarray, cost_matrix: np.n
399402

400403
for _ in range(self.conjugate_sinkhorn_max_iter):
401404
# Block coordinate descent iterates
405+
x[x == 0.0] = EPS_LOG # Prevent divide by zero in np.log
402406
alpha[I_nonzero_] = (np.log(self._local_transport(K, exp_beta, self.kernel_size)) - np.log(x))[I_nonzero_]
403407
exp_alpha = np.exp(-alpha)
404408

@@ -474,6 +478,7 @@ def _projected_sinkhorn(
474478

475479
for _ in range(self.projected_sinkhorn_max_iter):
476480
# Block coordinate descent iterates
481+
x_init[x_init == 0.0] = EPS_LOG # Prevent divide by zero in np.log
477482
alpha = np.log(self._local_transport(K, exp_beta, self.kernel_size)) - np.log(x_init)
478483
exp_alpha = np.exp(-alpha)
479484

@@ -733,8 +738,10 @@ def _check_params(self) -> None:
733738
if self.eps_step <= 0:
734739
raise ValueError("The perturbation step-size `eps_step` has to be positive.")
735740

736-
if self.eps_step > self.eps:
737-
raise ValueError("The iteration step `eps_step` has to be smaller than the total attack `eps`.")
741+
if self.norm == "inf" and self.eps_step > self.eps:
742+
raise ValueError(
743+
"The iteration step `eps_step` has to be smaller than or equal to the total attack budget `eps`."
744+
)
738745

739746
if self.eps_iter <= 0:
740747
raise ValueError("The number of epsilon iterations `eps_iter` has to be a positive integer.")

0 commit comments

Comments
 (0)