Skip to content

Commit ba1b68a

Browse files
committed
fix tricks Conv1dFlatWeights
Signed-off-by: Hao Wu <[email protected]>
1 parent e8daede commit ba1b68a

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

emerging_optimizers/utils/modules.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
5858
assert self.padding_mode == "zeros", "Only zeros padding is supported"
5959

6060
self.weight: nn.Parameter
61-
self.bias: nn.Parameter | None
61+
self.bias: torch.Tensor | None
6262

6363
flat_weight_shape = [self.out_channels, math.prod(self.weight.shape[1:])]
6464
if self.bias is not None:
@@ -116,7 +116,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
116116

117117
@override
118118
def extra_repr(self) -> str:
119+
if self.has_bias:
120+
# trick extra_repr() to print bias=True
121+
self.bias = torch.tensor(0)
119122
base_repr = super().extra_repr()
120123
if self.has_bias:
121-
base_repr += ", bias=True"
124+
del self.bias
122125
return f"{base_repr}, flattened_param_shape={tuple(self.weight.shape)}"

0 commit comments

Comments
 (0)