Skip to content

Commit 05e54a6

Browse files
committed
update docstring
Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent 4a14cf8 commit 05e54a6

File tree

2 files changed

+30
-7
lines changed

2 files changed

+30
-7
lines changed

docs/apidocs/utils.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,10 @@ emerging_optimizers.utils.eig
1414
=============================
1515
.. automodule:: emerging_optimizers.utils.eig
1616
:members:
17+
18+
19+
emerging_optimizers.utils.modules
20+
=================================
21+
.. automodule:: emerging_optimizers.utils.modules
22+
:members:
1723
```

emerging_optimizers/utils/modules.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,24 @@
2222

2323

2424
class Conv1dFlatWeights(nn.Conv1d):
25-
"""Conv1d with weights+bias stored in a single 2D tensor"""
25+
"""Conv1d with weights+bias stored in a single 2D tensor
26+
27+
There are conv1d used in some LLM, in mamba mixer for example. Because the weight is not 2d, we cannot apply
28+
many of the emerging optimizers originally introduced for 2d weights of Linear layers without bias. Since
29+
convolution can be viewed as a matrix multiplication with im2col (either implicit or explicit), we can flatten
30+
the weight into a single 2D tensor and then apply the emerging optimizers to it.
31+
32+
Bias is not commonly used in most LLM's anymore, but they are often included in this type of conv1d.
33+
Since bias is mathematically the 0 order term of the polynomial, we can combine weight and bias into a
34+
single 2D tensor.
35+
36+
Arguments are the same as ::class:`torch.nn.Conv1d`.
37+
38+
Note:
39+
Similar flattening logic can be applied to N-D convolution. But since we don't have use cases of them in LLM
40+
yet, they are not supported despite the __init__() function is generalized enough to support N-D convolution.
41+
42+
"""
2643

2744
def __init__(self, *args: Any, **kwargs: Any) -> None:
2845
super().__init__(*args, **kwargs)
@@ -37,8 +54,8 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
3754
flat_weight_shape[1] += 1
3855
flat_weight_buffer = torch.empty(flat_weight_shape, device=self.weight.device, dtype=self.weight.dtype)
3956
if self.bias is not None:
40-
flat_weight_buffer[:, :-1].copy_(self.weight.view(self.out_channels, -1))
41-
flat_weight_buffer[:, -1].copy_(self.bias)
57+
flat_weight_buffer[..., :-1].copy_(self.weight.view(self.out_channels, -1))
58+
flat_weight_buffer[..., -1].copy_(self.bias)
4259
del self.bias
4360
self.has_bias = True
4461
self.bias = "dummy" # Trick con1d.extra_repr() to not print bias=False
@@ -66,8 +83,8 @@ def from_conv1d(cls, conv1d: nn.Conv1d) -> Self:
6683
)
6784

6885
if conv1d.bias is not None:
69-
conv1d_flat.weight.data[:, :-1].copy_(conv1d.weight.data.view(conv1d.out_channels, -1))
70-
conv1d_flat.weight.data[:, -1].copy_(conv1d.bias.data)
86+
conv1d_flat.weight.data[..., :-1].copy_(conv1d.weight.data.view(conv1d.out_channels, -1))
87+
conv1d_flat.weight.data[..., -1].copy_(conv1d.bias.data)
7188
else:
7289
conv1d_flat.weight.data.copy_(conv1d.weight.data.view(conv1d.out_channels, -1))
7390
return conv1d_flat
@@ -78,8 +95,8 @@ def weight_shape(self) -> tuple[int, int, int]:
7895

7996
def forward(self, x: torch.Tensor) -> torch.Tensor:
8097
if self.has_bias:
81-
weight = self.weight[:, :-1].view(self.weight_shape)
82-
bias = self.weight[:, -1]
98+
weight = self.weight[..., :-1].view(self.weight_shape)
99+
bias = self.weight[..., -1]
83100
else:
84101
weight = self.weight.view(self.weight_shape)
85102
bias = None

0 commit comments

Comments
 (0)