1212from mipcandy .data .dataset import SupervisedDataset
1313from mipcandy .data .geometric import crop
1414from mipcandy .layer import HasDevice
15- from mipcandy .types import Device
15+ from mipcandy .types import Device , Shape , AmbiguousShape
1616
1717
1818def format_bbox (bbox : Sequence [int ]) -> tuple [int , int , int , int ] | tuple [int , int , int , int , int , int ]:
@@ -26,11 +26,11 @@ def format_bbox(bbox: Sequence[int]) -> tuple[int, int, int, int] | tuple[int, i
2626
2727@dataclass
2828class InspectionAnnotation (object ):
29- shape : tuple [ int , ...]
29+ shape : AmbiguousShape
3030 foreground_bbox : tuple [int , int , int , int ] | tuple [int , int , int , int , int , int ]
3131 ids : tuple [int , ...]
3232
33- def foreground_shape (self ) -> tuple [ int , int ] | tuple [ int , int , int ] :
33+ def foreground_shape (self ) -> Shape :
3434 r = (self .foreground_bbox [1 ] - self .foreground_bbox [0 ], self .foreground_bbox [3 ] - self .foreground_bbox [2 ])
3535 return r if len (self .foreground_bbox ) == 4 else r + (self .foreground_bbox [5 ] - self .foreground_bbox [4 ],)
3636
@@ -50,13 +50,13 @@ def __init__(self, dataset: SupervisedDataset, background: int, *annotations: In
5050 self ._dataset : SupervisedDataset = dataset
5151 self ._background : int = background
5252 self ._annotations : tuple [InspectionAnnotation , ...] = annotations
53- self ._shapes : tuple [tuple [ int , ...] | None , tuple [ int , ...], tuple [ int , ...] ] | None = None
54- self ._foreground_shapes : tuple [tuple [ int , ...] | None , tuple [ int , ...], tuple [ int , ...] ] | None = None
55- self ._statistical_foreground_shape : tuple [ int , int ] | tuple [ int , int , int ] | None = None
53+ self ._shapes : tuple [AmbiguousShape | None , AmbiguousShape , AmbiguousShape ] | None = None
54+ self ._foreground_shapes : tuple [AmbiguousShape | None , AmbiguousShape , AmbiguousShape ] | None = None
55+ self ._statistical_foreground_shape : Shape | None = None
5656 self ._foreground_heatmap : torch .Tensor | None = None
5757 self ._center_of_foregrounds : tuple [int , int ] | tuple [int , int , int ] | None = None
5858 self ._foreground_offsets : tuple [int , int ] | tuple [int , int , int ] | None = None
59- self ._roi_shape : tuple [ int , int ] | tuple [ int , int , int ] | None = None
59+ self ._roi_shape : Shape | None = None
6060
6161 def dataset (self ) -> SupervisedDataset :
6262 return self ._dataset
@@ -79,8 +79,8 @@ def save(self, path: str | PathLike[str]) -> None:
7979 with open (path , "w" ) as f :
8080 dump ({"background" : self ._background , "annotations" : [a .to_dict () for a in self ._annotations ]}, f )
8181
82- def _get_shapes (self , get_shape : Callable [[InspectionAnnotation ], tuple [ int , ...] ]) -> tuple [
83- tuple [ int , ...] | None , tuple [ int , ...], tuple [ int , ...] ]:
82+ def _get_shapes (self , get_shape : Callable [[InspectionAnnotation ], AmbiguousShape ]) -> tuple [
83+ AmbiguousShape | None , AmbiguousShape , AmbiguousShape ]:
8484 depths = []
8585 widths = []
8686 heights = []
@@ -95,19 +95,19 @@ def _get_shapes(self, get_shape: Callable[[InspectionAnnotation], tuple[int, ...
9595 widths .append (shape [2 ])
9696 return tuple (depths ) if depths else None , tuple (heights ), tuple (widths )
9797
98- def shapes (self ) -> tuple [tuple [ int , ...] | None , tuple [ int , ...], tuple [ int , ...] ]:
98+ def shapes (self ) -> tuple [AmbiguousShape | None , AmbiguousShape , AmbiguousShape ]:
9999 if self ._shapes :
100100 return self ._shapes
101101 self ._shapes = self ._get_shapes (lambda annotation : annotation .shape )
102102 return self ._shapes
103103
104- def foreground_shapes (self ) -> tuple [tuple [ int , ...] | None , tuple [ int , ...], tuple [ int , ...] ]:
104+ def foreground_shapes (self ) -> tuple [AmbiguousShape | None , AmbiguousShape , AmbiguousShape ]:
105105 if self ._foreground_shapes :
106106 return self ._foreground_shapes
107107 self ._foreground_shapes = self ._get_shapes (lambda annotation : annotation .foreground_shape ())
108108 return self ._foreground_shapes
109109
110- def statistical_foreground_shape (self , * , percentile : float = .95 ) -> tuple [ int , int ] | tuple [ int , int , int ] :
110+ def statistical_foreground_shape (self , * , percentile : float = .95 ) -> Shape :
111111 if self ._statistical_foreground_shape :
112112 return self ._statistical_foreground_shape
113113 depths , heights , widths = self .foreground_shapes ()
@@ -172,7 +172,7 @@ def center_of_foregrounds_offsets(self) -> tuple[int, int] | tuple[int, int, int
172172 self ._foreground_offsets = offsets + (round (center [2 ] - max_shape [2 ] * .5 ),) if depths else offsets
173173 return self ._foreground_offsets
174174
175- def set_roi_shape (self , roi_shape : tuple [ int , int ] | tuple [ int , int , int ] | None ) -> None :
175+ def set_roi_shape (self , roi_shape : Shape | None ) -> None :
176176 if roi_shape is not None :
177177 depths , heights , widths = self .shapes ()
178178 if depths :
@@ -183,7 +183,7 @@ def set_roi_shape(self, roi_shape: tuple[int, int] | tuple[int, int, int] | None
183183 raise ValueError (f"ROI shape { roi_shape } exceeds minimum image shape ({ min (heights )} , { min (widths )} )" )
184184 self ._roi_shape = roi_shape
185185
186- def roi_shape (self , * , percentile : float = .95 ) -> tuple [ int , int ] | tuple [ int , int , int ] :
186+ def roi_shape (self , * , percentile : float = .95 ) -> Shape :
187187 if self ._roi_shape :
188188 return self ._roi_shape
189189 sfs = self .statistical_foreground_shape (percentile = percentile )
0 commit comments