@@ -485,12 +485,16 @@ def merge_vectors(self, feature_vectors: List[np.ndarray]) -> np.ndarray:
485485 merged_vectors (List[np.ndarray]): Merged vectors for each image.
486486 """
487487
488- vect_per_image = len (feature_vectors ) // self .num_images
489- # split vectors on chunks of vectors related to the same image
490- image_vectors = [
491- feature_vectors [x : x + vect_per_image ] for x in range (0 , len (feature_vectors ), vect_per_image )
492- ]
493- return np .average (image_vectors , axis = 1 )
488+ image_vectors : dict = {}
489+ for vector , tile in zip (feature_vectors , self .tiles ):
490+ data_idx = tile .get ("index" , None ) if "index" in tile else tile .get ("dataset_idx" , None )
491+ if data_idx in image_vectors :
492+ # tile vectors
493+ image_vectors [data_idx ].append (vector )
494+ else :
495+ # whole image vector
496+ image_vectors [data_idx ] = [vector ]
497+ return [np .average (image , axis = 0 ) for idx , image in image_vectors .items ()]
494498
495499 def merge_maps (self , saliency_maps : Union [List [List [np .ndarray ]], List [np .ndarray ]]) -> List :
496500 """Merge tile-level saliency maps to image-level saliency map.
@@ -502,11 +506,24 @@ def merge_maps(self, saliency_maps: Union[List[List[np.ndarray]], List[np.ndarra
502506 Returns:
503507 merged_maps (List[list | np.ndarray | None]): Merged saliency maps for each image.
504508 """
509+
510+ dtype = None
511+ for map in saliency_maps :
512+ for cl_map in map :
513+ # find first class map which is not None
514+ if cl_map is not None and dtype is None :
515+ dtype = map [0 ].dtype
516+ feat_h , feat_w = map [0 ].shape
517+ break
518+ if dtype is not None :
519+ break
520+ else :
521+ # if None for each class for each image
522+ return saliency_maps [: self .num_images ]
523+
505524 merged_maps = []
506525 ratios = {}
507526 num_classes = len (saliency_maps [0 ])
508- feat_h , feat_w = saliency_maps [0 ][0 ].shape
509- dtype = saliency_maps [0 ][0 ][0 ].dtype
510527
511528 for orig_image in self .cached_results :
512529 img_idx = orig_image ["index" ]
0 commit comments