|
10 | 10 |
|
11 | 11 | import itertools |
12 | 12 | from functools import partial |
13 | | -from typing import Dict, Optional |
| 13 | +from typing import Dict, List, Optional, Tuple, Union |
14 | 14 |
|
15 | 15 | import torch |
16 | 16 | import torch.nn as nn |
|
20 | 20 | from timm.layers import LayerNorm2d, NormMlpClassifierHead, DropPath,\ |
21 | 21 | trunc_normal_, resize_rel_pos_bias_table_levit, use_fused_attn |
22 | 22 | from ._builder import build_model_with_cfg |
| 23 | +from ._features import feature_take_indices |
23 | 24 | from ._features_fx import register_notrace_module |
24 | 25 | from ._manipulate import checkpoint_seq |
25 | 26 | from ._registry import register_model, generate_default_cfgs |
@@ -536,6 +537,62 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None): |
536 | 537 | self.num_classes = num_classes |
537 | 538 | self.head.reset(num_classes, pool_type=global_pool) |
538 | 539 |
|
| 540 | + def forward_intermediates( |
| 541 | + self, |
| 542 | + x: torch.Tensor, |
| 543 | + indices: Optional[Union[int, List[int]]] = None, |
| 544 | + norm: bool = False, |
| 545 | + stop_early: bool = False, |
| 546 | + output_fmt: str = 'NCHW', |
| 547 | + intermediates_only: bool = False, |
| 548 | + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 549 | + """ Forward features that returns intermediates. |
| 550 | +
|
| 551 | + Args: |
| 552 | + x: Input image tensor |
| 553 | + indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 554 | + norm: Apply norm layer to compatible intermediates |
| 555 | + stop_early: Stop iterating over blocks when last desired intermediate hit |
| 556 | + output_fmt: Shape of intermediate feature outputs |
| 557 | + intermediates_only: Only return intermediate features |
| 558 | + Returns: |
| 559 | +
|
| 560 | + """ |
| 561 | + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 562 | + intermediates = [] |
| 563 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 564 | + |
| 565 | + # forward pass |
| 566 | + x = self.patch_embed(x) |
| 567 | + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 568 | + stages = self.stages |
| 569 | + else: |
| 570 | + stages = self.stages[:max_index + 1] |
| 571 | + |
| 572 | + for feat_idx, stage in enumerate(stages): |
| 573 | + x = stage(x) |
| 574 | + if feat_idx in take_indices: |
| 575 | + intermediates.append(x) |
| 576 | + |
| 577 | + if intermediates_only: |
| 578 | + return intermediates |
| 579 | + |
| 580 | + return x, intermediates |
| 581 | + |
| 582 | + def prune_intermediate_layers( |
| 583 | + self, |
| 584 | + indices: Union[int, List[int]] = 1, |
| 585 | + prune_norm: bool = False, |
| 586 | + prune_head: bool = True, |
| 587 | + ): |
| 588 | + """ Prune layers not required for specified intermediates. |
| 589 | + """ |
| 590 | + take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 591 | + self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 592 | + if prune_head: |
| 593 | + self.reset_classifier(0, '') |
| 594 | + return take_indices |
| 595 | + |
539 | 596 | def forward_features(self, x): |
540 | 597 | x = self.patch_embed(x) |
541 | 598 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
|
0 commit comments