11import functools
2+ import torch
23import torch .nn as nn
34
45# from ..head.build import HEAD_REGISTRY
@@ -10,16 +11,23 @@ def __init__(
1011 self ,
1112 in_features = 2048 ,
1213 hidden_layers = [],
14+ out_features = None ,
1315 activation = "relu" ,
1416 bn = True ,
1517 dropout = 0.0 ,
18+
1619 ):
1720 super ().__init__ ()
1821 if isinstance (hidden_layers , int ):
1922 hidden_layers = [hidden_layers ]
2023
2124 assert len (hidden_layers ) > 0
22- self .out_features = hidden_layers [- 1 ]
25+
26+ # If out_features is not specified, use the last hidden layer dimension
27+ if out_features is None :
28+ out_features = hidden_layers [- 1 ]
29+ self .out_features = out_features
30+ self .in_features = in_features
2331
2432 mlp = []
2533
@@ -33,15 +41,23 @@ def __init__(
3341 for hidden_dim in hidden_layers :
3442 mlp += [nn .Linear (in_features , hidden_dim )]
3543 if bn :
36- mlp += [nn .BatchNorm1d (hidden_dim )]
44+ mlp += [nn .LayerNorm (hidden_dim )]
3745 mlp += [act_fn ()]
3846 if dropout > 0 :
3947 mlp += [nn .Dropout (dropout )]
4048 in_features = hidden_dim
4149
50+ # Add final projection layer if output dimension differs from last hidden layer
51+ if out_features != hidden_layers [- 1 ]:
52+ mlp += [nn .Linear (hidden_layers [- 1 ], out_features )]
53+
4254 self .mlp = nn .Sequential (* mlp )
4355
4456 def forward (self , x ):
57+ # Flatten input if it has more than 2 dimensions
58+ if x .dim () > 2 :
59+ x = x .view (x .size (0 ), - 1 )
60+
4561 return self .mlp (x )
4662
4763
0 commit comments