Skip to content

Commit cda39b3

Browse files
committed
Add a deprecation phase to module re-org
1 parent 927f031 commit cda39b3

18 files changed

+60
-2
lines changed

benchmark.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
import torch.nn.parallel
2020

2121
from timm.data import resolve_data_config
22-
from timm.models import create_model, is_model, list_models, set_fast_norm
22+
from timm.layers import set_fast_norm
23+
from timm.models import create_model, is_model, list_models
2324
from timm.optim import create_optimizer_v2
2425
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
2526

timm/models/_builder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
_CHECK_HASH = False
2424

2525

26+
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
27+
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
28+
29+
2630
def _resolve_pretrained_source(pretrained_cfg):
2731
cfg_source = pretrained_cfg.get('source', '')
2832
pretrained_url = pretrained_cfg.get('url', None)

timm/models/_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from ._registry import is_model, model_entrypoint
1010

1111

12+
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
13+
14+
1215
def parse_model_name(model_name):
1316
if model_name.startswith('hf_hub'):
1417
# NOTE for backwards compat, deprecate hf_hub use

timm/models/_features.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
import torch.nn as nn
1818

1919

20+
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
21+
22+
2023
class FeatureInfo:
2124

2225
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):

timm/models/_features_fx.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
pass
3636

3737

38+
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
39+
'FeatureGraphNet', 'GraphExtractNet']
40+
41+
3842
def register_notrace_module(module: Type[nn.Module]):
3943
"""
4044
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.

timm/models/_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
_logger = logging.getLogger(__name__)
1414

15+
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
16+
1517

1618
def clean_state_dict(state_dict):
1719
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training

timm/models/_hub.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131

3232
_logger = logging.getLogger(__name__)
3333

34+
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
35+
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
36+
3437

3538
def get_cache_dir(child_dir=''):
3639
"""

timm/models/_manipulate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
from torch import nn as nn
1010
from torch.utils.checkpoint import checkpoint
1111

12+
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
13+
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
14+
1215

1316
def model_parameters(model, exclude_head=False):
1417
if exclude_head:

timm/models/_pretrained.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import Any, Deque, Dict, Tuple, Optional, Union
55

66

7+
__all__ = ['PretrainedCfg', 'filter_pretrained_cfg', 'DefaultCfg', 'split_model_name_tag', 'generate_default_cfgs']
8+
9+
710
@dataclass
811
class PretrainedCfg:
912
"""

timm/models/_prune.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55

66
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
77

8+
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
9+
810

911
def extract_layer(model, layer):
1012
layer = layer.split('.')

0 commit comments

Comments
 (0)