File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -71,7 +71,7 @@ def reset(self):
7171 state ['m_avg' ] = torch .zeros_like (p )
7272 state ['v_avg' ] = torch .zeros_like (p )
7373 else :
74- state ['v_avg_0' ] = torch .zeros_like (p .shape (dim = 1 ))
74+ state ['v_avg_0' ] = torch .zeros_like (p .mean (dim = 1 ))
7575 state ['v_avg_1' ] = torch .zeros_like (p .mean (dim = 0 ))
7676
7777 state ['m_avg_c' ] = torch .zeros_like (p .mean (dim = 1 )[:, None ])
@@ -108,7 +108,7 @@ def step(self, closure: CLOSURE = None) -> LOSS:
108108 state ['m_avg' ] = torch .zeros_like (p )
109109 state ['v_avg' ] = torch .zeros_like (p )
110110 else :
111- state ['v_avg_0' ] = torch .zeros_like (p .shape (dim = 1 ))
111+ state ['v_avg_0' ] = torch .zeros_like (p .mean (dim = 1 ))
112112 state ['v_avg_1' ] = torch .zeros_like (p .mean (dim = 0 ))
113113
114114 state ['m_avg_c' ] = torch .zeros_like (p .mean (dim = 1 )[:, None ])
You can’t perform that action at this time.
0 commit comments