1616import math
1717from typing import Any , Self
1818
19+
20+ try :
21+ from typing import override
22+ except ImportError :
23+ from typing_extensions import override
24+
1925import torch
2026import torch .nn as nn
2127import torch .nn .functional as F
@@ -51,8 +57,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5157
5258 assert self .padding_mode == "zeros" , "Only zeros padding is supported"
5359
54- self .weight : nn .Parameter [ torch . Tensor ]
55- self .bias : nn .Parameter [ torch . Tensor ] | None | str
60+ self .weight : nn .Parameter
61+ self .bias : nn .Parameter | None
5662
5763 flat_weight_shape = [self .out_channels , math .prod (self .weight .shape [1 :])]
5864 if self .bias is not None :
@@ -63,7 +69,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6369 flat_weight_buffer [..., - 1 ].copy_ (self .bias )
6470 del self .bias
6571 self .has_bias = True
66- self .bias = "dummy" # Trick con1d.extra_repr() to not print bias=False
6772 else :
6873 flat_weight_buffer .copy_ (self .weight .view (self .out_channels , - 1 ))
6974 self .has_bias = False
@@ -98,6 +103,7 @@ def from_conv1d(cls, conv1d: nn.Conv1d) -> Self:
98103 def weight_shape (self ) -> tuple [int , int , int ]:
99104 return (self .out_channels , self .in_channels // self .groups , self .kernel_size [0 ])
100105
106+ @override
101107 def forward (self , x : torch .Tensor ) -> torch .Tensor :
102108 if self .has_bias :
103109 weight = self .weight [..., :- 1 ].view (self .weight_shape )
@@ -108,6 +114,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
108114
109115 return F .conv1d (x , weight , bias , self .stride , self .padding , self .dilation , self .groups )
110116
117+ @override
111118 def extra_repr (self ) -> str :
112119 base_repr = super ().extra_repr ()
120+ if self .has_bias :
121+ base_repr += ", bias=True"
113122 return f"{ base_repr } , flattened_param_shape={ tuple (self .weight .shape )} "
0 commit comments