@@ -126,39 +126,59 @@ def class_gradient(self, x, label=None, logits=False):
126126
127127 :param x: Sample input with shape as expected by the model.
128128 :type x: `np.ndarray`
129- :param label: Index of a specific per-class derivative. If `None`, then gradients for all
130- classes will be computed.
131- :type label: `int`
129+ :param label: Index of a specific per-class derivative. If an integer is provided, the gradient of that class
130+ output is computed for all samples. If multiple values as provided, the first dimension should
131+ match the batch size of `x`, and each value will be used as target for its corresponding sample in
132+ `x`. If `None`, then gradients for all classes will be computed for each sample.
133+ :type label: `int` or `list`
132134 :param logits: `True` if the prediction should be done at the logits layer.
133135 :type logits: `bool`
134136 :return: Array of gradients of input features w.r.t. each class in the form
135137 `(batch_size, nb_classes, input_shape)` when computing for all classes, otherwise shape becomes
136138 `(batch_size, 1, input_shape)` when `label` parameter is specified.
137139 :rtype: `np.ndarray`
138140 """
139- if label is not None and label not in range (self ._nb_classes ):
140- raise ValueError ('Label %s is out of range.' % label )
141+ # Check value of label for computing gradients
142+ if not (label is None or (isinstance (label , (int , np .integer )) and label in range (self .nb_classes ))
143+ or (type (label ) is np .ndarray and len (label .shape ) == 1 and (label < self .nb_classes ).all ()
144+ and label .shape [0 ] == x .shape [0 ])):
145+ raise ValueError ('Label %s is out of range.' % str (label ))
141146
142147 self ._init_class_grads (label = label , logits = logits )
143148
144149 x_ = self ._apply_processing (x )
145150
146- if label is not None :
151+ if label is None :
152+ # Compute the gradients w.r.t. all classes
153+ if logits :
154+ grads = np .swapaxes (np .array (self ._class_grads_logits ([x_ ])), 0 , 1 )
155+ else :
156+ grads = np .swapaxes (np .array (self ._class_grads ([x_ ])), 0 , 1 )
157+
158+ grads = self ._apply_processing_gradient (grads )
159+
160+ elif isinstance (label , (int , np .integer )):
161+ # Compute the gradients only w.r.t. the provided label
147162 if logits :
148163 grads = np .swapaxes (np .array (self ._class_grads_logits_idx [label ]([x_ ])), 0 , 1 )
149164 else :
150165 grads = np .swapaxes (np .array (self ._class_grads_idx [label ]([x_ ])), 0 , 1 )
151166
152167 grads = self ._apply_processing_gradient (grads )
153168 assert grads .shape == (x_ .shape [0 ], 1 ) + self .input_shape
169+
154170 else :
171+ # For each sample, compute the gradients w.r.t. the indicated target class (possibly distinct)
172+ unique_label = list (np .unique (label ))
155173 if logits :
156- grads = np .swapaxes ( np . array (self ._class_grads_logits ([x_ ])), 0 , 1 )
174+ grads = np .array ([ self ._class_grads_logits_idx [ l ] ([x_ ]) for l in unique_label ] )
157175 else :
158- grads = np .swapaxes (np .array (self ._class_grads ([x_ ])), 0 , 1 )
176+ grads = np .array ([self ._class_grads_idx [l ]([x_ ]) for l in unique_label ])
177+ grads = np .swapaxes (np .squeeze (grads , axis = 1 ), 0 , 1 )
178+ lst = [unique_label .index (i ) for i in label ]
179+ grads = np .expand_dims (grads [np .arange (len (grads )), lst ], axis = 1 )
159180
160181 grads = self ._apply_processing_gradient (grads )
161- assert grads .shape == (x_ .shape [0 ], self .nb_classes ) + self .input_shape
162182
163183 return grads
164184
@@ -278,35 +298,49 @@ def _init_class_grads(self, label=None, logits=False):
278298 import keras .backend as k
279299 k .set_learning_phase (0 )
280300
281- if label is not None :
282- logger .debug ('Computing class gradients for class %i.' , label )
283- if logits :
284- if not hasattr (self , '_class_grads_logits_idx' ):
285- self ._class_grads_logits_idx = [None for _ in range (self .nb_classes )]
286-
287- if self ._class_grads_logits_idx [label ] is None :
288- class_grads_logits = [k .gradients (self ._preds_op [:, label ], self ._input )[0 ]]
289- self ._class_grads_logits_idx [label ] = k .function ([self ._input ], class_grads_logits )
290- else :
291- if not hasattr (self , '_class_grads_idx' ):
292- self ._class_grads_idx = [None for _ in range (self .nb_classes )]
293-
294- if self ._class_grads_idx [label ] is None :
295- class_grads = [k .gradients (k .softmax (self ._preds_op )[:, label ], self ._input )[0 ]]
296- self ._class_grads_idx [label ] = k .function ([self ._input ], class_grads )
301+ if len (self ._output .shape ) == 2 :
302+ nb_outputs = self ._output .shape [1 ]
297303 else :
304+ raise ValueError ('Unexpected output shape for classification in Keras model.' )
305+
306+ if label is None :
298307 logger .debug ('Computing class gradients for all %i classes.' , self .nb_classes )
299308 if logits :
300309 if not hasattr (self , '_class_grads_logits' ):
301310 class_grads_logits = [k .gradients (self ._preds_op [:, i ], self ._input )[0 ]
302- for i in range (self . nb_classes )]
311+ for i in range (nb_outputs )]
303312 self ._class_grads_logits = k .function ([self ._input ], class_grads_logits )
304313 else :
305314 if not hasattr (self , '_class_grads' ):
306315 class_grads = [k .gradients (k .softmax (self ._preds_op )[:, i ], self ._input )[0 ]
307- for i in range (self . nb_classes )]
316+ for i in range (nb_outputs )]
308317 self ._class_grads = k .function ([self ._input ], class_grads )
309318
319+ else :
320+ if type (label ) is int :
321+ unique_labels = [label ]
322+ logger .debug ('Computing class gradients for class %i.' , label )
323+ else :
324+ unique_labels = np .unique (label )
325+ logger .debug ('Computing class gradients for classes %s.' , str (unique_labels ))
326+
327+ if logits :
328+ if not hasattr (self , '_class_grads_logits_idx' ):
329+ self ._class_grads_logits_idx = [None for _ in range (nb_outputs )]
330+
331+ for l in unique_labels :
332+ if self ._class_grads_logits_idx [l ] is None :
333+ class_grads_logits = [k .gradients (self ._preds_op [:, l ], self ._input )[0 ]]
334+ self ._class_grads_logits_idx [l ] = k .function ([self ._input ], class_grads_logits )
335+ else :
336+ if not hasattr (self , '_class_grads_idx' ):
337+ self ._class_grads_idx = [None for _ in range (nb_outputs )]
338+
339+ for l in unique_labels :
340+ if self ._class_grads_idx [l ] is None :
341+ class_grads = [k .gradients (k .softmax (self ._preds_op )[:, l ], self ._input )[0 ]]
342+ self ._class_grads_idx [l ] = k .function ([self ._input ], class_grads )
343+
310344 def _get_layers (self ):
311345 """
312346 Return the hidden layers in the model, if applicable.
0 commit comments