1- from typing import Dict , Optional , Tuple
1+ from typing import Dict , List , Optional , Tuple , Union
22
33import torch
44import torch .nn as nn
@@ -207,12 +207,43 @@ def __init__(
207207 if enc_freeze :
208208 self .freeze_encoder ()
209209
210- def forward (self , x : torch .Tensor ) -> Dict [str , torch .Tensor ]:
211- """Forward pass of Stardist."""
212- feats = self .forward_encoder (x )
213- style = self .forward_style (feats [0 ])
210+ def forward (
211+ self ,
212+ x : torch .Tensor ,
213+ return_feats : bool = False ,
214+ ) -> Union [
215+ Dict [str , torch .Tensor ],
216+ Tuple [
217+ List [torch .Tensor ],
218+ Dict [str , torch .Tensor ],
219+ Dict [str , torch .Tensor ],
220+ ],
221+ ]:
222+ """Forward pass of Stardist.
223+
224+ Parameters
225+ ----------
226+ x : torch.Tensor
227+ Input image batch. Shape: (B, C, H, W).
228+ return_feats : bool, default=False
229+ If True, encoder, decoder, and head outputs will all be returned
230+
231+ Returns
232+ -------
233+ Union[
234+ Dict[str, torch.Tensor],
235+ Tuple[
236+ List[torch.Tensor],
237+ Dict[str, torch.Tensor],
238+ Dict[str, torch.Tensor],
239+ ],
240+ ]:
241+ Dictionary mapping of output names to outputs or if `return_feats == True`
242+ returns also the encoder features in a list, decoder features as a dict
243+ mapping decoder names to outputs and the final head outputs dict.
244+ """
245+ feats , dec_feats = self .forward_features (x )
214246
215- dec_feats = self .forward_dec_features (feats , style )
216247 # Extra convs after decoders
217248 for e in self .extra_convs .keys ():
218249 for extra_conv in self .extra_convs [e ].keys ():
@@ -230,6 +261,9 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
230261
231262 out = self .forward_heads (dec_feats )
232263
264+ if return_feats :
265+ return feats , dec_feats , out
266+
233267 return out
234268
235269
0 commit comments