|
18 | 18 | import numpy as np |
19 | 19 | import PIL.Image |
20 | 20 | import torch |
| 21 | +import uuid |
21 | 22 | import torchvision.transforms.v2 as tvt_v2 |
22 | 23 | import typeguard |
23 | 24 | from lightning.pytorch.cli import instantiate_class |
@@ -2124,8 +2125,56 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem | None: |
2124 | 2125 |
|
2125 | 2126 | inputs.polygons = [mixup_gt_polygons[i] for i in np.where(inside_inds)[0]] |
2126 | 2127 |
|
| 2128 | + # self.visualize(inputs, output_path=f"/home/kprokofi/debug_images/{str(uuid.uuid4())}.jpg") |
2127 | 2129 | return self.convert(inputs) |
2128 | 2130 |
|
| 2131 | + def visualize( |
| 2132 | + self, |
| 2133 | + inputs: OTXDataItem, |
| 2134 | + output_path: str | None = None, |
| 2135 | + show_blended: bool = True, |
| 2136 | + ) -> np.ndarray: |
| 2137 | + """Visualize CopyBlend augmentation for debugging. |
| 2138 | +
|
| 2139 | + Args: |
| 2140 | + inputs: OTXDataItem to visualize. |
| 2141 | + output_path: Optional path to save visualization. |
| 2142 | + show_blended: Whether to show blended boxes in different color. |
| 2143 | +
|
| 2144 | + Returns: |
| 2145 | + Visualization as numpy array. |
| 2146 | + """ |
| 2147 | + import cv2 |
| 2148 | + |
| 2149 | + img = to_np_image(inputs.image).copy() |
| 2150 | + bboxes = inputs.bboxes |
| 2151 | + labels = inputs.label |
| 2152 | + |
| 2153 | + # Draw bboxes |
| 2154 | + for idx, bbox in enumerate(bboxes): |
| 2155 | + x1, y1, x2, y2 = bbox.int().tolist() |
| 2156 | + label = labels[idx].item() if hasattr(labels[idx], "item") else labels[idx] |
| 2157 | + |
| 2158 | + # Use different colors for original vs blended |
| 2159 | + # Assume last N boxes are blended (where N = num_objects) |
| 2160 | + color = (0, 255, 0) |
| 2161 | + |
| 2162 | + cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) |
| 2163 | + cv2.putText( |
| 2164 | + img, |
| 2165 | + f"{label}", |
| 2166 | + (x1, y1 - 5), |
| 2167 | + cv2.FONT_HERSHEY_SIMPLEX, |
| 2168 | + 0.5, |
| 2169 | + color, |
| 2170 | + 2, |
| 2171 | + ) |
| 2172 | + |
| 2173 | + if output_path: |
| 2174 | + cv2.imwrite(output_path, img) |
| 2175 | + |
| 2176 | + return img |
| 2177 | + |
2129 | 2178 | def __repr__(self): |
2130 | 2179 | repr_str = self.__class__.__name__ |
2131 | 2180 | repr_str += f"(dynamic_scale={self.dynamic_scale}, " |
@@ -2403,7 +2452,6 @@ def forward(self, *_inputs: OTXDataItem) -> OTXDataItem | None: |
2403 | 2452 | canvas_size=(img_h, img_w), |
2404 | 2453 | ) |
2405 | 2454 | inputs.label = combined_labels |
2406 | | - |
2407 | 2455 | return self.convert(inputs) |
2408 | 2456 |
|
2409 | 2457 | def visualize( |
|
0 commit comments