Skip to content

Commit 1cbdcbb

Browse files
committed
fix tricks Conv1dFlatWeights
Signed-off-by: Hao Wu <[email protected]>
1 parent ba1b68a commit 1cbdcbb

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

emerging_optimizers/utils/modules.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
116116

117117
@override
118118
def extra_repr(self) -> str:
119-
if self.has_bias:
120-
# trick extra_repr() to print bias=True
121-
self.bias = torch.tensor(0)
122-
base_repr = super().extra_repr()
123-
if self.has_bias:
124-
del self.bias
125-
return f"{base_repr}, flattened_param_shape={tuple(self.weight.shape)}"
119+
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
120+
if self.padding != (0,) * len(self.padding):
121+
s += ", padding={padding}"
122+
if self.dilation != (1,) * len(self.dilation):
123+
s += ", dilation={dilation}"
124+
if self.output_padding != (0,) * len(self.output_padding):
125+
s += ", output_padding={output_padding}"
126+
if self.groups != 1:
127+
s += ", groups={groups}"
128+
if not self.has_bias:
129+
s += ", bias=False"
130+
if self.padding_mode != "zeros":
131+
s += ", padding_mode={padding_mode}"
132+
s += ", flattened_param_shape={tuple(self.weight.shape)}"
133+
return s.format(**self.__dict__)

0 commit comments

Comments
 (0)