File tree Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Expand file tree Collapse file tree 1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -325,11 +325,13 @@ def clever_t(
325325 sample_xs = rand_pool [np .random .choice (pool_factor * batch_size , batch_size )]
326326
327327 # Compute gradients
328- grads = classifier .class_gradient (sample_xs )
329- if np .isnan (grads ).any ():
328+ grad_pred_class = classifier .class_gradient (sample_xs , label = pred_class )
329+ grad_target_class = classifier .class_gradient (sample_xs , label = target_class )
330+
331+ if np .isnan (grad_pred_class ).any () or np .isnan (grad_target_class ).any ():
330332 raise Exception ("The classifier results NaN gradients." )
331333
332- grad = grads [:, pred_class ] - grads [:, target_class ]
334+ grad = grad_pred_class - grad_target_class
333335 grad = np .reshape (grad , (batch_size , - 1 ))
334336 grad_norm = np .max (np .linalg .norm (grad , ord = norm , axis = 1 ))
335337 grad_norm_set .append (grad_norm )
You can’t perform that action at this time.
0 commit comments