|
16 | 16 | """ |
17 | 17 | import math |
18 | 18 | from functools import partial |
19 | | -from typing import Dict, Optional |
| 19 | +from typing import Dict, List, Optional, Tuple, Union |
20 | 20 |
|
21 | 21 | import torch |
22 | 22 | import torch.nn as nn |
|
25 | 25 | from timm.layers import create_conv2d, create_norm_layer, get_act_layer, get_norm_layer, ConvNormAct |
26 | 26 | from timm.layers import DropPath, trunc_normal_, to_2tuple, to_ntuple, ndgrid |
27 | 27 | from ._builder import build_model_with_cfg |
| 28 | +from ._features import feature_take_indices |
28 | 29 | from ._manipulate import checkpoint_seq |
29 | 30 | from ._registry import generate_default_cfgs, register_model |
30 | 31 |
|
@@ -625,6 +626,73 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
625 | 626 | def set_distilled_training(self, enable=True): |
626 | 627 | self.distilled_training = enable |
627 | 628 |
|
| 629 | + def forward_intermediates( |
| 630 | + self, |
| 631 | + x: torch.Tensor, |
| 632 | + indices: Optional[Union[int, List[int]]] = None, |
| 633 | + norm: bool = False, |
| 634 | + stop_early: bool = False, |
| 635 | + output_fmt: str = 'NCHW', |
| 636 | + intermediates_only: bool = False, |
| 637 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 638 | + """ Forward features that returns intermediates. |
| 639 | +
|
| 640 | + Args: |
| 641 | + x: Input image tensor |
| 642 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 643 | + norm: Apply norm layer to compatible intermediates |
| 644 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 645 | + output_fmt: Shape of intermediate feature outputs |
| 646 | + intermediates_only: Only return intermediate features |
| 647 | + Returns: |
| 648 | +
|
| 649 | + """ |
| 650 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 651 | + intermediates = [] |
| 652 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 653 | + |
| 654 | + # forward pass |
| 655 | + x = self.stem(x) |
| 656 | + |
| 657 | + last_idx = len(self.stages) - 1 |
| 658 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 659 | + stages = self.stages |
| 660 | + else: |
| 661 | + stages = self.stages[:max_index + 1] |
| 662 | + |
| 663 | + for feat_idx, stage in enumerate(stages): |
| 664 | + x = stage(x) |
| 665 | + if feat_idx in take_indices: |
| 666 | + if feat_idx == last_idx: |
| 667 | + x_inter = self.norm(x) if norm else x |
| 668 | + intermediates.append(x_inter) |
| 669 | + else: |
| 670 | + intermediates.append(x) |
| 671 | + |
| 672 | + if intermediates_only: |
| 673 | + return intermediates |
| 674 | + |
| 675 | + if feat_idx == last_idx: |
| 676 | + x = self.norm(x) |
| 677 | + |
| 678 | + return x, intermediates |
| 679 | + |
| 680 | + def prune_intermediate_layers( |
| 681 | + self, |
| 682 | + indices: Union[int, List[int]] = 1, |
| 683 | + prune_norm: bool = False, |
| 684 | + prune_head: bool = True, |
| 685 | + ): |
| 686 | + """ Prune layers not required for specified intermediates. |
| 687 | + """ |
| 688 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 689 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 690 | + if prune_norm: |
| 691 | + self.norm = nn.Identity() |
| 692 | + if prune_head: |
| 693 | + self.reset_classifier(0, '') |
| 694 | + return take_indices |
| 695 | + |
628 | 696 | def forward_features(self, x): |
629 | 697 | x = self.stem(x) |
630 | 698 | x = self.stages(x) |
|
0 commit comments