Skip to content

Commit 86ac918

Browse files
committed
adding docstrings
1 parent 6cf1e8a commit 86ac918

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/eca/dataset.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,42 @@
55
from PIL import Image
66

77
class DataSource(Flag):
8+
"""Flag to indicate origin dataset."""
89
CHOLEC = auto()
910
ROBUST = auto()
1011
BOTH = CHOLEC | ROBUST
1112

1213
class AnnotationType(Flag):
14+
"""Flag to indicate type of annotation."""
1315
AREA = auto()
1416
MASK = auto()
1517
BOTH = AREA | MASK
1618

1719
class ECADataset():
18-
def __init__(self, data_directory="eca-data", data_source=DataSource.BOTH, annotation_type=AnnotationType.AREA, include_cropped=True, include_source_info=False) -> None:
20+
"""
21+
Dataloader for the ECA dataset.
22+
23+
Args:
24+
data_directory: root directory of the ECA data.
25+
data_source: flag denoting the origin dataset(s) samples will be taken from.
26+
annotation_type: flag denoting the type of annotation(s) provided.
27+
include_cropped: whether or not to include additional cropped samples.
28+
include_source_info: whether or not to include the source information of the sample.
29+
"""
30+
def __init__(
31+
self,
32+
data_directory: str = "eca-data",
33+
data_source: DataSource = DataSource.BOTH,
34+
annotation_type: AnnotationType = AnnotationType.AREA,
35+
include_cropped: bool = True,
36+
include_source_info: bool = False
37+
) -> None:
1938
super().__init__()
2039
self.data_directory = data_directory
2140
self.annotation_type = annotation_type
2241
self.data_source = data_source
2342
self.include_cropped = include_cropped
2443
self.include_source_info = include_source_info
25-
2644
try:
2745
self.sample_list = self.__get_sample_list()
2846
except FileNotFoundError:

src/eca/scoring.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
from typing import Sequence, Optional, Tuple
23

34
class Line():
45
def __init__(self, a, b) -> None:
@@ -179,7 +180,26 @@ def get_smallest_dist(polygon, point):
179180
other_point = closest
180181
return smallest_dist, other_point
181182

182-
def content_area_hausdorff(circle_a, circle_b, frame_size, n_points=100, normalise=True):
183+
def content_area_hausdorff(
184+
circle_a: Optional[Sequence[int]],
185+
circle_b: Optional[Sequence[int]],
186+
frame_size: Sequence[int],
187+
n_points: int = 100,
188+
normalise: bool = True
189+
) -> Tuple[float, Optional[Sequence[Sequence[int]]]]:
190+
"""
191+
Hausdorff distance between two content areas.
192+
193+
Args:
194+
circle_a: circle for the first content area (can be None).
195+
circle_b: circle for the second content area (can be None).
196+
frame_size: size of the image in question.
197+
n_points: number of points used when discretising the edges of the content areas.
198+
normalise: whether or not to normalise the result as if the image were 1080x1920.
199+
Returns:
200+
float: score in pixels (optionally normalised)
201+
tuple: the coordinates of the two points found to give the final score.
202+
"""
183203

184204
if (circle_a == circle_b):
185205
return 0.0, None

0 commit comments

Comments
 (0)