Skip to content

Commit 5f1a61a

Browse files
committed
added resizing to the homography estimator
1 parent d9400c7 commit 5f1a61a

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

endoscopy/homography_estimator.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
from typing import Tuple, Any
3+
from kornia.geometry import resize
34

45
from .utils.loader import MODEL, load_model
56
from .utils.helpers import four_point_homography_to_matrix, image_edges
@@ -17,17 +18,19 @@ def __call__(self, img: torch.FloatTensor, wrp: torch.FloatTensor) -> Tuple[torc
1718
"""Foward pass of BoundingCircleDetector.
1819
1920
Args:
20-
img (torch.FloatTensor): Needs to be normalized in [0, 1].
21-
wrp (torch.FloatTensor): Needs to be normalized in [0, 1].
21+
img (torch.FloatTensor): Needs to be normalized in [0, 1]. Will be resized to 240x320.
22+
wrp (torch.FloatTensor): Needs to be normalized in [0, 1]. Will be resized to 240x320.
2223
Return:
2324
h (torch.Tensor): Homography of shape Bx3x3
2425
duv (torch.Tensor): Four point homography of shape Bx4x2
2526
"""
2627
if img.dim() != 4 or wrp.dim() != 4:
2728
raise RuntimeError("BoundingCircleDetector: Expected 4 dimensional input, got {} dimensional input.".format(img.dim()))
2829

29-
duv = self.model(img, wrp)
30-
uv_img = image_edges(img)
30+
img, wrp = resize(img, [240, 320]), resize(wrp, [240, 320])
31+
32+
duv = self.model(img.to(self.device), wrp.to(self.device))
33+
uv_img = image_edges(img).to(self.device)
3134
h = four_point_homography_to_matrix(uv_img, duv)
3235

3336
return h, duv

0 commit comments

Comments
 (0)