@@ -367,29 +367,32 @@ def __init__(self, config):
367367 self .projection = None
368368 if config .add_projection :
369369 self .projection = nn .Conv2d (256 , 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = (1 , 1 ))
370+ self .projection_act = nn .ReLU ()
370371
371372 features = config .fusion_hidden_size
372373 self .conv1 = nn .Conv2d (features , features // 2 , kernel_size = 3 , stride = 1 , padding = 1 )
373374 self .upsample = nn .Upsample (scale_factor = 2 , mode = "bilinear" , align_corners = True )
374375 self .conv2 = nn .Conv2d (features // 2 , config .num_relative_features , kernel_size = 3 , stride = 1 , padding = 1 )
376+ self .conv2_act = nn .ReLU ()
375377 self .conv3 = nn .Conv2d (config .num_relative_features , 1 , kernel_size = 1 , stride = 1 , padding = 0 )
378+ self .conv3_act = nn .ReLU ()
376379
377380 def forward (self , hidden_states : List [torch .Tensor ]) -> torch .Tensor :
378381 # use last features
379382 hidden_states = hidden_states [self .head_in_index ]
380383
381384 if self .projection is not None :
382385 hidden_states = self .projection (hidden_states )
383- hidden_states = nn . ReLU () (hidden_states )
386+ hidden_states = self . projection_act (hidden_states )
384387
385388 hidden_states = self .conv1 (hidden_states )
386389 hidden_states = self .upsample (hidden_states )
387390 hidden_states = self .conv2 (hidden_states )
388- hidden_states = nn . ReLU () (hidden_states )
391+ hidden_states = self . conv2_act (hidden_states )
389392 # we need the features here (after second conv + ReLu)
390393 features = hidden_states
391394 hidden_states = self .conv3 (hidden_states )
392- hidden_states = nn . ReLU () (hidden_states )
395+ hidden_states = self . conv3_act (hidden_states )
393396
394397 predicted_depth = hidden_states .squeeze (dim = 1 )
395398
0 commit comments