diff --git a/README.md b/README.md index bdcd798..2e8ed04 100644 --- a/README.md +++ b/README.md @@ -71,17 +71,20 @@ pip install openvino [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type). 3) The resulting depth maps are written to the `output` folder. + When using `--save_npy`, the raw depth arrays are stored alongside the images for + further processing. The `run()` function also returns these arrays when used as a module. #### optional 1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single - inference height but a range of different heights. Feel free to explore different heights by appending the extra + inference height but a range of different heights. Feel free to explore different heights by appending the extra command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may decrease the model accuracy. -2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is +2) To store the predicted depth maps as NumPy arrays alongside the images, set the flag `--save_npy`. +3) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution, - disregarding the aspect ratio while preserving the height, use the command line argument `--square`. + disregarding the aspect ratio while preserving the height, use the command line argument `--square`. #### via Camera @@ -92,9 +95,30 @@ pip install openvino python run.py --model_type --side ``` - The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown + The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown side-by-side for comparison. + A different camera can be specified via `--camera_source`. This parameter accepts a device index or a URL. + For example, to use a Wi‑Fi camera stream: + + ```shell + python run.py --model_type --camera_source http://:/stream + ``` + + Use `--save_npy` to additionally store the raw depth predictions as NumPy arrays. + +#### Streaming depth arrays + +The `stream_depth.py` script forwards each predicted depth map to another +process using HTTP POST. Simply provide the destination URL: + +```shell +python stream_depth.py --server_url http://localhost:8000/depth +``` + +Use `--camera_source` to select a different camera or stream and `--model_type` +to choose another model. + #### via Docker 1) Make sure you have installed Docker and the diff --git a/patched_timm/models/maxxvit.py b/patched_timm/models/maxxvit.py new file mode 100644 index 0000000..7db3ac4 --- /dev/null +++ b/patched_timm/models/maxxvit.py @@ -0,0 +1,2650 @@ +""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch + +This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch. + +99% of the implementation was done from papers, however last minute some adjustments were made +based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit + +There are multiple sets of models defined for both architectures. Typically, names with a + `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit. +These configs work well and appear to be a bit faster / lower resource than the paper. + +The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to +match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. + +Papers: + +MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 +@article{tu2022maxvit, + title={MaxViT: Multi-Axis Vision Transformer}, + author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao}, + journal={ECCV}, + year={2022}, +} + +CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803 +@article{DBLP:journals/corr/abs-2106-04803, + author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan}, + title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes}, + journal = {CoRR}, + volume = {abs/2106.04803}, + year = {2021} +} + +Hacked together by / Copyright 2022, Ross Wightman +""" + +import math +from collections import OrderedDict +from dataclasses import dataclass, replace, field +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from torch.jit import Final + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead +from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d +from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert +from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table +from ._builder import build_model_with_cfg +from ._features import feature_take_indices +from ._features_fx import register_notrace_function +from ._manipulate import named_apply, checkpoint_seq +from ._registry import generate_default_cfgs, register_model + +__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] + + +@dataclass +class MaxxVitTransformerCfg: + """Configuration for MaxxVit transformer blocks.""" + dim_head: int = 32 + head_first: bool = True # head ordering in qkv channel dim + expand_ratio: float = 4.0 + expand_first: bool = True + shortcut_bias: bool = True + attn_bias: bool = True + attn_drop: float = 0. + proj_drop: float = 0. + pool_type: str = 'avg2' + rel_pos_type: str = 'bias' + rel_pos_dim: int = 512 # for relative position types w/ MLP + partition_ratio: int = 32 + window_size: Optional[Tuple[int, int]] = None + grid_size: Optional[Tuple[int, int]] = None + no_block_attn: bool = False # disable window block attention for maxvit (ie only grid) + use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order + init_values: Optional[float] = None + act_layer: str = 'gelu' + norm_layer: str = 'layernorm2d' + norm_layer_cl: str = 'layernorm' + norm_eps: float = 1e-6 + + def __post_init__(self): + if self.grid_size is not None: + self.grid_size = to_2tuple(self.grid_size) + if self.window_size is not None: + self.window_size = to_2tuple(self.window_size) + if self.grid_size is None: + self.grid_size = self.window_size + + +@dataclass +class MaxxVitConvCfg: + """Configuration for MaxxVit convolution blocks.""" + block_type: str = 'mbconv' + expand_ratio: float = 4.0 + expand_output: bool = True # calculate expansion channels from output (vs input chs) + kernel_size: int = 3 + group_size: int = 1 # 1 == depthwise + pre_norm_act: bool = False # activation after pre-norm + output_bias: bool = True # bias for shortcut + final 1x1 projection conv + stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' + pool_type: str = 'avg2' + downsample_pool_type: str = 'avg2' + padding: str = '' + attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 + attn_layer: str = 'se' + attn_act_layer: str = 'silu' + attn_ratio: float = 0.25 + init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv + act_layer: str = 'gelu' + norm_layer: str = '' + norm_layer_cl: str = '' + norm_eps: Optional[float] = None + + def __post_init__(self): + # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args + assert self.block_type in ('mbconv', 'convnext') + use_mbconv = self.block_type == 'mbconv' + if not self.norm_layer: + self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d' + if not self.norm_layer_cl and not use_mbconv: + self.norm_layer_cl = 'layernorm' + if self.norm_eps is None: + self.norm_eps = 1e-5 if use_mbconv else 1e-6 + self.downsample_pool_type = self.downsample_pool_type or self.pool_type + + +@dataclass +class MaxxVitCfg: + """Configuration for MaxxVit models.""" + embed_dim: Tuple[int, ...] = (96, 192, 384, 768) + depths: Tuple[int, ...] = (2, 3, 5, 2) + block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') + stem_width: Union[int, Tuple[int, int]] = 64 + stem_bias: bool = False + conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg) + transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg) + head_hidden_size: Optional[int] = None + weight_init: str = 'vit_eff' + + +class Attention2d(nn.Module): + """Multi-head attention for 2D NCHW tensors.""" + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + head_first: bool = True, + rel_pos_cls: Optional[Callable] = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension (defaults to input dimension). + dim_head: Dimension per attention head. + bias: Whether to use bias in qkv and projection. + expand_first: Whether to expand channels before or after qkv. + head_first: Whether heads are first in tensor layout. + rel_pos_cls: Relative position class to use. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.head_first = head_first + self.scale = dim_head ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + B, C, H, W = x.shape + + if self.head_first: + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + else: + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) + + if self.fused_attn: + attn_bias = None + if self.rel_pos is not None: + attn_bias = self.rel_pos.get_bias() + elif shared_rel_pos is not None: + attn_bias = shared_rel_pos + + x = torch.nn.functional.scaled_dot_product_attention( + q.transpose(-1, -2).contiguous(), + k.transpose(-1, -2).contiguous(), + v.transpose(-1, -2).contiguous(), + attn_mask=attn_bias, + dropout_p=self.attn_drop.p if self.training else 0., + ).transpose(-1, -2).reshape(B, -1, H, W) + else: + q = q * self.scale + attn = q.transpose(-2, -1) @ k + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionCl(nn.Module): + """Channels-last multi-head attention (B, ..., C).""" + fused_attn: Final[bool] + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + head_first: bool = True, + rel_pos_cls: Optional[Callable] = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension (defaults to input dimension). + dim_head: Dimension per attention head. + bias: Whether to use bias in qkv and projection. + expand_first: Whether to expand channels before or after qkv. + head_first: Whether heads are first in tensor layout. + rel_pos_cls: Relative position class to use. + attn_drop: Attention dropout rate. + proj_drop: Projection dropout rate. + """ + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first and dim_out > dim else dim + assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.head_first = head_first + self.scale = dim_head ** -0.5 + self.fused_attn = use_fused_attn() + + self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim_attn, dim_out, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + B = x.shape[0] + restore_shape = x.shape[:-1] + + if self.head_first: + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + else: + q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) + + if self.fused_attn: + attn_bias = None + if self.rel_pos is not None: + attn_bias = self.rel_pos.get_bias() + elif shared_rel_pos is not None: + attn_bias = shared_rel_pos + + x = torch.nn.functional.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_bias, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(restore_shape + (-1,)) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + """Per-channel scaling layer.""" + + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + """ + Args: + dim: Number of channels. + init_values: Initial scaling value. + inplace: Whether to perform inplace operations. + """ + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gamma = self.gamma + return x.mul_(gamma) if self.inplace else x * gamma + + +class LayerScale2d(nn.Module): + """Per-channel scaling layer for 2D tensors.""" + + def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False): + """ + Args: + dim: Number of channels. + init_values: Initial scaling value. + inplace: Whether to perform inplace operations. + """ + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gamma = self.gamma.view(1, -1, 1, 1) + return x.mul_(gamma) if self.inplace else x * gamma + + +class Downsample2d(nn.Module): + """A downsample pooling module supporting several maxpool and avgpool modes. + + * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1 + * 'max2' - MaxPool2d w/ kernel_size = stride = 2 + * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1 + * 'avg2' - AvgPool2d w/ kernel_size = stride = 2 + """ + + def __init__( + self, + dim: int, + dim_out: int, + pool_type: str = 'avg2', + padding: str = '', + bias: bool = True, + ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + pool_type: Type of pooling operation. + padding: Padding mode. + bias: Whether to use bias in expansion conv. + """ + super().__init__() + assert pool_type in ('max', 'max2', 'avg', 'avg2') + if pool_type == 'max': + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1) + elif pool_type == 'max2': + self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2 + elif pool_type == 'avg': + self.pool = create_pool2d( + 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1) + else: + self.pool = create_pool2d('avg', 2, padding=padding or 0) + + if dim != dim_out: + self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) + else: + self.expand = nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool(x) # spatial downsample + x = self.expand(x) # expand chs + return x + + +def _init_transformer(module: nn.Module, name: str, scheme: str = '') -> None: + """Initialize transformer module weights.""" + if isinstance(module, (nn.Conv2d, nn.Linear)): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # vit like + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + + +class TransformerBlock2d(nn.Module): + """Transformer block with 2D downsampling. + + '2D' NCHW tensor layout + + Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW + for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs. + + This impl was faster on TPU w/ PT XLA than the 1D experiment. + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + rel_pos_cls: Optional[Callable] = None, + cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + stride: Stride for downsampling. + rel_pos_cls: Relative position class. + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ + super().__init__() + cfg = cfg or MaxxVitTransformerCfg() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + act_layer = get_act_layer(cfg.act_layer) + + if stride == 2: + self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias) + self.norm1 = nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)), + ])) + else: + assert dim == dim_out + self.shortcut = nn.Identity() + self.norm1 = norm_layer(dim) + + self.attn = Attention2d( + dim, + dim_out, + dim_head=cfg.dim_head, + expand_first=cfg.expand_first, + bias=cfg.attn_bias, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop + ) + self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim_out) + self.mlp = ConvMlp( + in_features=dim_out, + hidden_features=int(dim_out * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme: str = '') -> None: + named_apply(partial(_init_transformer, scheme=scheme), self) + + def forward(self, x: torch.Tensor, shared_rel_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def _init_conv(module: nn.Module, name: str, scheme: str = '') -> None: + """Initialize convolution module weights.""" + if isinstance(module, nn.Conv2d): + if scheme == 'normal': + nn.init.normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'trunc_normal': + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif scheme == 'xavier_normal': + nn.init.xavier_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + # efficientnet like + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def num_groups(group_size: Optional[int], channels: int) -> int: + """Calculate number of groups for grouped convolution.""" + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + assert channels % group_size == 0 + return channels // group_size + + +class MbConvBlock(nn.Module): + """Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand).""" + + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg | None = None, + drop_path: float = 0. + ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for conv. + dilation: Dilation for conv. + cfg: Convolution block configuration. + drop_path: Drop path rate. + """ + super(MbConvBlock, self).__init__() + cfg = cfg or MaxxVitConvCfg() + norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps) + mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio) + groups = num_groups(cfg.group_size, mid_chs) + + if stride == 2: + self.shortcut = Downsample2d( + in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', '1x1', 'dw') + stride_pool, stride_1, stride_2 = 1, 1, 1 + if cfg.stride_mode == 'pool': + # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1 + stride_pool, dilation_2 = stride, dilation[1] + # FIXME handle dilation of avg pool + elif cfg.stride_mode == '1x1': + # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away + stride_1, dilation_2 = stride, dilation[1] + else: + stride_2, dilation_2 = stride, dilation[0] + + self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) + if stride_pool > 1: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding) + else: + self.down = nn.Identity() + self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) + self.norm1 = norm_act_layer(mid_chs) + + self.conv2_kxk = create_conv2d( + mid_chs, mid_chs, cfg.kernel_size, + stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding) + + attn_kwargs = {} + if isinstance(cfg.attn_layer, str): + if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca': + attn_kwargs['act_layer'] = cfg.attn_act_layer + attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs)) + + # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2) + if cfg.attn_early: + self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + self.norm2 = norm_act_layer(mid_chs) + self.se = None + else: + self.se_early = None + self.norm2 = norm_act_layer(mid_chs) + self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs) + + self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def init_weights(self, scheme: str = '') -> None: + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = self.shortcut(x) + x = self.pre_norm(x) + x = self.down(x) + + # 1x1 expansion conv & norm-act + x = self.conv1_1x1(x) + x = self.norm1(x) + + # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act + x = self.conv2_kxk(x) + if self.se_early is not None: + x = self.se_early(x) + x = self.norm2(x) + if self.se is not None: + x = self.se(x) + + # 1x1 linear projection to output width + x = self.conv3_1x1(x) + x = self.drop_path(x) + shortcut + return x + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block.""" + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Tuple[int, int] = (1, 1), + cfg: MaxxVitConvCfg | None = None, + conv_mlp: bool = True, + drop_path: float = 0. + ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Kernel size for depthwise conv. + stride: Stride for conv. + dilation: Dilation for conv. + cfg: Convolution block configuration. + conv_mlp: Whether to use convolutional MLP. + drop_path: Drop path rate. + """ + super().__init__() + cfg = cfg or MaxxVitConvCfg() + out_chs = out_chs or in_chs + act_layer = get_act_layer(cfg.act_layer) + if conv_mlp: + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) + mlp_layer = ConvMlp + else: + assert 'layernorm' in cfg.norm_layer + norm_layer = LayerNorm + mlp_layer = Mlp + self.use_conv_mlp = conv_mlp + + if stride == 2: + self.shortcut = Downsample2d(in_chs, out_chs) + elif in_chs != out_chs: + self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias) + else: + self.shortcut = nn.Identity() + + assert cfg.stride_mode in ('pool', 'dw') + stride_pool, stride_dw = 1, 1 + # FIXME handle dilation? + if cfg.stride_mode == 'pool': + stride_pool = stride + else: + stride_dw = stride + + if stride_pool == 2: + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + else: + self.down = nn.Identity() + + self.conv_dw = create_conv2d( + in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1], + depthwise=True, bias=cfg.output_bias) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer) + if conv_mlp: + self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + else: + self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = self.shortcut(x) + x = self.down(x) + x = self.conv_dw(x) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + else: + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + x = self.mlp(x) + x = self.ls(x) + x = x.permute(0, 3, 1, 2) + + x = self.drop_path(x) + shortcut + return x + + +def window_partition(x: torch.Tensor, window_size: List[int]) -> torch.Tensor: + """Partition into non-overlapping windows.""" + B, H, W, C = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})') + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse window partition.""" + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C) + return x + + +def grid_partition(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor: + """Partition into overlapping windows with grid striding.""" + B, H, W, C = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}') + x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse grid partition.""" + H, W = img_size + C = windows.shape[-1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C) + return x + + +def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size: Tuple[int, int]) -> Optional[Callable]: + """Get relative position class based on config.""" + rel_pos_cls = None + if cfg.rel_pos_type == 'mlp': + rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) + elif cfg.rel_pos_type == 'bias': + rel_pos_cls = partial(RelPosBias, window_size=window_size) + elif cfg.rel_pos_type == 'bias_tf': + rel_pos_cls = partial(RelPosBiasTf, window_size=window_size) + return rel_pos_cls + + +class PartitionAttentionCl(nn.Module): + """Grid or Block partition + Attn + FFN. + + NxC 'channels last' tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + super().__init__() + cfg = cfg or MaxxVitTransformerCfg() + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = AttentionCl( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x): + img_size = x.shape[1:3] + if self.partition_block: + partitioned = window_partition(x, self.partition_size) + else: + partitioned = grid_partition(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse(partitioned, self.partition_size, img_size) + else: + x = grid_reverse(partitioned, self.partition_size, img_size) + return x + + def forward(self, x): + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ParallelPartitionAttention(nn.Module): + """Experimental. Grid and Block partition + single FFN. + + NxC tensor layout. + """ + + def __init__( + self, + dim: int, + cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + """ + Args: + dim: Input dimension. + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ + cfg = cfg or MaxxVitTransformerCfg() + super().__init__() + assert dim % 2 == 0 + norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + assert cfg.window_size == cfg.grid_size + self.partition_size = to_2tuple(cfg.window_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn_block = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.attn_grid = AttentionCl( + dim, + dim // 2, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + out_features=dim, + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: + img_size = x.shape[1:3] + + partitioned_block = window_partition(x, self.partition_size) + partitioned_block = self.attn_block(partitioned_block) + x_window = window_reverse(partitioned_block, self.partition_size, img_size) + + partitioned_grid = grid_partition(x, self.partition_size) + partitioned_grid = self.attn_grid(partitioned_grid) + x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size) + + return torch.cat([x_window, x_grid], dim=-1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +def window_partition_nchw(x: torch.Tensor, window_size: List[int]) -> torch.Tensor: + """Partition windows for NCHW tensors.""" + B, C, H, W = x.shape + _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})') + _assert(W % window_size[1] == 0, f'width ({W}) must be divisible by window ({window_size[1]})') + x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1]) + windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse_nchw(windows: torch.Tensor, window_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse window partition for NCHW tensors.""" + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1]) + x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W) + return x + + +def grid_partition_nchw(x: torch.Tensor, grid_size: List[int]) -> torch.Tensor: + """Grid partition for NCHW tensors.""" + B, C, H, W = x.shape + _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}') + _assert(W % grid_size[1] == 0, f'width {W} must be divisible by grid {grid_size[1]}') + x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1]) + windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1]) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def grid_reverse_nchw(windows: torch.Tensor, grid_size: List[int], img_size: List[int]) -> torch.Tensor: + """Reverse grid partition for NCHW tensors.""" + H, W = img_size + C = windows.shape[1] + x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1]) + x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W) + return x + + +class PartitionAttention2d(nn.Module): + """Grid or Block partition + Attn + FFN. + + '2D' NCHW tensor layout. + """ + + def __init__( + self, + dim: int, + partition_type: str = 'block', + cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + """ + Args: + dim: Input dimension. + partition_type: Partition type ('block' or 'grid'). + cfg: Transformer block configuration. + drop_path: Drop path rate. + """ + cfg = cfg or MaxxVitTransformerCfg() + super().__init__() + norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last + act_layer = get_act_layer(cfg.act_layer) + + self.partition_block = partition_type == 'block' + self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size) + rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size) + + self.norm1 = norm_layer(dim) + self.attn = Attention2d( + dim, + dim, + dim_head=cfg.dim_head, + bias=cfg.attn_bias, + head_first=cfg.head_first, + rel_pos_cls=rel_pos_cls, + attn_drop=cfg.attn_drop, + proj_drop=cfg.proj_drop, + ) + self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = ConvMlp( + in_features=dim, + hidden_features=int(dim * cfg.expand_ratio), + act_layer=act_layer, + drop=cfg.proj_drop) + self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def _partition_attn(self, x: torch.Tensor) -> torch.Tensor: + img_size = x.shape[-2:] + if self.partition_block: + partitioned = window_partition_nchw(x, self.partition_size) + else: + partitioned = grid_partition_nchw(x, self.partition_size) + + partitioned = self.attn(partitioned) + + if self.partition_block: + x = window_reverse_nchw(partitioned, self.partition_size, img_size) + else: + x = grid_reverse_nchw(partitioned, self.partition_size, img_size) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class MaxxVitBlock(nn.Module): + """MaxVit conv, window partition + FFN , grid partition + FFN.""" + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + conv_cfg: MaxxVitConvCfg | None = None, + transformer_cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + """Initialize MaxxVitBlock. + + Args: + dim: Input channel dimension. + dim_out: Output channel dimension. + stride: Stride for downsampling. + conv_cfg: Configuration for convolutional blocks. + transformer_cfg: Configuration for transformer blocks. + drop_path: Drop path rate. + """ + conv_cfg = conv_cfg or MaxxVitConvCfg() + transformer_cfg = transformer_cfg or MaxxVitTransformerCfg() + super().__init__() + self.nchw_attn = transformer_cfg.use_nchw_attn + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + + attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl + self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs) + self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) + + def init_weights(self, scheme=''): + if self.attn_block is not None: + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x): + # NCHW format + x = self.conv(x) + + if not self.nchw_attn: + x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) + if self.attn_block is not None: + x = self.attn_block(x) + x = self.attn_grid(x) + if not self.nchw_attn: + x = x.permute(0, 3, 1, 2) # back to NCHW + return x + + +class ParallelMaxxVitBlock(nn.Module): + """MaxVit block with parallel cat(window + grid), one FF. + + Experimental timm block. + """ + + def __init__( + self, + dim: int, + dim_out: int, + stride: int = 1, + num_conv: int = 2, + conv_cfg: MaxxVitConvCfg | None = None, + transformer_cfg: MaxxVitTransformerCfg | None = None, + drop_path: float = 0., + ): + """ + Args: + dim: Input dimension. + dim_out: Output dimension. + stride: Stride for first conv block. + num_conv: Number of convolution blocks. + conv_cfg: Convolution block configuration. + transformer_cfg: Transformer block configuration. + drop_path: Drop path rate. + """ + conv_cfg = conv_cfg or MaxxVitConvCfg() + transformer_cfg = transformer_cfg or MaxxVitTransformerCfg() + super().__init__() + + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + if num_conv > 1: + convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)] + convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1) + self.conv = nn.Sequential(*convs) + else: + self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) + self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) + + def init_weights(self, scheme: str = '') -> None: + named_apply(partial(_init_transformer, scheme=scheme), self.attn) + named_apply(partial(_init_conv, scheme=scheme), self.conv) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = x.permute(0, 2, 3, 1) + x = self.attn(x) + x = x.permute(0, 3, 1, 2) + return x + + +class MaxxVitStage(nn.Module): + """MaxxVit stage consisting of mixed convolution and transformer blocks.""" + + def __init__( + self, + in_chs: int, + out_chs: int, + stride: int = 2, + depth: int = 4, + feat_size: Tuple[int, int] = (14, 14), + block_types: Union[str, Tuple[str]] = 'C', + transformer_cfg: MaxxVitTransformerCfg | None = None, + conv_cfg: MaxxVitConvCfg | None = None, + drop_path: Union[float, List[float]] = 0., + ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + stride: Stride for first block. + depth: Number of blocks in stage. + feat_size: Feature map size. + block_types: Block types ('C' for conv, 'T' for transformer, etc). + transformer_cfg: Transformer block configuration. + conv_cfg: Convolution block configuration. + drop_path: Drop path rate(s). + """ + transformer_cfg = transformer_cfg or MaxxVitTransformerCfg() + conv_cfg = conv_cfg or MaxxVitConvCfg() + super().__init__() + self.grad_checkpointing = False + + block_types = extend_tuple(block_types, depth) + blocks = [] + for i, t in enumerate(block_types): + block_stride = stride if i == 0 else 1 + assert t in ('C', 'T', 'M', 'PM') + if t == 'C': + conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock + blocks += [conv_cls( + in_chs, + out_chs, + stride=block_stride, + cfg=conv_cfg, + drop_path=drop_path[i], + )] + elif t == 'T': + rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size) + blocks += [TransformerBlock2d( + in_chs, + out_chs, + stride=block_stride, + rel_pos_cls=rel_pos_cls, + cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'M': + blocks += [MaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + elif t == 'PM': + blocks += [ParallelMaxxVitBlock( + in_chs, + out_chs, + stride=block_stride, + conv_cfg=conv_cfg, + transformer_cfg=transformer_cfg, + drop_path=drop_path[i], + )] + in_chs = out_chs + self.blocks = nn.Sequential(*blocks) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + """Stem layer for feature extraction.""" + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + padding: str = '', + bias: bool = False, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + """ + Args: + in_chs: Input channels. + out_chs: Output channels. + kernel_size: Kernel size for convolutions. + padding: Padding mode. + bias: Whether to use bias. + act_layer: Activation layer. + norm_layer: Normalization layer. + norm_eps: Normalization epsilon. + """ + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias) + + def init_weights(self, scheme: str = '') -> None: + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]) -> MaxxVitTransformerCfg: + """Configure window size based on image size and partition ratio.""" + if cfg.window_size is not None: + assert cfg.grid_size + return cfg + partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio + cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) + return cfg + + +def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs: Any) -> MaxxVitCfg: + """Overlay keyword arguments onto configuration.""" + transformer_kwargs = {} + conv_kwargs = {} + base_kwargs = {} + for k, v in kwargs.items(): + if k.startswith('transformer_'): + transformer_kwargs[k.replace('transformer_', '')] = v + elif k.startswith('conv_'): + conv_kwargs[k.replace('conv_', '')] = v + else: + base_kwargs[k] = v + cfg = replace( + cfg, + transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs), + conv_cfg=replace(cfg.conv_cfg, **conv_kwargs), + **base_kwargs + ) + return cfg + + +class MaxxVit(nn.Module): + """CoaTNet + MaxVit base model. + + Highly configurable for different block compositions, tensor layouts, pooling types. + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0., + **kwargs: Any, + ): + """ + Args: + cfg: Model configuration. + img_size: Input image size. + in_chans: Number of input channels. + num_classes: Number of classification classes. + global_pool: Global pooling type. + drop_rate: Dropout rate. + drop_path_rate: Drop path rate. + **kwargs: Additional keyword arguments to overlay on config. + """ + super().__init__() + img_size = to_2tuple(img_size) + if kwargs: + cfg = _overlay_kwargs(cfg, **kwargs) + transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = cfg.embed_dim[-1] + self.drop_rate = drop_rate + self.grad_checkpointing = False + self.feature_info = [] + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + padding=cfg.conv_cfg.padding, + bias=cfg.stem_bias, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + stride = self.stem.stride + self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')] + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')] + self.stages = nn.Sequential(*stages) + + final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) + if cfg.head_hidden_size: + self.norm = nn.Identity() + self.head_hidden_size = cfg.head_hidden_size + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=self.head_hidden_size, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=final_norm_layer, + ) + else: + # standard classifier head w/ norm, pooling, fc classifier + self.head_hidden_size = self.num_features + self.norm = final_norm_layer(self.num_features) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module: nn.Module, name: str, scheme: str = '') -> None: + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self) -> Set[str]: + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse: bool = False) -> Dict[str, Any]: + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable: bool = True) -> None: + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self) -> nn.Module: + return self.head.fc + + def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None: + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def forward_intermediates( + self, + x: torch.Tensor, + indices: Optional[Union[int, List[int]]] = None, + norm: bool = False, + stop_early: bool = False, + output_fmt: str = 'NCHW', + intermediates_only: bool = False, + ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: + """ Forward features that returns intermediates. + + Args: + x: Input image tensor + indices: Take last n blocks if int, all if None, select matching indices if sequence + norm: Apply norm layer to compatible intermediates + stop_early: Stop iterating over blocks when last desired intermediate hit + output_fmt: Shape of intermediate feature outputs + intermediates_only: Only return intermediate features + Returns: + + """ + assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' + intermediates = [] + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + + # forward pass + feat_idx = 0 # stem is index 0 + x = self.stem(x) + if feat_idx in take_indices: + intermediates.append(x) + + last_idx = len(self.stages) + if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript + stages = self.stages + else: + stages = self.stages[:max_index] + for stage in stages: + feat_idx += 1 + x = stage(x) + if feat_idx in take_indices: + if norm and feat_idx == last_idx: + x_inter = self.norm(x) # applying final norm to last intermediate + else: + x_inter = x + intermediates.append(x_inter) + + if intermediates_only: + return intermediates + + if feat_idx == last_idx: + x = self.norm(x) + + return x, intermediates + + def prune_intermediate_layers( + self, + indices: Union[int, List[int]] = 1, + prune_norm: bool = False, + prune_head: bool = True, + ) -> Tuple[int, ...]: + """Prune layers not required for specified intermediates.""" + take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) + self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0 + if prune_norm: + self.norm = nn.Identity() + if prune_head: + self.head = self.reset_classifier(0, '') + return take_indices + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor: + return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _rw_coat_cfg( + stride_mode: str = 'pool', + pool_type: str = 'avg2', + conv_output_bias: bool = False, + conv_attn_early: bool = False, + conv_attn_act_layer: str = 'relu', + conv_norm_layer: str = '', + transformer_shortcut_bias: bool = True, + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + init_values: Optional[float] = None, + rel_pos_type: str = 'bias', + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """RW variant configuration for CoAtNet models. + + These models were created and trained before seeing https://github.com/google-research/maxvit + + Common differences for initial timm models: + - pre-norm layer in MZBConv included an activation after norm + - mbconv expansion calculated from input instead of output chs + - mbconv shortcut and final 1x1 conv did not have a bias + - SE act layer was relu, not silu + - mbconv uses silu in timm, not gelu + - expansion in attention block done via output proj, not input proj + + Variable differences (evolved over training initial models): + - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + - SE attention was between conv2 and norm/act + - default to avg pool for mbconv downsample instead of 1x1 or dw conv + - transformer block shortcut has no bias + """ + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + pre_norm_act=True, + expand_output=False, + output_bias=conv_output_bias, + attn_early=conv_attn_early, + attn_act_layer=conv_attn_act_layer, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + shortcut_bias=transformer_shortcut_bias, + pool_type=pool_type, + init_values=init_values, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _rw_max_cfg( + stride_mode: str = 'dw', + pool_type: str = 'avg2', + conv_output_bias: bool = False, + conv_attn_ratio: float = 1 / 16, + conv_norm_layer: str = '', + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + window_size: Optional[Tuple[int, int]] = None, + dim_head: int = 32, + init_values: Optional[float] = None, + rel_pos_type: str = 'bias', + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """RW variant configuration for MaxViT models. + + These models were created and trained before seeing https://github.com/google-research/maxvit + + Differences of initial timm models: + - mbconv expansion calculated from input instead of output chs + - mbconv shortcut and final 1x1 conv did not have a bias + - mbconv uses silu in timm, not gelu + - expansion in attention block done via output proj, not input proj + """ + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + output_bias=conv_output_bias, + attn_ratio=conv_attn_ratio, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + dim_head=dim_head, + window_size=window_size, + init_values=init_values, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _next_cfg( + stride_mode: str = 'dw', + pool_type: str = 'avg2', + conv_norm_layer: str = 'layernorm2d', + conv_norm_layer_cl: str = 'layernorm', + transformer_norm_layer: str = 'layernorm2d', + transformer_norm_layer_cl: str = 'layernorm', + window_size: Optional[Tuple[int, int]] = None, + no_block_attn: bool = False, + init_values: Union[float, Tuple[float, float]] = 1e-6, + rel_pos_type: str = 'mlp', # MLP by default for maxxvit + rel_pos_dim: int = 512, +) -> Dict[str, Any]: + """Configuration for experimental ConvNeXt-based MaxxViT models.""" + init_values = to_2tuple(init_values) + return dict( + conv_cfg=MaxxVitConvCfg( + block_type='convnext', + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + init_values=init_values[0], + norm_layer=conv_norm_layer, + norm_layer_cl=conv_norm_layer_cl, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + window_size=window_size, + no_block_attn=no_block_attn, # enabled for MaxxViT-V2 + init_values=init_values[1], + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) + + +def _tf_cfg() -> Dict[str, Any]: + """Configuration matching TensorFlow MaxViT models.""" + return dict( + conv_cfg=MaxxVitConvCfg( + norm_eps=1e-3, + act_layer='gelu_tanh', + padding='same', + ), + transformer_cfg=MaxxVitTransformerCfg( + norm_eps=1e-5, + act_layer='gelu_tanh', + head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....) + rel_pos_type='bias_tf', + ), + ) + + +model_cfgs = dict( + # timm specific CoAtNet configs + coatnet_pico_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 3, 5, 2), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + stride_mode='pool', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_0_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + conv_attn_early=True, + transformer_shortcut_bias=False, + ), + ), + coatnet_1_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + ) + ), + coatnet_2_rw=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + #init_values=1e-6, + ), + ), + coatnet_3_rw=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + ), + ), + + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + coatnet_bn_0_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + transformer_norm_layer='batchnorm2d', + ) + ), + coatnet_rmlp_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( + conv_output_bias=True, + conv_attn_ratio=0.25, + rel_pos_type='mlp', + rel_pos_dim=384, + ), + ), + coatnet_rmlp_0_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + ), + ), + coatnet_rmlp_1_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + pool_type='max', + conv_attn_early=True, + transformer_shortcut_bias=False, + rel_pos_type='mlp', + rel_pos_dim=384, # was supposed to be 512, woops + ), + ), + coatnet_rmlp_1_rw2=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + rel_pos_dim=512, # was supposed to be 512, woops + ), + ), + coatnet_rmlp_2_rw=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), + coatnet_rmlp_3_rw=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), + + coatnet_nano_cc=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), + ), + coatnext_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + weight_init='normal', + **_next_cfg( + rel_pos_type='bias', + init_values=(1e-5, None) + ), + ), + + # Trying to be like the CoAtNet paper configs + coatnet_0=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 5, 2), + stem_width=64, + head_hidden_size=768, + ), + coatnet_1=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=64, + head_hidden_size=768, + ), + coatnet_2=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=128, + head_hidden_size=1024, + ), + coatnet_3=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=192, + head_hidden_size=1536, + ), + coatnet_4=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 12, 28, 2), + stem_width=192, + head_hidden_size=1536, + ), + coatnet_5=MaxxVitCfg( + embed_dim=(256, 512, 1280, 2048), + depths=(2, 12, 28, 2), + stem_width=192, + head_hidden_size=2048, + ), + + # Experimental MaxVit configs + maxvit_pico_rw=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(), + ), + maxvit_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_pm=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('PM',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + + maxvit_rmlp_pico_rw=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_tiny_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_small_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg( + rel_pos_type='mlp', + init_values=1e-6, + ), + ), + maxvit_rmlp_base_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + head_hidden_size=768, + **_rw_max_cfg( + rel_pos_type='mlp', + ), + ), + + maxxvit_rmlp_nano_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + weight_init='normal', + **_next_cfg(), + ), + maxxvit_rmlp_tiny_rw=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(), + ), + maxxvit_rmlp_small_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(48, 96), + **_next_cfg(), + ), + + maxxvitv2_nano_rw=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(48, 96), + weight_init='normal', + **_next_cfg( + no_block_attn=True, + rel_pos_type='bias', + ), + ), + maxxvitv2_rmlp_base_rw=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 12, 2), + block_type=('M',) * 4, + stem_width=(64, 128), + **_next_cfg( + no_block_attn=True, + ), + ), + maxxvitv2_rmlp_large_rw=MaxxVitCfg( + embed_dim=(160, 320, 640, 1280), + depths=(2, 6, 16, 2), + block_type=('M',) * 4, + stem_width=(80, 160), + head_hidden_size=1280, + **_next_cfg( + no_block_attn=True, + ), + ), + + # Trying to be like the MaxViT paper configs + maxvit_tiny_tf=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=512, + **_tf_cfg(), + ), + maxvit_small_tf=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), + ), + maxvit_base_tf=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), + ), + maxvit_large_tf=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=128, + stem_bias=True, + head_hidden_size=1024, + **_tf_cfg(), + ), + maxvit_xlarge_tf=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=192, + stem_bias=True, + head_hidden_size=1536, + **_tf_cfg(), + ), +) + + +def checkpoint_filter_fn(state_dict: Dict[str, torch.Tensor], model: nn.Module) -> Dict[str, torch.Tensor]: + """Filter checkpoint state dict for compatibility.""" + model_state_dict = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k.endswith('relative_position_bias_table'): + m = model.get_submodule(k[:-29]) + if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]: + v = resize_rel_pos_bias_table( + v, + new_window_size=m.window_size, + new_bias_shape=m.relative_position_bias_table.shape, + ) + + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): + # adapt between conv2d / linear layers + assert v.ndim in (2, 4) + v = v.reshape(model_state_dict[k].shape) + out_dict[k] = v + return out_dict + + +def _create_maxxvit(variant: str, cfg_variant: Optional[str] = None, pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """Create a MaxxVit model variant.""" + if cfg_variant is None: + if variant in model_cfgs: + cfg_variant = variant + else: + cfg_variant = '_'.join(variant.split('_')[:-1]) + return build_model_with_cfg( + MaxxVit, variant, pretrained, + model_cfg=model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + +def _cfg(url: str = '', **kwargs: Any) -> Dict[str, Any]: + """Create a default configuration dict.""" + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } + + +default_cfgs = generate_default_cfgs({ + # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos + 'coatnet_pico_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', + crop_pct=0.9), + 'coatnet_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), + 'coatnet_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' + ), + + # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos + 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + #'coatnet_3_rw_224.untrained': _cfg(url=''), + + # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos) + 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + 'coatnet_bn_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''), + 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), + 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), + 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_cc_224.untrained': _cfg(url=''), + 'coatnext_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', + crop_pct=0.9), + + # ImagenNet-12k pretrain CoAtNet + 'coatnet_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_3_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released) + 'coatnet_0_224.untrained': _cfg(url=''), + 'coatnet_1_224.untrained': _cfg(url=''), + 'coatnet_2_224.untrained': _cfg(url=''), + 'coatnet_3_224.untrained': _cfg(url=''), + 'coatnet_4_224.untrained': _cfg(url=''), + 'coatnet_5_224.untrained': _cfg(url=''), + + # timm specific MaxVit configs, ImageNet-1k pretrain or untrained + 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), + 'maxvit_tiny_rw_256.untrained': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain + 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', + crop_pct=0.9, + ), + 'maxvit_rmlp_small_rw_256.untrained': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune + 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + ), + 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # timm specific MaxVit w/ ImageNet-12k pretrain + 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + ), + + # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks) + 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn) + 'maxxvitv2_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''), + + 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # MaxViT models ported from official Tensorflow impl + 'maxvit_tiny_tf_224.in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_tiny_tf_384.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_tiny_tf_512.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_224.in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_small_tf_384.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_512.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_224.in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_base_tf_384.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_large_tf_384.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), + + 'maxvit_base_tf_224.in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843), + 'maxvit_base_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843), + 'maxvit_large_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_224.in21k': _cfg( + hf_hub_id='timm/', + num_classes=21843), + 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), +}) + + +@register_model +def coatnet_pico_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Pico model with RW configuration.""" + return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with RW configuration.""" + return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with RW configuration.""" + return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with RW configuration.""" + return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with RW configuration.""" + return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model with RW configuration.""" + return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_bn_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with BatchNorm and RW configuration.""" + return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with Relative Position MLP.""" + return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_0_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model with Relative Position MLP.""" + return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_1_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with Relative Position MLP.""" + return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_1_rw2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model with Relative Position MLP v2.""" + return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_2_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with Relative Position MLP.""" + return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_2_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model with Relative Position MLP at 384x384.""" + return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_3_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model with Relative Position MLP.""" + return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_nano_cc_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet Nano model with ConvNeXt blocks.""" + return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnext_nano_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoAtNeXt Nano model with RW configuration.""" + return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_0_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-0 model.""" + return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_1_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-1 model.""" + return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_2_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-2 model.""" + return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_3_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-3 model.""" + return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_4_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-4 model.""" + return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_5_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """CoatNet-5 model.""" + return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Pico model with RW configuration.""" + return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Nano model with RW configuration.""" + return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with RW configuration.""" + return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with RW configuration at 256x256.""" + return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_pico_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Pico RW 256x256 model.""" + return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Nano RW 256x256 model.""" + return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Tiny RW 256x256 model.""" + return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_small_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Relative Position MLP Small RW 224x224 model.""" + return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model with Relative Position MLP at 256x256.""" + return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model with Relative Position MLP.""" + return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model with Relative Position MLP at 384x384.""" + return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_pm_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model with parallel blocks.""" + return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_rmlp_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Relative Position MLP Nano RW 256x256 model.""" + return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_rmlp_tiny_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Tiny model with Relative Position MLP.""" + return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_rmlp_small_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT Small model with Relative Position MLP.""" + return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_nano_rw_256(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Nano model.""" + return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Base model with Relative Position MLP.""" + return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Base model with Relative Position MLP at 384x384.""" + return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_large_rw_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxxViT-V2 Large model with Relative Position MLP.""" + return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow.""" + return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow at 384x384.""" + return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Tiny model from TensorFlow at 512x512.""" + return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow.""" + return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow at 384x384.""" + return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Small model from TensorFlow at 512x512.""" + return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow.""" + return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow at 384x384.""" + return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Base model from TensorFlow at 512x512.""" + return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow.""" + return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow at 384x384.""" + return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT Large model from TensorFlow at 512x512.""" + return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_224(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow.""" + return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_384(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow at 384x384.""" + return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_xlarge_tf_512(pretrained: bool = False, **kwargs: Any) -> MaxxVit: + """MaxViT XLarge model from TensorFlow at 512x512.""" + return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) diff --git a/run.py b/run.py index 5696ef0..d2e8fe2 100644 --- a/run.py +++ b/run.py @@ -103,7 +103,7 @@ def create_side_by_side(image, depth, grayscale): def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", optimize=False, side=False, height=None, - square=False, grayscale=False): + square=False, grayscale=False, camera_source=0, save_npy=False): """Run MonoDepthNN to compute depth maps. Args: @@ -116,6 +116,11 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op height (int): inference encoder image height square (bool): resize to a square resolution? grayscale (bool): use a grayscale colormap? + camera_source: camera device index or URL used for video capture + save_npy (bool): save predicted depth arrays as .npy files + + Returns: + list[np.ndarray]: all predictions in order of processing """ print("Initialize") @@ -138,6 +143,8 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op print("Start processing") + predictions = [] + if input_path is not None: if output_path is None: print("Warning: No output path specified. Images will be processed but not shown or stored anywhere.") @@ -166,11 +173,16 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op content = create_side_by_side(original_image_bgr*255, prediction, grayscale) cv2.imwrite(filename + ".png", content) utils.write_pfm(filename + ".pfm", prediction.astype(np.float32)) + if save_npy: + np.save(filename + ".npy", prediction) + + predictions.append(prediction) else: with torch.no_grad(): fps = 1 - video = VideoStream(0).start() + src = int(camera_source) if str(camera_source).isdigit() else camera_source + video = VideoStream(src).start() time_start = time.time() frame_index = 0 while True: @@ -189,6 +201,10 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op if output_path is not None: filename = os.path.join(output_path, 'Camera' + '-' + model_type + '_' + str(frame_index)) cv2.imwrite(filename + ".png", content) + if save_npy: + np.save(filename + ".npy", prediction) + + predictions.append(prediction) alpha = 0.1 if time.time()-time_start > 0: @@ -200,10 +216,13 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op break frame_index += 1 + video.stop() print() print("Finished") + return predictions + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -262,6 +281,16 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op 'colormap.' ) + parser.add_argument('--camera_source', + default=0, + help='Camera device index or video/stream URL to read frames from.' + ) + + parser.add_argument('--save_npy', + action='store_true', + help='Save predicted depth arrays as .npy files.' + ) + args = parser.parse_args() @@ -273,5 +302,16 @@ def run(input_path, output_path, model_path, model_type="dpt_beit_large_512", op torch.backends.cudnn.benchmark = True # compute depth maps - run(args.input_path, args.output_path, args.model_weights, args.model_type, args.optimize, args.side, args.height, - args.square, args.grayscale) + run( + args.input_path, + args.output_path, + args.model_weights, + args.model_type, + args.optimize, + args.side, + args.height, + args.square, + args.grayscale, + args.camera_source, + args.save_npy, + ) diff --git a/stream_depth.py b/stream_depth.py new file mode 100644 index 0000000..0a6ba06 --- /dev/null +++ b/stream_depth.py @@ -0,0 +1,72 @@ +import argparse +import numpy as np +import requests +import torch +import cv2 + +from midas.model_loader import default_models, load_model +from run import process + + +def stream_depth(server_url, camera_source=0, model_type="dpt_beit_large_512", model_weights=None, optimize=False): + """Stream depth predictions via HTTP POST.""" + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if model_weights is None: + model_weights = default_models[model_type] + + model, transform, net_w, net_h = load_model(device, model_weights, model_type, optimize) + + src = int(camera_source) if str(camera_source).isdigit() else camera_source + cap = cv2.VideoCapture(src) + + try: + while True: + ret, frame = cap.read() + if not ret: + break + image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image = transform({"image": image_rgb / 255.0})["image"] + + with torch.no_grad(): + prediction = process( + device, + model, + model_type, + image, + (net_w, net_h), + image_rgb.shape[1::-1], + optimize, + True, + ) + + requests.post( + server_url, + data=prediction.astype(np.float32).tobytes(), + headers={ + "Content-Type": "application/octet-stream", + "X-Width": str(prediction.shape[1]), + "X-Height": str(prediction.shape[0]), + }, + timeout=1, + ) + finally: + cap.release() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Stream depth maps to another process.") + parser.add_argument("--server_url", required=True, help="URL to POST depth arrays to") + parser.add_argument("--camera_source", default=0, help="Camera device index or video/stream URL") + parser.add_argument("--model_type", default="dpt_beit_large_512", help="Model type") + parser.add_argument("--model_weights", default=None, help="Path to model weights") + parser.add_argument("--optimize", action="store_true", help="Use half-float optimization") + args = parser.parse_args() + + stream_depth( + args.server_url, + args.camera_source, + args.model_type, + args.model_weights, + args.optimize, + ) +