@@ -168,6 +168,7 @@ def get_gson(
168168        inst : np .ndarray ,
169169        type : np .ndarray ,
170170        classes : Dict [str , int ],
171+         soft_type : np .ndarray  =  None ,
171172        x_offset : int  =  0 ,
172173        y_offset : int  =  0 ,
173174        geo_format : str  =  "qupath" ,
@@ -182,6 +183,8 @@ def get_gson(
182183                Cell type labelled semantic segmentation mask. Shape: (H, W). 
183184            classes : Dict[str, int] 
184185                Class dict e.g. {"inflam":1, "epithelial":2, "connec":3} 
186+             soft_type : np.ndarray, default=None 
187+                 Softmax type mask. Shape: (C, H, W). C is the number of classes. 
185188            x_offset : int, default=0 
186189                x-coordinate offset. (to set geojson to .mrxs wsi coordinates) 
187190            y_offset : int, default=0 
@@ -211,6 +214,14 @@ def get_gson(
211214
212215            inst_type  =  [key  for  key  in  classes .keys () if  classes [key ] ==  inst_type ][0 ]
213216
217+             # type probabilities 
218+             if  soft_type  is  not   None :
219+                 type_probs  =  soft_type [..., inst_map  ==  inst_id ].mean (axis = 1 )
220+                 inst_type_soft  =  dict (zip (classes .keys (), type_probs ))
221+                 # convert to float for json serialization 
222+                 for  key  in  inst_type_soft .keys ():
223+                     inst_type_soft [key ] =  float (inst_type_soft [key ])
224+ 
214225            # get the cell contour coordinates 
215226            contours , _  =  cv2 .findContours (inst , cv2 .RETR_TREE , cv2 .CHAIN_APPROX_SIMPLE )
216227
@@ -230,6 +241,11 @@ def get_gson(
230241            poly .append (poly [0 ])  # close the polygon 
231242            geo_obj ["geometry" ]["coordinates" ] =  [poly ]
232243            geo_obj ["properties" ]["classification" ]["name" ] =  inst_type 
244+             if  soft_type  is  not   None :
245+                 geo_obj ["properties" ]["classification" ][
246+                     "probabilities" 
247+                 ] =  inst_type_soft 
248+ 
233249            geo_objs .append (geo_obj )
234250
235251        return  geo_objs 
@@ -364,6 +380,7 @@ def write_mat(
364380        sem : np .ndarray  =  None ,
365381        compute_centorids : bool  =  False ,
366382        compute_bboxes : bool  =  False ,
383+         ** kwargs ,
367384    ) ->  None :
368385        """ 
369386        Write multiple masks to .mat file. 
@@ -429,6 +446,7 @@ def write_gson(
429446        inst : np .ndarray ,
430447        type : np .ndarray  =  None ,
431448        classes : Dict [str , int ] =  None ,
449+         soft_type : np .ndarray  =  None ,
432450        x_offset : int  =  0 ,
433451        y_offset : int  =  0 ,
434452        geo_format : str  =  "qupath" ,
@@ -444,6 +462,8 @@ def write_gson(
444462            type : np.ndarray, optional 
445463                Cell type labelled semantic segmentation mask. Shape: (H, W). If None, 
446464                the classes of the objects will be set to {background: 0, foreground: 1} 
465+             soft_type : np.ndarray, default=None 
466+                 Softmax type mask. Shape: (C, H, W). C is the number of classes. 
447467            classes : Dict[str, int], optional 
448468                Class dict e.g. {"inflam":1, "epithelial":2, "connec":3}. Ignored if 
449469                `type` is None. 
@@ -489,7 +509,7 @@ def write_gson(
489509                )
490510
491511        geo_objs  =  FileHandler .get_gson (
492-             inst , type , classes , x_offset , y_offset , geo_format 
512+             inst , type , classes , soft_type ,  x_offset , y_offset , geo_format 
493513        )
494514
495515        fname  =  fname .with_suffix (".json" )
@@ -564,6 +584,7 @@ def save_masks(
564584                    inst = maps ["inst" ],
565585                    type = type_map ,
566586                    classes = classes_type ,
587+                     soft_type = maps ["soft_type" ] if  "soft_type"  in  maps .keys () else  None ,
567588                    geo_format = json_format ,
568589                    x_offset = offs ["x" ],
569590                    y_offset = offs ["y" ],
@@ -587,6 +608,7 @@ def save_masks(
587608                    inst = label_semantic (maps ["sem" ]),
588609                    type = maps ["sem" ],
589610                    classes = classes_sem ,
611+                     soft_type = maps ["soft_sem" ] if  "soft_sem"  in  maps .keys () else  None ,
590612                    geo_format = json_format ,
591613                    x_offset = offs ["x" ],
592614                    y_offset = offs ["y" ],
0 commit comments