@@ -69,7 +69,8 @@ def update(self, group: Dict):
6969 param_state ['slow_mom' ] = torch .zeros_like (fast )
7070
7171 slow = param_state ['slow_param' ]
72- slow += (fast - slow ) * self .alpha
72+ slow .add_ (fast - slow , alpha = self .alpha )
73+
7374 fast .copy_ (slow )
7475
7576 if 'momentum_buffer' not in self .optimizer .state [fast ]:
@@ -98,30 +99,21 @@ def step(self, closure: CLOSURE = None) -> LOSS:
9899 return loss
99100
100101 def state_dict (self ) -> STATE :
101- fast_state_dict : STATE = self .optimizer .state_dict ()
102- fast_state = fast_state_dict ['state' ]
103- param_groups = fast_state_dict ['param_groups' ]
104-
102+ fast_state : STATE = self .optimizer .state_dict ()
105103 slow_state : STATE = {(id (k ) if isinstance (k , torch .Tensor ) else k ): v for k , v in self .state .items ()}
106104
107105 return {
108- 'fast_state' : fast_state ,
106+ 'fast_state' : fast_state [ 'state' ] ,
109107 'slow_state' : slow_state ,
110- 'param_groups' : param_groups ,
108+ 'param_groups' : fast_state [ ' param_groups' ] ,
111109 }
112110
113- def load_state_dict (self , state_dict : STATE ):
114- slow_state_dict : STATE = {
115- 'state' : state_dict ['slow_state' ],
116- 'param_groups' : state_dict ['param_groups' ],
117- }
118- fast_state_dict : STATE = {
119- 'state' : state_dict ['fast_state' ],
120- 'param_groups' : state_dict ['param_groups' ],
121- }
122- super ().load_state_dict (slow_state_dict )
111+ def load_state_dict (self , state : STATE ):
112+ slow_state : STATE = {'state' : state ['slow_state' ], 'param_groups' : state ['param_groups' ]}
113+ fast_state : STATE = {'state' : state ['fast_state' ], 'param_groups' : state ['param_groups' ]}
114+ super ().load_state_dict (slow_state )
123115
124- self .optimizer .load_state_dict (fast_state_dict )
116+ self .optimizer .load_state_dict (fast_state )
125117 self .fast_state = self .optimizer .state
126118
127119 def add_param_group (self , param_group ):
0 commit comments