File tree Expand file tree Collapse file tree 1 file changed +15
-7
lines changed
emerging_optimizers/utils Expand file tree Collapse file tree 1 file changed +15
-7
lines changed Original file line number Diff line number Diff line change @@ -116,10 +116,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
116116
117117 @override
118118 def extra_repr (self ) -> str :
119- if self .has_bias :
120- # trick extra_repr() to print bias=True
121- self .bias = torch .tensor (0 )
122- base_repr = super ().extra_repr ()
123- if self .has_bias :
124- del self .bias
125- return f"{ base_repr } , flattened_param_shape={ tuple (self .weight .shape )} "
119+ s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
120+ if self .padding != (0 ,) * len (self .padding ):
121+ s += ", padding={padding}"
122+ if self .dilation != (1 ,) * len (self .dilation ):
123+ s += ", dilation={dilation}"
124+ if self .output_padding != (0 ,) * len (self .output_padding ):
125+ s += ", output_padding={output_padding}"
126+ if self .groups != 1 :
127+ s += ", groups={groups}"
128+ if not self .has_bias :
129+ s += ", bias=False"
130+ if self .padding_mode != "zeros" :
131+ s += ", padding_mode={padding_mode}"
132+ s += ", flattened_param_shape={tuple(self.weight.shape)}"
133+ return s .format (** self .__dict__ )
You can’t perform that action at this time.
0 commit comments