Skip to content

Commit 58f990f

Browse files
authored
Fix RandomAffine and RTDetr training with IoURandomCrop (#4718)
1 parent bfece5d commit 58f990f

File tree

8 files changed

+99
-59
lines changed

8 files changed

+99
-59
lines changed

lib/src/otx/data/transform_libs/torchvision.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,6 @@ def __init__(
11591159
) -> None:
11601160
super().__init__()
11611161
self._validate_parameters(max_translate_ratio, scaling_ratio_range)
1162-
11631162
self.max_rotate_degree = max_rotate_degree
11641163
self.max_translate_ratio = max_translate_ratio
11651164
self.scaling_ratio_range = scaling_ratio_range
@@ -1238,28 +1237,28 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem:
12381237
homography_matrix = self._get_random_homography_matrix(height, width)
12391238
output_shape = (height + self.border[0] * 2, width + self.border[1] * 2)
12401239

1241-
if hasattr(inputs, "bboxes") and inputs.bboxes is not None and len(inputs.bboxes) > 0:
1240+
transformed_img = self._warp_image(img, homography_matrix, output_shape)
1241+
inputs.image = transformed_img
1242+
inputs.img_info = _resize_image_info(inputs.img_info, transformed_img.shape[:2])
1243+
valid_index = None
1244+
valid_bboxes = hasattr(inputs, "bboxes") and inputs.bboxes is not None and len(inputs.bboxes) > 0
1245+
1246+
if valid_bboxes:
12421247
# Test transform bboxes to see if any remain valid
12431248
valid_index = self._transform_bboxes(inputs, homography_matrix, output_shape)
12441249
# If no valid annotations will remain after transformation, skip entirely
12451250
if not valid_index.any():
12461251
inputs.image = img
12471252
return self.convert(inputs) # type: ignore[return-value]
12481253

1249-
# If we reach here, transformation will produce valid results, so proceed
1250-
# Transform image
1251-
transformed_img = self._warp_image(img, homography_matrix, output_shape)
1252-
inputs.image = transformed_img
1253-
inputs.img_info = _resize_image_info(inputs.img_info, transformed_img.shape[:2])
1254-
1255-
if hasattr(inputs, "masks") and inputs.masks is not None and len(inputs.masks) > 0:
1256-
self._transform_masks(inputs, homography_matrix, output_shape, valid_index)
1254+
if hasattr(inputs, "masks") and inputs.masks is not None and len(inputs.masks) > 0:
1255+
self._transform_masks(inputs, homography_matrix, output_shape, valid_index)
12571256

1258-
if hasattr(inputs, "polygons") and inputs.polygons is not None and len(inputs.polygons) > 0:
1259-
self._transform_polygons(inputs, homography_matrix, output_shape, valid_index)
1257+
if hasattr(inputs, "polygons") and inputs.polygons is not None and len(inputs.polygons) > 0:
1258+
self._transform_polygons(inputs, homography_matrix, output_shape, valid_index)
12601259

1261-
if self.recompute_bbox:
1262-
self._recompute_bboxes(inputs, output_shape)
1260+
if valid_bboxes and self.recompute_bbox:
1261+
self._recompute_bboxes(inputs, output_shape)
12631262

12641263
return self.convert(inputs) # type: ignore[return-value]
12651264

@@ -1321,7 +1320,7 @@ def _transform_masks(
13211320
inputs: OTXDataItem,
13221321
warp_matrix: np.ndarray,
13231322
output_size: tuple[int, int],
1324-
valid_index: np.ndarray,
1323+
valid_index: np.ndarray | None = None,
13251324
) -> None:
13261325
"""Transform masks using the warp matrix.
13271326
@@ -1335,11 +1334,11 @@ def _transform_masks(
13351334
return
13361335

13371336
# Convert valid_index to numpy boolean array if it's a tensor
1338-
if hasattr(valid_index, "numpy"):
1337+
if valid_index is not None and hasattr(valid_index, "numpy"):
13391338
valid_index = valid_index.numpy()
13401339

13411340
# Filter masks using valid_index first
1342-
masks = inputs.masks[valid_index]
1341+
masks = inputs.masks[valid_index] if valid_index is not None else inputs.masks
13431342
masks = masks.numpy() if not isinstance(masks, np.ndarray) else masks
13441343

13451344
if masks.ndim == 3:
@@ -1378,15 +1377,20 @@ def _warp_single_mask(self, mask: np.ndarray, warp_matrix: np.ndarray, output_si
13781377
)
13791378
return warped_mask > 127
13801379

1381-
msg = "Multi-class masks are not supported yet."
1382-
raise NotImplementedError(msg)
1380+
return cv2.warpPerspective(
1381+
mask.astype(np.uint8),
1382+
warp_matrix,
1383+
dsize=(width, height),
1384+
flags=cv2.INTER_NEAREST,
1385+
borderValue=0,
1386+
)
13831387

13841388
def _transform_polygons(
13851389
self,
13861390
inputs: OTXDataItem,
13871391
warp_matrix: np.ndarray,
13881392
output_shape: tuple[int, int],
1389-
valid_index: np.ndarray,
1393+
valid_index: np.ndarray | None = None,
13901394
) -> None:
13911395
"""Transform polygons using the warp matrix.
13921396
@@ -1405,11 +1409,13 @@ def _transform_polygons(
14051409
return
14061410

14071411
# Convert valid_index to numpy boolean array if it's a tensor
1408-
if hasattr(valid_index, "numpy"):
1412+
if valid_index is not None and hasattr(valid_index, "numpy"):
14091413
valid_index = valid_index.numpy()
14101414

1411-
# Filter polygons using valid_index
1412-
filtered_polygons = [p for p, keep in zip(inputs.polygons, valid_index) if keep]
1415+
# Filter polygons using valid_index if available
1416+
filtered_polygons = (
1417+
[p for p, keep in zip(inputs.polygons, valid_index) if keep] if valid_index is not None else inputs.polygons
1418+
)
14131419

14141420
if filtered_polygons:
14151421
inputs.polygons = project_polygons(filtered_polygons, warp_matrix, output_shape)

lib/src/otx/recipe/detection/dfine_x.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,25 @@ overrides:
6666
batch_size: 8
6767
num_workers: 4
6868
transforms:
69-
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
70-
init_args:
71-
p: 0.5
7269
- class_path: torchvision.transforms.v2.RandomZoomOut
70+
enable: true
7371
init_args:
7472
fill: 0
7573
- class_path: otx.data.transform_libs.torchvision.RandomIoUCrop
74+
enable: true
7675
init_args:
7776
probability: 0.8
7877
- class_path: torchvision.transforms.v2.SanitizeBoundingBoxes
7978
init_args:
8079
min_size: 1
81-
- class_path: otx.data.transform_libs.torchvision.RandomFlip
82-
init_args:
83-
probability: 0.5
8480
- class_path: otx.data.transform_libs.torchvision.Resize
8581
init_args:
8682
scale: $(input_size)
8783
transform_bbox: true
8884
keep_ratio: false
85+
- class_path: otx.data.transform_libs.torchvision.RandomFlip
86+
init_args:
87+
probability: 0.5
8988
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
9089
enable: false
9190
init_args:

lib/src/otx/recipe/detection/dfine_x_tile.yaml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,26 +68,25 @@ overrides:
6868
num_workers: 4
6969
to_tv_image: true
7070
transforms:
71-
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
72-
init_args:
73-
p: 0.5
7471
- class_path: torchvision.transforms.v2.RandomZoomOut
72+
enable: true
7573
init_args:
7674
fill: 0
7775
- class_path: otx.data.transform_libs.torchvision.RandomIoUCrop
76+
enable: true
7877
init_args:
7978
probability: 0.8
8079
- class_path: torchvision.transforms.v2.SanitizeBoundingBoxes
8180
init_args:
8281
min_size: 1
83-
- class_path: otx.data.transform_libs.torchvision.RandomFlip
84-
init_args:
85-
probability: 0.5
8682
- class_path: otx.data.transform_libs.torchvision.Resize
8783
init_args:
8884
scale: $(input_size)
8985
transform_bbox: true
9086
keep_ratio: false
87+
- class_path: otx.data.transform_libs.torchvision.RandomFlip
88+
init_args:
89+
probability: 0.5
9190
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
9291
enable: false
9392
init_args:

lib/src/otx/recipe/detection/rtdetr_101.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ overrides:
6363
train_subset:
6464
batch_size: 4
6565
transforms:
66+
- class_path: otx.data.transform_libs.torchvision.MinIoURandomCrop
67+
enable: false
68+
- class_path: otx.data.transform_libs.torchvision.Resize
69+
init_args:
70+
scale: $(input_size)
71+
keep_ratio: false
72+
transform_bbox: true
6673
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
6774
enable: false
6875
init_args:
@@ -79,11 +86,6 @@ overrides:
7986
- -0.05
8087
- 0.05
8188
p: 0.5
82-
- class_path: otx.data.transform_libs.torchvision.Resize
83-
init_args:
84-
scale: $(input_size)
85-
keep_ratio: false
86-
transform_bbox: true
8789
- class_path: otx.data.transform_libs.torchvision.RandomAffine
8890
enable: false
8991
init_args:
@@ -94,6 +96,7 @@ overrides:
9496
- 1.5
9597
max_shear_degree: 2.0
9698
- class_path: otx.data.transform_libs.torchvision.RandomFlip
99+
enable: true
97100
init_args:
98101
probability: 0.5
99102
- class_path: torchvision.transforms.v2.RandomVerticalFlip

lib/src/otx/recipe/detection/rtdetr_18.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,13 @@ overrides:
6262
train_subset:
6363
batch_size: 4
6464
transforms:
65+
- class_path: otx.data.transform_libs.torchvision.MinIoURandomCrop
66+
enable: false
67+
- class_path: otx.data.transform_libs.torchvision.Resize
68+
init_args:
69+
scale: $(input_size)
70+
keep_ratio: false
71+
transform_bbox: true
6572
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
6673
enable: false
6774
init_args:
@@ -78,11 +85,6 @@ overrides:
7885
- -0.05
7986
- 0.05
8087
p: 0.5
81-
- class_path: otx.data.transform_libs.torchvision.Resize
82-
init_args:
83-
scale: $(input_size)
84-
keep_ratio: false
85-
transform_bbox: true
8688
- class_path: otx.data.transform_libs.torchvision.RandomAffine
8789
enable: false
8890
init_args:
@@ -93,6 +95,7 @@ overrides:
9395
- 1.5
9496
max_shear_degree: 2.0
9597
- class_path: otx.data.transform_libs.torchvision.RandomFlip
98+
enable: true
9699
init_args:
97100
probability: 0.5
98101
- class_path: torchvision.transforms.v2.RandomVerticalFlip

lib/src/otx/recipe/detection/rtdetr_50.yaml

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ overrides:
6363
train_subset:
6464
batch_size: 4
6565
transforms:
66+
- class_path: otx.data.transform_libs.torchvision.MinIoURandomCrop
67+
enable: false
68+
- class_path: otx.data.transform_libs.torchvision.Resize
69+
init_args:
70+
scale: $(input_size)
71+
keep_ratio: false
72+
transform_bbox: true
6673
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
6774
enable: false
6875
init_args:
@@ -79,11 +86,6 @@ overrides:
7986
- -0.05
8087
- 0.05
8188
p: 0.5
82-
- class_path: otx.data.transform_libs.torchvision.Resize
83-
init_args:
84-
scale: $(input_size)
85-
keep_ratio: false
86-
transform_bbox: true
8789
- class_path: otx.data.transform_libs.torchvision.RandomAffine
8890
enable: false
8991
init_args:
@@ -94,6 +96,7 @@ overrides:
9496
- 1.5
9597
max_shear_degree: 2.0
9698
- class_path: otx.data.transform_libs.torchvision.RandomFlip
99+
enable: true
97100
init_args:
98101
probability: 0.5
99102
- class_path: torchvision.transforms.v2.RandomVerticalFlip

lib/src/otx/recipe/detection/rtmdet_tiny.yaml

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -81,15 +81,6 @@ overrides:
8181
- class_path: otx.data.transform_libs.torchvision.RandomCrop
8282
init_args:
8383
crop_size: $(input_size)
84-
- class_path: otx.data.transform_libs.torchvision.RandomAffine
85-
enable: false
86-
init_args:
87-
max_rotate_degree: 10.0
88-
max_translate_ratio: 0.1
89-
scaling_ratio_range:
90-
- 0.5
91-
- 1.5
92-
max_shear_degree: 2.0
9384
- class_path: torchvision.transforms.v2.RandomPhotometricDistort
9485
enable: false
9586
init_args:
@@ -106,6 +97,15 @@ overrides:
10697
- -0.05
10798
- 0.05
10899
p: 0.5
100+
- class_path: otx.data.transform_libs.torchvision.RandomAffine
101+
enable: false
102+
init_args:
103+
max_rotate_degree: 10.0
104+
max_translate_ratio: 0.1
105+
scaling_ratio_range:
106+
- 0.5
107+
- 1.5
108+
max_shear_degree: 2.0
109109
- class_path: otx.data.transform_libs.torchvision.YOLOXHSVRandomAug
110110
- class_path: otx.data.transform_libs.torchvision.RandomFlip
111111
init_args:

lib/tests/unit/data/transform_libs/test_torchvision.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,17 @@ def close(self):
5151
return
5252

5353

54+
@pytest.fixture()
55+
def seg_data_entity() -> OTXDataItem:
56+
masks = torch.randint(low=0, high=2, size=(1, 112, 224), dtype=torch.uint8)
57+
return OTXDataItem(
58+
image=tv_tensors.Image(torch.randint(low=0, high=256, size=(3, 112, 224), dtype=torch.uint8)),
59+
img_info=ImageInfo(img_idx=0, img_shape=(112, 224), ori_shape=(112, 224)),
60+
masks=tv_tensors.Mask(masks),
61+
label=LongTensor([1]),
62+
)
63+
64+
5465
@pytest.fixture()
5566
def det_data_entity() -> OTXDataItem:
5667
return OTXDataItem(
@@ -359,6 +370,22 @@ def test_forward(self, random_affine: RandomAffine, det_data_entity: OTXDataItem
359370
assert results.bboxes.dtype == torch.float32
360371
assert results.img_info.img_shape == results.image.shape[:2]
361372

373+
def test_segmentation_transform(
374+
self, random_affine_with_mask_transform: RandomAffine, seg_data_entity: OTXDataItem
375+
) -> None:
376+
"""Test forward for segmentation task."""
377+
original_entity = deepcopy(seg_data_entity)
378+
results = random_affine_with_mask_transform(original_entity)
379+
380+
assert hasattr(results, "masks")
381+
assert results.masks is not None
382+
assert results.masks.shape[0] > 0 # Should have masks
383+
assert results.masks.shape[1:] == results.image.shape[:2] # Same spatial dimensions as image
384+
385+
# Check that the number of masks matches the number of remaining bboxes and labels
386+
assert results.masks.shape[0] == results.label.shape[0]
387+
assert isinstance(results.masks, tv_tensors.Mask)
388+
362389
def test_forward_with_masks_transform_enabled(
363390
self,
364391
random_affine_with_mask_transform: RandomAffine,

0 commit comments

Comments
 (0)