@@ -29,7 +29,7 @@ class TileMerge(Generic[T_OTXDataEntity, T_OTXBatchPredEntity]):
2929 img_infos (list[ImageInfo]): Original image information before tiling.
3030 num_classes (int): Number of classes.
3131 tile_config (TileConfig): Tile configuration.
32- explain_mode (bool): Whether or not tiles have explain features. Default: False.
32+ explain_mode (bool, optional ): Whether or not tiles have explain features. Default: False.
3333 """
3434
3535 def __init__ (
@@ -119,8 +119,8 @@ def merge(
119119 img_ids = []
120120 explain_mode = self .explain_mode
121121
122- for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs ):
123- batch_size = tile_preds . batch_size
122+ for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs , strict = True ):
123+ batch_size = len ( tile_attrs )
124124 saliency_maps = tile_preds .saliency_map if explain_mode else [[] for _ in range (batch_size )]
125125 feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (batch_size )]
126126 for tile_attr , tile_img_info , tile_bboxes , tile_labels , tile_scores , tile_s_map , tile_f_vect in zip (
@@ -131,6 +131,7 @@ def merge(
131131 tile_preds .scores ,
132132 saliency_maps ,
133133 feature_vectors ,
134+ strict = True ,
134135 ):
135136 offset_x , offset_y , _ , _ = tile_attr ["roi" ]
136137 tile_bboxes [:, 0 ::2 ] += offset_x
@@ -156,7 +157,7 @@ def merge(
156157
157158 return [
158159 self ._merge_entities (image_info , entities_to_merge [img_id ], explain_mode )
159- for img_id , image_info in zip (img_ids , self .img_infos )
160+ for img_id , image_info in zip (img_ids , self .img_infos , strict = True )
160161 ]
161162
162163 def _merge_entities (
@@ -319,8 +320,8 @@ def merge(
319320 img_ids = []
320321 explain_mode = self .explain_mode
321322
322- for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs ):
323- feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (tile_preds . batch_size )]
323+ for tile_preds , tile_attrs in zip (batch_tile_preds , batch_tile_attrs , strict = True ):
324+ feature_vectors = tile_preds .feature_vector if explain_mode else [[] for _ in range (len ( tile_attrs ) )]
324325 for tile_attr , tile_img_info , tile_bboxes , tile_labels , tile_scores , tile_masks , tile_f_vect in zip (
325326 tile_attrs ,
326327 tile_preds .imgs_info ,
@@ -329,6 +330,7 @@ def merge(
329330 tile_preds .scores ,
330331 tile_preds .masks ,
331332 feature_vectors ,
333+ strict = True ,
332334 ):
333335 keep_indices = tile_masks .to_sparse ().sum ((1 , 2 )).to_dense () > 0
334336 keep_indices = keep_indices .nonzero (as_tuple = True )[0 ]
@@ -363,7 +365,7 @@ def merge(
363365
364366 return [
365367 self ._merge_entities (image_info , entities_to_merge [img_id ], explain_mode )
366- for img_id , image_info in zip (img_ids , self .img_infos )
368+ for img_id , image_info in zip (img_ids , self .img_infos , strict = True )
367369 ]
368370
369371 def _merge_entities (
0 commit comments