@@ -68,112 +68,41 @@ class Segmentation:
6868 uid : str
6969 groundtruths : list [Bitmask ]
7070 predictions : list [Bitmask ]
71- shape : tuple [int , ...] = field ( default_factory = lambda : ( 0 , 0 ))
71+ shape : tuple [int , ...]
7272 size : int = field (default = 0 )
7373
7474 def __post_init__ (self ):
7575
76- groundtruth_shape = {
77- groundtruth .mask .shape for groundtruth in self .groundtruths
78- }
79- prediction_shape = {
80- prediction .mask .shape for prediction in self .predictions
81- }
82- if len (groundtruth_shape ) == 0 :
83- raise ValueError ("The segmenation is missing ground truths." )
84- elif len (prediction_shape ) == 0 :
85- raise ValueError ("The segmenation is missing predictions." )
86- elif (
87- len (groundtruth_shape ) != 1
88- or len (prediction_shape ) != 1
89- or groundtruth_shape != prediction_shape
90- ):
76+ if len (self .shape ) != 2 or self .shape [0 ] <= 0 or self .shape [1 ] <= 0 :
9177 raise ValueError (
92- "A shape mismatch exists within the segmentation. "
78+ f"segmentations must be 2-dimensional and have non-zero dimensions. Recieved shape ' { self . shape } ' "
9379 )
94-
95- self .shape = groundtruth_shape .pop ()
96- self .size = int (np .prod (np .array (self .shape )))
97-
98-
99- def generate_segmentation (
100- datum_uid : str ,
101- number_of_unique_labels : int ,
102- mask_height : int ,
103- mask_width : int ,
104- ) -> Segmentation :
105- """
106- Generates a semantic segmentation annotation.
107-
108- Parameters
109- ----------
110- datum_uid : str
111- The datum UID for the generated segmentation.
112- number_of_unique_labels : int
113- The number of unique labels.
114- mask_height : int
115- The height of the mask in pixels.
116- mask_width : int
117- The width of the mask in pixels.
118-
119- Returns
120- -------
121- Segmentation
122- A generated semantic segmenatation annotation.
123- """
124-
125- if number_of_unique_labels > 1 :
126- common_proba = 0.4 / (number_of_unique_labels - 1 )
127- min_proba = min (common_proba , 0.1 )
128- labels = [str (i ) for i in range (number_of_unique_labels )] + [None ]
129- proba = (
130- [0.5 ]
131- + [common_proba for _ in range (number_of_unique_labels - 1 )]
132- + [0.1 ]
133- )
134- elif number_of_unique_labels == 1 :
135- labels = ["0" , None ]
136- proba = [0.9 , 0.1 ]
137- min_proba = 0.1
138- else :
139- raise ValueError (
140- "The number of unique labels should be greater than zero."
141- )
142-
143- probabilities = np .array (proba , dtype = np .float64 )
144- weights = (probabilities / min_proba ).astype (np .int32 )
145-
146- indices = np .random .choice (
147- np .arange (len (weights )),
148- size = (mask_height * 2 , mask_width ),
149- p = probabilities ,
150- )
151-
152- N = len (labels )
153-
154- masks = np .arange (N )[:, None , None ] == indices
155-
156- gts = []
157- pds = []
158- for lidx in range (N ):
159- label = labels [lidx ]
160- if label is None :
161- continue
162- gts .append (
163- Bitmask (
164- mask = masks [lidx , :mask_height , :],
165- label = label ,
166- )
167- )
168- pds .append (
169- Bitmask (
170- mask = masks [lidx , mask_height :, :],
171- label = label ,
172- )
173- )
174-
175- return Segmentation (
176- uid = datum_uid ,
177- groundtruths = gts ,
178- predictions = pds ,
179- )
80+ self .size = self .shape [0 ] * self .shape [1 ]
81+
82+ mask_accumulation = None
83+ for groundtruth in self .groundtruths :
84+ if self .shape != groundtruth .mask .shape :
85+ raise ValueError (
86+ f"ground truth masks for datum '{ self .uid } ' should have shape '{ self .shape } '. Received mask with shape '{ groundtruth .mask .shape } '"
87+ )
88+
89+ if mask_accumulation is None :
90+ mask_accumulation = groundtruth .mask .copy ()
91+ elif np .logical_and (mask_accumulation , groundtruth .mask ).any ():
92+ raise ValueError ("ground truth masks cannot overlap" )
93+ else :
94+ mask_accumulation = mask_accumulation | groundtruth .mask
95+
96+ mask_accumulation = None
97+ for prediction in self .predictions :
98+ if self .shape != prediction .mask .shape :
99+ raise ValueError (
100+ f"prediction masks for datum '{ self .uid } ' should have shape '{ self .shape } '. Received mask with shape '{ prediction .mask .shape } '"
101+ )
102+
103+ if mask_accumulation is None :
104+ mask_accumulation = prediction .mask .copy ()
105+ elif np .logical_and (mask_accumulation , prediction .mask ).any ():
106+ raise ValueError ("prediction masks cannot overlap" )
107+ else :
108+ mask_accumulation = mask_accumulation | prediction .mask
0 commit comments