1010
1111
1212class BaseMultiTaskSegModel (nn .ModuleDict ):
13+ def forward_encoder (self , x : torch .Tensor ) -> List [torch .Tensor ]:
14+ """Forward the model encoder."""
15+ self ._check_input_shape (x )
16+ feats = self .encoder (x )
17+
18+ return feats
19+
20+ def forward_style (self , feat : torch .Tensor ) -> torch .Tensor :
21+ """Forward the style domain adaptation layer.
22+
23+ NOTE: returns None if style channels are not given at model init.
24+ """
25+ style = None
26+ if self .make_style is not None :
27+ style = self .make_style (feat )
28+
29+ return style
30+
1331 def forward_dec_features (
1432 self , feats : List [torch .Tensor ], style : torch .Tensor = None
15- ) -> Dict [str , torch .Tensor ]:
16- """Forward pass of the decoders in a multi-task seg model."""
33+ ) -> Dict [str , List [torch .Tensor ]]:
34+ """Forward pass of all the decoder features mappings in the model.
35+
36+ NOTE: returns all the features from diff decoder stages in a list.
37+ """
1738 res = {}
1839 decoders = [k for k in self .keys () if "decoder" in k ]
1940
2041 for dec in decoders :
21- x = self [dec ](* feats , style = style )
42+ featlist = self [dec ](* feats , style = style )
2243 branch = dec .split ("_" )[0 ]
23- res [branch ] = x
44+ res [branch ] = featlist
2445
2546 return res
2647
@@ -30,10 +51,9 @@ def forward_heads(
3051 """Forward pass of the seg heads in a multi-task seg model."""
3152 res = {}
3253 heads = [k for k in self .keys () if "head" in k ]
33-
3454 for head in heads :
3555 branch = head .split ("_" )[0 ]
36- x = self [head ](dec_feats [branch ])
56+ x = self [head ](dec_feats [branch ][ - 1 ]) # the last decoder stage feat map
3757 res [branch ] = x
3858
3959 return res
0 commit comments