|
1 | 1 | import ast |
2 | 2 | import inspect |
| 3 | +import math |
3 | 4 | from dataclasses import dataclass |
4 | 5 | from typing import Any, Union |
5 | 6 |
|
6 | 7 | import matplotlib.pyplot as plt |
| 8 | +import PIL |
7 | 9 | from browsergym.core.action.highlevel import ACTION_SUBSETS |
8 | 10 | from PIL import Image, ImageDraw |
9 | 11 |
|
@@ -289,17 +291,54 @@ def overlay_rectangle( |
289 | 291 | bbox: tuple[float, float, float, float], |
290 | 292 | color: Union[str, tuple[int, int, int]] = "red", |
291 | 293 | width: int = 1, |
| 294 | + dashed: bool = True, |
292 | 295 | ) -> Image.Image: |
293 | 296 | draw = ImageDraw.Draw(img) |
294 | 297 |
|
295 | 298 | x, y, w, h = bbox |
296 | 299 |
|
297 | | - # Draw rectangle outline |
298 | | - draw.rectangle([x, y, x + w, y + h], outline=color, width=width) |
| 300 | + if dashed: |
| 301 | + # Draw dashed rectangle |
| 302 | + print("Drawing dashed rectangle") |
| 303 | + linedashed(draw, x, y, x + w, y, color, width) |
| 304 | + linedashed(draw, x + w, y, x + w, y + h, color, width) |
| 305 | + linedashed(draw, x + w, y + h, x, y + h, color, width) |
| 306 | + linedashed(draw, x, y + h, x, y, color, width) |
| 307 | + else: |
| 308 | + draw.rectangle([x, y, x + w, y + h], outline=color, width=width) |
299 | 309 |
|
300 | 310 | return img |
301 | 311 |
|
302 | 312 |
|
| 313 | +# Adapted from https://stackoverflow.com/questions/51908563/dotted-or-dashed-line-with-python-pillow/58885306#58885306 |
| 314 | +def linedashed( |
| 315 | + draw: PIL.ImageDraw.Draw, x0, y0, x1, y1, fill, width, dash_length=4, nodash_length=8 |
| 316 | +): |
| 317 | + line_dx = x1 - x0 # delta x (can be negative) |
| 318 | + line_dy = y1 - y0 # delta y (can be negative) |
| 319 | + line_length = math.hypot(line_dx, line_dy) # line length (positive) |
| 320 | + if line_length == 0: |
| 321 | + return # Avoid division by zero in case the line length is 0 |
| 322 | + pixel_dx = line_dx / line_length # x add for 1px line length |
| 323 | + pixel_dy = line_dy / line_length # y add for 1px line length |
| 324 | + dash_start = 0 |
| 325 | + while dash_start < line_length: |
| 326 | + dash_end = dash_start + dash_length |
| 327 | + if dash_end > line_length: |
| 328 | + dash_end = line_length |
| 329 | + draw.line( |
| 330 | + ( |
| 331 | + round(x0 + pixel_dx * dash_start), |
| 332 | + round(y0 + pixel_dy * dash_start), |
| 333 | + round(x0 + pixel_dx * dash_end), |
| 334 | + round(y0 + pixel_dy * dash_end), |
| 335 | + ), |
| 336 | + fill=fill, |
| 337 | + width=width, |
| 338 | + ) |
| 339 | + dash_start += dash_length + nodash_length |
| 340 | + |
| 341 | + |
303 | 342 | def annotate_action( |
304 | 343 | img: Image.Image, action_string: str, properties: dict[str, tuple], colormap: str = "tab10" |
305 | 344 | ) -> str: |
|
0 commit comments