@@ -1159,7 +1159,6 @@ def __init__(
11591159 ) -> None :
11601160 super ().__init__ ()
11611161 self ._validate_parameters (max_translate_ratio , scaling_ratio_range )
1162-
11631162 self .max_rotate_degree = max_rotate_degree
11641163 self .max_translate_ratio = max_translate_ratio
11651164 self .scaling_ratio_range = scaling_ratio_range
@@ -1238,28 +1237,28 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem:
12381237 homography_matrix = self ._get_random_homography_matrix (height , width )
12391238 output_shape = (height + self .border [0 ] * 2 , width + self .border [1 ] * 2 )
12401239
1241- if hasattr (inputs , "bboxes" ) and inputs .bboxes is not None and len (inputs .bboxes ) > 0 :
1240+ transformed_img = self ._warp_image (img , homography_matrix , output_shape )
1241+ inputs .image = transformed_img
1242+ inputs .img_info = _resize_image_info (inputs .img_info , transformed_img .shape [:2 ])
1243+ valid_index = None
1244+ valid_bboxes = hasattr (inputs , "bboxes" ) and inputs .bboxes is not None and len (inputs .bboxes ) > 0
1245+
1246+ if valid_bboxes :
12421247 # Test transform bboxes to see if any remain valid
12431248 valid_index = self ._transform_bboxes (inputs , homography_matrix , output_shape )
12441249 # If no valid annotations will remain after transformation, skip entirely
12451250 if not valid_index .any ():
12461251 inputs .image = img
12471252 return self .convert (inputs ) # type: ignore[return-value]
12481253
1249- # If we reach here, transformation will produce valid results, so proceed
1250- # Transform image
1251- transformed_img = self ._warp_image (img , homography_matrix , output_shape )
1252- inputs .image = transformed_img
1253- inputs .img_info = _resize_image_info (inputs .img_info , transformed_img .shape [:2 ])
1254-
1255- if hasattr (inputs , "masks" ) and inputs .masks is not None and len (inputs .masks ) > 0 :
1256- self ._transform_masks (inputs , homography_matrix , output_shape , valid_index )
1254+ if hasattr (inputs , "masks" ) and inputs .masks is not None and len (inputs .masks ) > 0 :
1255+ self ._transform_masks (inputs , homography_matrix , output_shape , valid_index )
12571256
1258- if hasattr (inputs , "polygons" ) and inputs .polygons is not None and len (inputs .polygons ) > 0 :
1259- self ._transform_polygons (inputs , homography_matrix , output_shape , valid_index )
1257+ if hasattr (inputs , "polygons" ) and inputs .polygons is not None and len (inputs .polygons ) > 0 :
1258+ self ._transform_polygons (inputs , homography_matrix , output_shape , valid_index )
12601259
1261- if self .recompute_bbox :
1262- self ._recompute_bboxes (inputs , output_shape )
1260+ if valid_bboxes and self .recompute_bbox :
1261+ self ._recompute_bboxes (inputs , output_shape )
12631262
12641263 return self .convert (inputs ) # type: ignore[return-value]
12651264
@@ -1321,7 +1320,7 @@ def _transform_masks(
13211320 inputs : OTXDataItem ,
13221321 warp_matrix : np .ndarray ,
13231322 output_size : tuple [int , int ],
1324- valid_index : np .ndarray ,
1323+ valid_index : np .ndarray | None = None ,
13251324 ) -> None :
13261325 """Transform masks using the warp matrix.
13271326
@@ -1335,11 +1334,11 @@ def _transform_masks(
13351334 return
13361335
13371336 # Convert valid_index to numpy boolean array if it's a tensor
1338- if hasattr (valid_index , "numpy" ):
1337+ if valid_index is not None and hasattr (valid_index , "numpy" ):
13391338 valid_index = valid_index .numpy ()
13401339
13411340 # Filter masks using valid_index first
1342- masks = inputs .masks [valid_index ]
1341+ masks = inputs .masks [valid_index ] if valid_index is not None else inputs . masks
13431342 masks = masks .numpy () if not isinstance (masks , np .ndarray ) else masks
13441343
13451344 if masks .ndim == 3 :
@@ -1378,15 +1377,20 @@ def _warp_single_mask(self, mask: np.ndarray, warp_matrix: np.ndarray, output_si
13781377 )
13791378 return warped_mask > 127
13801379
1381- msg = "Multi-class masks are not supported yet."
1382- raise NotImplementedError (msg )
1380+ return cv2 .warpPerspective (
1381+ mask .astype (np .uint8 ),
1382+ warp_matrix ,
1383+ dsize = (width , height ),
1384+ flags = cv2 .INTER_NEAREST ,
1385+ borderValue = 0 ,
1386+ )
13831387
13841388 def _transform_polygons (
13851389 self ,
13861390 inputs : OTXDataItem ,
13871391 warp_matrix : np .ndarray ,
13881392 output_shape : tuple [int , int ],
1389- valid_index : np .ndarray ,
1393+ valid_index : np .ndarray | None = None ,
13901394 ) -> None :
13911395 """Transform polygons using the warp matrix.
13921396
@@ -1405,11 +1409,13 @@ def _transform_polygons(
14051409 return
14061410
14071411 # Convert valid_index to numpy boolean array if it's a tensor
1408- if hasattr (valid_index , "numpy" ):
1412+ if valid_index is not None and hasattr (valid_index , "numpy" ):
14091413 valid_index = valid_index .numpy ()
14101414
1411- # Filter polygons using valid_index
1412- filtered_polygons = [p for p , keep in zip (inputs .polygons , valid_index ) if keep ]
1415+ # Filter polygons using valid_index if available
1416+ filtered_polygons = (
1417+ [p for p , keep in zip (inputs .polygons , valid_index ) if keep ] if valid_index is not None else inputs .polygons
1418+ )
14131419
14141420 if filtered_polygons :
14151421 inputs .polygons = project_polygons (filtered_polygons , warp_matrix , output_shape )
0 commit comments