Skip to content

Commit 45c4d44

Browse files
committed
fix norm at last feat_idx
1 parent 46433ad commit 45c4d44

File tree

6 files changed

+30
-12
lines changed

6 files changed

+30
-12
lines changed

timm/models/mambaout.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,6 @@ def prune_intermediate_layers(
479479
self.reset_classifier(0, '')
480480
return take_indices
481481

482-
483482
def forward_features(self, x):
484483
x = self.stem(x)
485484
x = self.stages(x)

timm/models/maxxvit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1302,7 +1302,8 @@ def forward_intermediates(
13021302
if intermediates_only:
13031303
return intermediates
13041304

1305-
x = self.norm(x)
1305+
if feat_idx == last_idx:
1306+
x = self.norm(x)
13061307

13071308
return x, intermediates
13081309

timm/models/nest.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ def forward_intermediates(
449449

450450
# forward pass
451451
x = self.patch_embed(x)
452+
last_idx = self.num_blocks - 1
452453
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
453454
stages = self.levels
454455
else:
@@ -457,13 +458,18 @@ def forward_intermediates(
457458
for feat_idx, stage in enumerate(stages):
458459
x = stage(x)
459460
if feat_idx in take_indices:
460-
intermediates.append(x)
461+
if norm and feat_idx == last_idx:
462+
x_inter = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
463+
intermediates.append(x_inter)
464+
else:
465+
intermediates.append(x)
461466

462467
if intermediates_only:
463468
return intermediates
464469

465-
# Layer norm done over channel dim only (to NHWC and back)
466-
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
470+
if feat_idx == last_idx:
471+
# Layer norm done over channel dim only (to NHWC and back)
472+
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
467473

468474
return x, intermediates
469475

timm/models/nextvit.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ def forward_intermediates(
588588

589589
# forward pass
590590
x = self.stem(x)
591+
last_idx = len(self.stages) - 1
591592
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
592593
stages = self.stages
593594
else:
@@ -596,12 +597,17 @@ def forward_intermediates(
596597
for feat_idx, stage in enumerate(stages):
597598
x = stage(x)
598599
if feat_idx in take_indices:
599-
intermediates.append(x)
600+
if feat_idx == last_idx:
601+
x_inter = self.norm(x) if norm else x
602+
intermediates.append(x_inter)
603+
else:
604+
intermediates.append(x)
600605

601606
if intermediates_only:
602607
return intermediates
603608

604-
x = self.norm(x)
609+
if feat_idx == last_idx:
610+
x = self.norm(x)
605611

606612
return x, intermediates
607613

timm/models/rdnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def forward_intermediates(
309309
x = self.stem(x)
310310
if feat_idx in take_indices:
311311
intermediates.append(x)
312-
312+
last_idx = len(self.dense_stages)
313313
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
314314
dense_stages = self.dense_stages
315315
else:
@@ -324,7 +324,8 @@ def forward_intermediates(
324324
if intermediates_only:
325325
return intermediates
326326

327-
x = self.norm_pre(x)
327+
if feat_idx == last_idx:
328+
x = self.norm_pre(x)
328329

329330
return x, intermediates
330331

timm/models/resnetv2.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def forward_intermediates(
574574
x = self.stem(x)
575575
if feat_idx in take_indices:
576576
intermediates.append(x)
577-
577+
last_idx = len(self.stages)
578578
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
579579
stages = self.stages
580580
else:
@@ -583,12 +583,17 @@ def forward_intermediates(
583583
for feat_idx, stage in enumerate(stages, start=1):
584584
x = stage(x)
585585
if feat_idx in take_indices:
586-
intermediates.append(x)
586+
if feat_idx == last_idx:
587+
x_inter = self.norm(x) if norm else x
588+
intermediates.append(x_inter)
589+
else:
590+
intermediates.append(x)
587591

588592
if intermediates_only:
589593
return intermediates
590594

591-
x = self.norm(x)
595+
if feat_idx == last_idx:
596+
x = self.norm(x)
592597

593598
return x, intermediates
594599

0 commit comments

Comments
 (0)