Skip to content

Commit bac56a6

Browse files
committed
fix tresnet and rdnet
1 parent 880b761 commit bac56a6

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

timm/models/rdnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
355355
def forward_features(self, x):
356356
x = self.stem(x)
357357
x = self.dense_stages(x)
358+
x = self.norm_pre(x)
358359
return x
359360

360361
def forward_head(self, x, pre_logits: bool = False):

timm/models/tresnet.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -252,14 +252,15 @@ def forward_intermediates(
252252
"""
253253
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
254254
intermediates = []
255-
take_indices, max_index = feature_take_indices(len(self.body) - 1, indices)
256-
255+
stage_ends = [1, 2, 3, 4, 5]
256+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
257+
take_indices = [stage_ends[i] for i in take_indices]
258+
max_index = stage_ends[max_index]
257259
# forward pass
258-
x = self.body[0](x) # s2d
259260
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
260-
stages = [self.body[1], self.body[2], self.body[3], self.body[4], self.body[5]]
261+
stages = self.body
261262
else:
262-
stages = self.body[1:max_index + 2]
263+
stages = self.body[:max_index + 1]
263264

264265
for feat_idx, stage in enumerate(stages):
265266
x = stage(x)
@@ -279,8 +280,10 @@ def prune_intermediate_layers(
279280
):
280281
""" Prune layers not required for specified intermediates.
281282
"""
282-
take_indices, max_index = feature_take_indices(len(self.body) - 1, indices)
283-
self.body = self.body[1:max_index + 2] # truncate blocks w/ stem as idx 0
283+
stage_ends = [1, 2, 3, 4, 5]
284+
take_indices, max_index = feature_take_indices(len(stage_ends), indices)
285+
max_index = stage_ends[max_index]
286+
self.body = self.body[:max_index + 1] # truncate blocks w/ stem as idx 0
284287
if prune_head:
285288
self.reset_classifier(0, '')
286289
return take_indices

0 commit comments

Comments
 (0)