2222
2323
2424class Conv1dFlatWeights (nn .Conv1d ):
25- """Conv1d with weights+bias stored in a single 2D tensor"""
25+ """Conv1d with weights+bias stored in a single 2D tensor
26+
27+ There are conv1d used in some LLM, in mamba mixer for example. Because the weight is not 2d, we cannot apply
28+ many of the emerging optimizers originally introduced for 2d weights of Linear layers without bias. Since
29+ convolution can be viewed as a matrix multiplication with im2col (either implicit or explicit), we can flatten
30+ the weight into a single 2D tensor and then apply the emerging optimizers to it.
31+
32+ Bias is not commonly used in most LLM's anymore, but they are often included in this type of conv1d.
33+ Since bias is mathematically the 0 order term of the polynomial, we can combine weight and bias into a
34+ single 2D tensor.
35+
36+ Arguments are the same as ::class:`torch.nn.Conv1d`.
37+
38+ Note:
39+ Similar flattening logic can be applied to N-D convolution. But since we don't have use cases of them in LLM
40+ yet, they are not supported despite the __init__() function is generalized enough to support N-D convolution.
41+
42+ """
2643
2744 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
2845 super ().__init__ (* args , ** kwargs )
@@ -37,8 +54,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3754 flat_weight_shape [1 ] += 1
3855 flat_weight_buffer = torch .empty (flat_weight_shape , device = self .weight .device , dtype = self .weight .dtype )
3956 if self .bias is not None :
40- flat_weight_buffer [: , :- 1 ].copy_ (self .weight .view (self .out_channels , - 1 ))
41- flat_weight_buffer [: , - 1 ].copy_ (self .bias )
57+ flat_weight_buffer [... , :- 1 ].copy_ (self .weight .view (self .out_channels , - 1 ))
58+ flat_weight_buffer [... , - 1 ].copy_ (self .bias )
4259 del self .bias
4360 self .has_bias = True
4461 self .bias = "dummy" # Trick con1d.extra_repr() to not print bias=False
@@ -66,8 +83,8 @@ def from_conv1d(cls, conv1d: nn.Conv1d) -> Self:
6683 )
6784
6885 if conv1d .bias is not None :
69- conv1d_flat .weight .data [: , :- 1 ].copy_ (conv1d .weight .data .view (conv1d .out_channels , - 1 ))
70- conv1d_flat .weight .data [: , - 1 ].copy_ (conv1d .bias .data )
86+ conv1d_flat .weight .data [... , :- 1 ].copy_ (conv1d .weight .data .view (conv1d .out_channels , - 1 ))
87+ conv1d_flat .weight .data [... , - 1 ].copy_ (conv1d .bias .data )
7188 else :
7289 conv1d_flat .weight .data .copy_ (conv1d .weight .data .view (conv1d .out_channels , - 1 ))
7390 return conv1d_flat
@@ -78,8 +95,8 @@ def weight_shape(self) -> tuple[int, int, int]:
7895
7996 def forward (self , x : torch .Tensor ) -> torch .Tensor :
8097 if self .has_bias :
81- weight = self .weight [: , :- 1 ].view (self .weight_shape )
82- bias = self .weight [: , - 1 ]
98+ weight = self .weight [... , :- 1 ].view (self .weight_shape )
99+ bias = self .weight [... , - 1 ]
83100 else :
84101 weight = self .weight .view (self .weight_shape )
85102 bias = None
0 commit comments