Skip to content

Commit 8e26dfa

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Fix Clever data type bug for PyTorch
1 parent e79ebf7 commit 8e26dfa

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

art/metrics.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from art.attacks import FastGradientMethod
1616
from art.utils import random_sphere
17+
from art import NUMPY_DTYPE
1718

1819
logger = logging.getLogger(__name__)
1920

@@ -265,6 +266,7 @@ def clever_t(classifier, x, target_class, nb_batches, batch_size, radius, norm,
265266
rand_pool = np.reshape(random_sphere(nb_points=pool_factor * batch_size, nb_dims=dim, radius=radius, norm=norm),
266267
shape)
267268
rand_pool += np.repeat(np.array([x]), pool_factor * batch_size, 0)
269+
rand_pool = rand_pool.astype(NUMPY_DTYPE)
268270
np.clip(rand_pool, classifier.clip_values[0], classifier.clip_values[1], out=rand_pool)
269271

270272
# Change norm since q = p / (p-1)

0 commit comments

Comments
 (0)