|
12 | 12 | "InvertedBottleneckConv", |
13 | 13 | "FusedMobileInvertedConv", |
14 | 14 | "HoverNetDenseConv", |
| 15 | + "BasicConvOld", |
15 | 16 | ] |
16 | 17 |
|
17 | 18 |
|
@@ -829,3 +830,135 @@ def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor: |
829 | 830 | x = self.conv(x) |
830 | 831 |
|
831 | 832 | return x |
| 833 | + |
| 834 | + |
| 835 | +class BasicConvOld(nn.Module): |
| 836 | + def __init__( |
| 837 | + self, |
| 838 | + in_channels: int, |
| 839 | + out_channels: int, |
| 840 | + same_padding: bool = True, |
| 841 | + normalization: str = "bn", |
| 842 | + activation: str = "relu", |
| 843 | + convolution: str = "conv", |
| 844 | + preactivate: bool = False, |
| 845 | + kernel_size=3, |
| 846 | + groups: int = 1, |
| 847 | + bias: bool = False, |
| 848 | + attention: str = None, |
| 849 | + preattend: bool = False, |
| 850 | + **kwargs |
| 851 | + ) -> None: |
| 852 | + """Conv-block (basic) parent class. |
| 853 | +
|
| 854 | + Parameters |
| 855 | + ---------- |
| 856 | + in_channels : int |
| 857 | + Number of input channels. |
| 858 | + out_channels : int |
| 859 | + Number of output channels. |
| 860 | + same_padding : bool, default=True |
| 861 | + if True, performs same-covolution. |
| 862 | + normalization : str, default="bn": |
| 863 | + Normalization method. |
| 864 | + One of: "bn", "bcn", "gn", "in", "ln", None |
| 865 | + activation : str, default="relu" |
| 866 | + Activation method. |
| 867 | + One of: "mish", "swish", "relu", "relu6", "rrelu", "selu", |
| 868 | + "celu", "gelu", "glu", "tanh", "sigmoid", "silu", "prelu", |
| 869 | + "leaky-relu", "elu", "hardshrink", "tanhshrink", "hardsigmoid" |
| 870 | + convolution : str, default="conv" |
| 871 | + The convolution method. One of: "conv", "wsconv", "scaled_wsconv" |
| 872 | + preactivate : bool, default=False |
| 873 | + If True, normalization will be applied before convolution. |
| 874 | + kernel_size : int, default=3 |
| 875 | + The size of the convolution kernel. |
| 876 | + groups : int, default=1 |
| 877 | + Number of groups the kernels are divided into. If `groups == 1` |
| 878 | + normal convolution is applied. If `groups = in_channels` |
| 879 | + depthwise convolution is applied. |
| 880 | + bias : bool, default=False, |
| 881 | + Include bias term in the convolution. |
| 882 | + attention : str, default=None |
| 883 | + Attention method. One of: "se", "scse", "gc", "eca", None |
| 884 | + preattend : bool, default=False |
| 885 | + If True, Attention is applied at the beginning of forward pass. |
| 886 | + """ |
| 887 | + super().__init__() |
| 888 | + self.conv_choice = convolution |
| 889 | + self.out_channels = out_channels |
| 890 | + self.preattend = preattend |
| 891 | + self.preactivate = preactivate |
| 892 | + |
| 893 | + # set norm channel number for preactivation or normal |
| 894 | + norm_channels = in_channels if preactivate else self.out_channels |
| 895 | + |
| 896 | + # set padding. Works if dilation or stride are not adjusted |
| 897 | + padding = (kernel_size - 1) // 2 if same_padding else 0 |
| 898 | + |
| 899 | + self.conv = Conv( |
| 900 | + name=self.conv_choice, |
| 901 | + in_channels=in_channels, |
| 902 | + out_channels=out_channels, |
| 903 | + kernel_size=kernel_size, |
| 904 | + groups=groups, |
| 905 | + padding=padding, |
| 906 | + bias=bias, |
| 907 | + ) |
| 908 | + |
| 909 | + self.norm = Norm(normalization, num_features=norm_channels) |
| 910 | + self.act = Activation(activation) |
| 911 | + |
| 912 | + # set attention channels |
| 913 | + att_channels = in_channels if preattend else self.out_channels |
| 914 | + self.att = Attention(attention, in_channels=att_channels) |
| 915 | + |
| 916 | + self.downsample = None |
| 917 | + if in_channels != out_channels: |
| 918 | + self.downsample = nn.Sequential( |
| 919 | + Conv( |
| 920 | + self.conv_choice, |
| 921 | + in_channels=in_channels, |
| 922 | + out_channels=out_channels, |
| 923 | + bias=False, |
| 924 | + kernel_size=1, |
| 925 | + padding=0, |
| 926 | + ), |
| 927 | + Norm(normalization, num_features=out_channels), |
| 928 | + ) |
| 929 | + |
| 930 | + def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
| 931 | + """Forward pass.""" |
| 932 | + identity = x |
| 933 | + if self.downsample is not None: |
| 934 | + identity = self.downsample(x) |
| 935 | + |
| 936 | + x = self.att(x) |
| 937 | + |
| 938 | + # residual |
| 939 | + x = self.conv(x) |
| 940 | + x = self.norm(x) |
| 941 | + |
| 942 | + x += identity |
| 943 | + x = self.act(x) |
| 944 | + |
| 945 | + return x |
| 946 | + |
| 947 | + def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor: |
| 948 | + """Forward pass with pre-activation.""" |
| 949 | + identity = x |
| 950 | + if self.downsample is not None: |
| 951 | + identity = self.downsample(x) |
| 952 | + |
| 953 | + # pre-attention |
| 954 | + x = self.att(x) |
| 955 | + |
| 956 | + # preact residual |
| 957 | + x = self.norm(x) |
| 958 | + x = self.act(x) |
| 959 | + x = self.conv(x) |
| 960 | + |
| 961 | + x += identity |
| 962 | + x = self.act(x) |
| 963 | + |
| 964 | + return x |
0 commit comments