2727import numpy as np
2828import tensorflow as tf
2929import cv2
30+ from skimage .transform import resize as skimage_resize
3031
3132from .guided_grad import replace_grad_to_guided_grad
3233from .candidate_ops import candidate_featuremap_op_names , candidate_predict_op_names
@@ -76,10 +77,11 @@ class Gradcam:
7677 def __init__ (self , x_placeholder , num_classes , featuremap_op_name , predict_op_name = None , graph = None ):
7778 self ._x_placeholder = x_placeholder
7879 graph = graph if graph is not None else tf .get_default_graph ()
80+ self .graph = graph
7981
8082 predict_op_name = self ._find_prob_layer (predict_op_name , graph )
81- self ._prob_ts = graph .get_operation_by_name (predict_op_name ).outputs
82- self ._target_ts = graph .get_operation_by_name (featuremap_op_name ).outputs
83+ self ._prob_ts = graph .get_operation_by_name (predict_op_name ).outputs [ 0 ]
84+ self ._target_ts = graph .get_operation_by_name (featuremap_op_name ).outputs [ 0 ]
8385
8486 self ._class_idx = tf .placeholder (tf .int32 )
8587 top1 = tf .argmax (tf .reshape (self ._prob_ts , [- 1 ]))
@@ -91,10 +93,10 @@ def __init__(self, x_placeholder, num_classes, featuremap_op_name, predict_op_na
9193
9294 replace_grad_to_guided_grad (graph )
9395
94- max_output = tf .reduce_max (self ._target_ts , axis = 3 )
96+ max_output = tf .reduce_max (self ._target_ts , axis = 2 )
9597 self ._saliency_map = tf .gradients (tf .reduce_sum (max_output ), x_placeholder )[0 ]
9698
97- def gradcam (self , sess , input_data , target_index = None ):
99+ def gradcam (self , sess , input_data , target_index = None , feed_options = dict () ):
98100 """ Calculate Grad-CAM (class activation map) and Guided Grad-CAM for given input on target class
99101
100102 Parameters
@@ -106,6 +108,8 @@ def gradcam(self, sess, input_data, target_index=None):
106108 target_index : int
107109 Target class index
108110 If None, predicted class index is used
111+ feed_options : dict
112+ Optional parameters to graph
109113
110114 Returns
111115 -------
@@ -120,41 +124,54 @@ def gradcam(self, sess, input_data, target_index=None):
120124 * guided_backprop: Guided backprop result
121125
122126 """
123-
124127 input_feed = np .expand_dims (input_data , axis = 0 )
125- image_height , image_width = input_data .shape [:2 ]
128+ if input_data .ndim == 3 :
129+ is_image = True
130+ image_height , image_width = input_data .shape [:2 ]
131+ if input_data .ndim == 1 :
132+ is_image = False
133+ input_length = input_data .shape [0 ]
126134
127135 if target_index is not None :
128- conv_out_eval , grad_eval = sess . run (
129- [ self . _target_ts , self . _grad_by_idx ],
130- feed_dict = { self ._x_placeholder : input_feed , self ._class_idx : target_index } )
136+ feed_dict = { self . _x_placeholder : input_feed , self . _class_idx : target_index }
137+ feed_dict . update ( feed_options )
138+ conv_out_eval , grad_eval = sess . run ([ self ._target_ts , self ._grad_by_idx ], feed_dict = feed_dict )
131139 else :
132- conv_out_eval , grad_eval = sess . run (
133- [ self . _target_ts , self . _grad_by_top1 ],
134- feed_dict = { self ._x_placeholder : input_feed } )
140+ feed_dict = { self . _x_placeholder : input_feed }
141+ feed_dict . update ( feed_options )
142+ conv_out_eval , grad_eval = sess . run ([ self . _target_ts , self ._grad_by_top1 ], feed_dict = feed_dict )
135143
136144 weights = np .mean (grad_eval , axis = (0 , 1 , 2 ))
137145 conv_out_eval = np .squeeze (conv_out_eval , axis = 0 )
138- cam = np .ones (conv_out_eval .shape [:2 ], dtype = np .float32 )
146+ cam = np .zeros (conv_out_eval .shape [:2 ], dtype = np .float32 )
139147
140148 for i , w in enumerate (weights ):
141149 cam += w * conv_out_eval [:, :, i ]
142- cam = cv2 .resize (cam , (image_height , image_width ))
150+
151+ if is_image :
152+ cam += 1
153+ cam = cv2 .resize (cam , (image_height , image_width ))
154+ saliency_val = sess .run (self ._saliency_map , feed_dict = {self ._x_placeholder : input_feed })
155+ saliency_val = np .squeeze (saliency_val , axis = 0 )
156+ else :
157+ cam = skimage_resize (cam , (input_length , 1 ), preserve_range = True , mode = 'reflect' )
158+ cam = np .transpose (cam )
159+
143160 cam = np .maximum (cam , 0 )
144161 heatmap = cam / np .max (cam )
145162
146- saliency_val = sess .run (self ._saliency_map , feed_dict = {self ._x_placeholder : input_feed })
147- saliency_val = np .squeeze (saliency_val , axis = 0 )
163+ ret = {'heatmap' : heatmap }
148164
149- return {
150- 'gradcam_img' : self .overlay_gradcam (input_data , heatmap ),
151- 'guided_gradcam_img' : _deprocess_image (saliency_val * heatmap [..., None ]),
152- 'heatmap' : heatmap ,
153- 'guided_backprop' : saliency_val
154- }
165+ if is_image :
166+ ret .update ({
167+ 'gradcam_img' : self .overlay_gradcam (input_data , heatmap ),
168+ 'guided_gradcam_img' : _deprocess_image (saliency_val * heatmap [..., None ]),
169+ 'guided_backprop' : saliency_val
170+ })
171+ return ret
155172
156173 @staticmethod
157- def candidate_featuremap_op_names (sess , graph = None ):
174+ def candidate_featuremap_op_names (sess , graph = None , feed_options = dict () ):
158175 """ Returns the list of candidates for operation names of CNN feature map layer
159176
160177 Parameters
@@ -163,17 +180,19 @@ def candidate_featuremap_op_names(sess, graph=None):
163180 Tensorflow session
164181 graph: tf.Graph
165182 Tensorflow graph
183+ feed_options: dict
184+ Optional parameters to graph
166185 Returns
167186 -------
168187 list
169188 String list of candidates
170189
171190 """
172191 graph = graph if graph is not None else tf .get_default_graph ()
173- return candidate_featuremap_op_names (sess , graph )
192+ return candidate_featuremap_op_names (sess , graph , feed_options )
174193
175194 @staticmethod
176- def candidate_predict_op_names (sess , num_classes , graph = None ):
195+ def candidate_predict_op_names (sess , num_classes , graph = None , feed_options = dict () ):
177196 """ Returns the list of candidate for operation names of prediction layer
178197
179198 Parameters
@@ -184,14 +203,16 @@ def candidate_predict_op_names(sess, num_classes, graph=None):
184203 Number of prediction classes
185204 graph: tf.Graph
186205 Tensorflow graph
206+ feed_options: dict
207+ Optional parameters to graph
187208 Returns
188209 -------
189210 list
190211 String list of candidates
191212
192213 """
193214 graph = graph if graph is not None else tf .get_default_graph ()
194- return candidate_predict_op_names (sess , num_classes , graph )
215+ return candidate_predict_op_names (sess , num_classes , graph , feed_options )
195216
196217 @staticmethod
197218 def overlay_gradcam (image , heatmap ):
0 commit comments