Skip to content

Commit 3e28a48

Browse files
committed
Move activation definitions of zoe_depth to init()
Signed-off-by: Phillip Kuznetsov <[email protected]>
1 parent c7f9c47 commit 3e28a48

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/transformers/models/zoedepth/modeling_zoedepth.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,29 +367,32 @@ def __init__(self, config):
367367
self.projection = None
368368
if config.add_projection:
369369
self.projection = nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
370+
self.projection_act = nn.ReLU()
370371

371372
features = config.fusion_hidden_size
372373
self.conv1 = nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1)
373374
self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
374375
self.conv2 = nn.Conv2d(features // 2, config.num_relative_features, kernel_size=3, stride=1, padding=1)
376+
self.conv2_act = nn.ReLU()
375377
self.conv3 = nn.Conv2d(config.num_relative_features, 1, kernel_size=1, stride=1, padding=0)
378+
self.conv3_act = nn.ReLU()
376379

377380
def forward(self, hidden_states: List[torch.Tensor]) -> torch.Tensor:
378381
# use last features
379382
hidden_states = hidden_states[self.head_in_index]
380383

381384
if self.projection is not None:
382385
hidden_states = self.projection(hidden_states)
383-
hidden_states = nn.ReLU()(hidden_states)
386+
hidden_states = self.projection_act(hidden_states)
384387

385388
hidden_states = self.conv1(hidden_states)
386389
hidden_states = self.upsample(hidden_states)
387390
hidden_states = self.conv2(hidden_states)
388-
hidden_states = nn.ReLU()(hidden_states)
391+
hidden_states = self.conv2_act(hidden_states)
389392
# we need the features here (after second conv + ReLu)
390393
features = hidden_states
391394
hidden_states = self.conv3(hidden_states)
392-
hidden_states = nn.ReLU()(hidden_states)
395+
hidden_states = self.conv3_act(hidden_states)
393396

394397
predicted_depth = hidden_states.squeeze(dim=1)
395398

0 commit comments

Comments
 (0)