|
7 | 7 | """ |
8 | 8 |
|
9 | 9 | __all__ = ['EfficientVit', 'EfficientVitLarge'] |
10 | | -from typing import List, Optional |
| 10 | +from typing import List, Optional, Tuple, Union |
11 | 11 | from functools import partial |
12 | 12 |
|
13 | 13 | import torch |
|
17 | 17 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
18 | 18 | from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh |
19 | 19 | from ._builder import build_model_with_cfg |
| 20 | +from ._features import feature_take_indices |
20 | 21 | from ._features_fx import register_notrace_module |
21 | 22 | from ._manipulate import checkpoint_seq |
22 | 23 | from ._registry import register_model, generate_default_cfgs |
@@ -754,6 +755,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
754 | 755 | self.num_classes = num_classes |
755 | 756 | self.head.reset(num_classes, global_pool) |
756 | 757 |
|
| 758 | + def forward_intermediates( |
| 759 | + self, |
| 760 | + x: torch.Tensor, |
| 761 | + indices: Optional[Union[int, List[int]]] = None, |
| 762 | + norm: bool = False, |
| 763 | + stop_early: bool = False, |
| 764 | + output_fmt: str = 'NCHW', |
| 765 | + intermediates_only: bool = False, |
| 766 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 767 | + """ Forward features that returns intermediates. |
| 768 | +
|
| 769 | + Args: |
| 770 | + x: Input image tensor |
| 771 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 772 | + norm: Apply norm layer to compatible intermediates |
| 773 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 774 | + output_fmt: Shape of intermediate feature outputs |
| 775 | + intermediates_only: Only return intermediate features |
| 776 | + Returns: |
| 777 | +
|
| 778 | + """ |
| 779 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 780 | + intermediates = [] |
| 781 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 782 | + |
| 783 | + # forward pass |
| 784 | + x = self.stem(x) |
| 785 | + |
| 786 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 787 | + stages = self.stages |
| 788 | + else: |
| 789 | + stages = self.stages[:max_index + 1] |
| 790 | + |
| 791 | + for feat_idx, stage in enumerate(stages): |
| 792 | + x = stage(x) |
| 793 | + if feat_idx in take_indices: |
| 794 | + intermediates.append(x) |
| 795 | + |
| 796 | + if intermediates_only: |
| 797 | + return intermediates |
| 798 | + |
| 799 | + return x, intermediates |
| 800 | + |
| 801 | + def prune_intermediate_layers( |
| 802 | + self, |
| 803 | + indices: Union[int, List[int]] = 1, |
| 804 | + prune_norm: bool = False, |
| 805 | + prune_head: bool = True, |
| 806 | + ): |
| 807 | + """ Prune layers not required for specified intermediates. |
| 808 | + """ |
| 809 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 810 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 811 | + if prune_head: |
| 812 | + self.reset_classifier(0, '') |
| 813 | + return take_indices |
| 814 | + |
757 | 815 | def forward_features(self, x): |
758 | 816 | x = self.stem(x) |
759 | 817 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
@@ -851,6 +909,63 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
851 | 909 | self.num_classes = num_classes |
852 | 910 | self.head.reset(num_classes, global_pool) |
853 | 911 |
|
| 912 | + def forward_intermediates( |
| 913 | + self, |
| 914 | + x: torch.Tensor, |
| 915 | + indices: Optional[Union[int, List[int]]] = None, |
| 916 | + norm: bool = False, |
| 917 | + stop_early: bool = False, |
| 918 | + output_fmt: str = 'NCHW', |
| 919 | + intermediates_only: bool = False, |
| 920 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 921 | + """ Forward features that returns intermediates. |
| 922 | +
|
| 923 | + Args: |
| 924 | + x: Input image tensor |
| 925 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 926 | + norm: Apply norm layer to compatible intermediates |
| 927 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 928 | + output_fmt: Shape of intermediate feature outputs |
| 929 | + intermediates_only: Only return intermediate features |
| 930 | + Returns: |
| 931 | +
|
| 932 | + """ |
| 933 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 934 | + intermediates = [] |
| 935 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 936 | + |
| 937 | + # forward pass |
| 938 | + x = self.stem(x) |
| 939 | + |
| 940 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 941 | + stages = self.stages |
| 942 | + else: |
| 943 | + stages = self.stages[:max_index + 1] |
| 944 | + |
| 945 | + for feat_idx, stage in enumerate(stages): |
| 946 | + x = stage(x) |
| 947 | + if feat_idx in take_indices: |
| 948 | + intermediates.append(x) |
| 949 | + |
| 950 | + if intermediates_only: |
| 951 | + return intermediates |
| 952 | + |
| 953 | + return x, intermediates |
| 954 | + |
| 955 | + def prune_intermediate_layers( |
| 956 | + self, |
| 957 | + indices: Union[int, List[int]] = 1, |
| 958 | + prune_norm: bool = False, |
| 959 | + prune_head: bool = True, |
| 960 | + ): |
| 961 | + """ Prune layers not required for specified intermediates. |
| 962 | + """ |
| 963 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 964 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 965 | + if prune_head: |
| 966 | + self.reset_classifier(0, '') |
| 967 | + return take_indices |
| 968 | + |
854 | 969 | def forward_features(self, x): |
855 | 970 | x = self.stem(x) |
856 | 971 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
0 commit comments