|
1 | 1 | import math |
2 | 2 | from copy import deepcopy |
3 | 3 | from functools import partial |
4 | | -from typing import Callable, Dict, List, Optional, Tuple, Union |
| 4 | +from typing import Dict, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import torch |
7 | 7 | import torch.nn as nn |
8 | 8 | import torch.nn.functional as F |
9 | | -from torch.jit import Final |
10 | 9 |
|
11 | 10 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD |
12 | 11 | from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \ |
13 | 12 | get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn |
14 | 13 |
|
15 | 14 | from ._builder import build_model_with_cfg |
16 | 15 | from ._features import feature_take_indices |
17 | | -from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv |
18 | | -from ._registry import generate_default_cfgs, register_model, register_model_deprecations |
| 16 | +from ._manipulate import named_apply, checkpoint |
| 17 | +from ._registry import generate_default_cfgs, register_model |
19 | 18 |
|
20 | 19 |
|
21 | 20 | def window_partition(x, window_size: Tuple[int, int]): |
@@ -471,7 +470,10 @@ def forward_intermediates( |
471 | 470 | else: |
472 | 471 | blocks = self.blocks[:max_index + 1] |
473 | 472 | for i, blk in enumerate(blocks): |
474 | | - x = blk(x) |
| 473 | + if self.grad_checkpointing and not torch.jit.is_scripting(): |
| 474 | + x = checkpoint(blk, x) |
| 475 | + else: |
| 476 | + x = blk(x) |
475 | 477 | if i in take_indices: |
476 | 478 | x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x |
477 | 479 | intermediates.append(x_out) |
@@ -503,8 +505,11 @@ def prune_intermediate_layers( |
503 | 505 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
504 | 506 | x = self.patch_embed(x) # BHWC |
505 | 507 | x = self._pos_embed(x) |
506 | | - for i, blk in enumerate(self.blocks): |
507 | | - x = blk(x) |
| 508 | + for blk in self.blocks: |
| 509 | + if self.grad_checkpointing and not torch.jit.is_scripting(): |
| 510 | + x = checkpoint(blk, x) |
| 511 | + else: |
| 512 | + x = blk(x) |
508 | 513 | return x |
509 | 514 |
|
510 | 515 | def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor: |
|
0 commit comments