Skip to content

Commit c9f9c30

Browse files
committed
fix some model
1 parent d1e7779 commit c9f9c30

17 files changed

+57
-51
lines changed

timm/models/byobnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from ._builder import build_model_with_cfg
4646
from ._features import feature_take_indices
47-
from ._manipulate import checkpoint, checkpoint_seq, named_apply
47+
from ._manipulate import named_apply, checkpoint_seq
4848
from ._registry import generate_default_cfgs, register_model
4949

5050
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@@ -1385,7 +1385,7 @@ def forward_intermediates(
13851385
for stage in stages:
13861386
feat_idx += 1
13871387
if self.grad_checkpointing and not torch.jit.is_scripting():
1388-
x = checkpoint(stage, x)
1388+
x = checkpoint_seq(stage, x)
13891389
else:
13901390
x = stage(x)
13911391
if not exclude_final_conv and feat_idx == last_idx:

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def forward_intermediates(
212212
blocks = self.blocks[:max_index]
213213
for feat_idx, blk in enumerate(blocks, start=1):
214214
if self.grad_checkpointing and not torch.jit.is_scripting():
215-
x = checkpoint(blk, x)
215+
x = checkpoint_seq(blk, x, flatten=True)
216216
else:
217217
x = blk(x)
218218
if feat_idx in take_indices:

timm/models/efficientvit_mit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ def forward_intermediates(
790790

791791
for feat_idx, stage in enumerate(stages):
792792
if self.grad_checkpointing and not torch.jit.is_scripting():
793-
x = checkpoint(stage, x)
793+
x = checkpoint_seq(stages, x)
794794
else:
795795
x = stage(x)
796796
if feat_idx in take_indices:
@@ -947,7 +947,7 @@ def forward_intermediates(
947947

948948
for feat_idx, stage in enumerate(stages):
949949
if self.grad_checkpointing and not torch.jit.is_scripting():
950-
x = checkpoint(stage, x)
950+
x = checkpoint_seq(stages, x)
951951
else:
952952
x = stage(x)
953953
if feat_idx in take_indices:

timm/models/ghostnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ def forward_intermediates(
728728

729729
for feat_idx, stage in enumerate(stages, start=1):
730730
if self.grad_checkpointing and not torch.jit.is_scripting():
731-
x = checkpoint(stage, x)
731+
x = checkpoint_seq(stage, x, flatten=True)
732732
else:
733733
x = stage(x)
734734
if feat_idx in take_indices:

timm/models/hieradet_sam2.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import math
22
from copy import deepcopy
33
from functools import partial
4-
from typing import Callable, Dict, List, Optional, Tuple, Union
4+
from typing import Dict, List, Optional, Tuple, Union
55

66
import torch
77
import torch.nn as nn
88
import torch.nn.functional as F
9-
from torch.jit import Final
109

1110
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1211
from timm.layers import PatchEmbed, Mlp, DropPath, ClNormMlpClassifierHead, LayerScale, \
1312
get_norm_layer, get_act_layer, init_weight_jax, init_weight_vit, to_2tuple, use_fused_attn
1413

1514
from ._builder import build_model_with_cfg
1615
from ._features import feature_take_indices
17-
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
18-
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
16+
from ._manipulate import named_apply, checkpoint
17+
from ._registry import generate_default_cfgs, register_model
1918

2019

2120
def window_partition(x, window_size: Tuple[int, int]):
@@ -471,7 +470,10 @@ def forward_intermediates(
471470
else:
472471
blocks = self.blocks[:max_index + 1]
473472
for i, blk in enumerate(blocks):
474-
x = blk(x)
473+
if self.grad_checkpointing and not torch.jit.is_scripting():
474+
x = checkpoint(blk, x)
475+
else:
476+
x = blk(x)
475477
if i in take_indices:
476478
x_out = x.permute(0, 3, 1, 2) if output_fmt == 'NCHW' else x
477479
intermediates.append(x_out)
@@ -503,8 +505,11 @@ def prune_intermediate_layers(
503505
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
504506
x = self.patch_embed(x) # BHWC
505507
x = self._pos_embed(x)
506-
for i, blk in enumerate(self.blocks):
507-
x = blk(x)
508+
for blk in self.blocks:
509+
if self.grad_checkpointing and not torch.jit.is_scripting():
510+
x = checkpoint(blk, x)
511+
else:
512+
x = blk(x)
508513
return x
509514

510515
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:

timm/models/inception_resnet_v2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from functools import partial
66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
98

109
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
1110
from timm.layers import create_classifier, ConvNormAct

timm/models/inception_v3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
Licensed BSD-Clause 3 https://github.com/pytorch/vision/blob/master/LICENSE
55
"""
66
from functools import partial
7-
from typing import Optional
87

98
import torch
109
import torch.nn as nn

timm/models/nasnet.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33
https://github.com/Cadene/pretrained-models.pytorch
44
"""
55
from functools import partial
6-
from typing import Optional
76

87
import torch
98
import torch.nn as nn
10-
import torch.nn.functional as F
119

1210
from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
1311
from ._builder import build_model_with_cfg

timm/models/nextvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from timm.layers import ClassifierHead
1818
from ._builder import build_model_with_cfg
1919
from ._features import feature_take_indices
20-
from ._manipulate import checkpoint, checkpoint_seq
20+
from ._manipulate import checkpoint_seq
2121
from ._registry import generate_default_cfgs, register_model
2222

2323
__all__ = ['NextViT']
@@ -595,7 +595,7 @@ def forward_intermediates(
595595

596596
for feat_idx, stage in enumerate(stages):
597597
if self.grad_checkpointing and not torch.jit.is_scripting():
598-
x = checkpoint(stage, x)
598+
x = checkpoint_seq(stage, x)
599599
else:
600600
x = stage(x)
601601
if feat_idx in take_indices:

timm/models/nfnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def create_stem(
215215
if 'deep' in stem_type:
216216
if 'quad' in stem_type:
217217
# 4 deep conv stack as in NFNet-F models
218-
assert not 'pool' in stem_type
218+
assert 'pool' not in stem_type
219219
stem_chs = (out_chs // 8, out_chs // 4, out_chs // 2, out_chs)
220220
strides = (2, 1, 1, 2)
221221
stem_stride = 4

0 commit comments

Comments
 (0)