Skip to content

Commit b267ec6

Browse files
authored
Merge pull request #539 from joemathai/faster_clever_metric
faster clever_t computation
2 parents 763dc38 + bd5f8ff commit b267ec6

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

art/metrics/metrics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,11 +325,13 @@ def clever_t(
325325
sample_xs = rand_pool[np.random.choice(pool_factor * batch_size, batch_size)]
326326

327327
# Compute gradients
328-
grads = classifier.class_gradient(sample_xs)
329-
if np.isnan(grads).any():
328+
grad_pred_class = classifier.class_gradient(sample_xs, label=pred_class)
329+
grad_target_class = classifier.class_gradient(sample_xs, label=target_class)
330+
331+
if np.isnan(grad_pred_class).any() or np.isnan(grad_target_class).any():
330332
raise Exception("The classifier results NaN gradients.")
331333

332-
grad = grads[:, pred_class] - grads[:, target_class]
334+
grad = grad_pred_class - grad_target_class
333335
grad = np.reshape(grad, (batch_size, -1))
334336
grad_norm = np.max(np.linalg.norm(grad, ord=norm, axis=1))
335337
grad_norm_set.append(grad_norm)

0 commit comments

Comments
 (0)