@@ -43,18 +43,18 @@ def inplace_lerp(tgt: Tensor, src: Tensor, weight):
43
43
44
44
class KarrasEMA (Module ):
45
45
"""
46
- Exponential Moving Average module using hyperparameters from the Karras et al. paper .
46
+ Karras EMA implementation with power function decay profile .
47
47
48
48
Args:
49
- model: The model to create an EMA of
50
- sigma_rel: Relative standard deviation for EMA profile width
51
- gamma: Direct gamma parameter (alternative to sigma_rel )
49
+ model: Model to create EMA of
50
+ sigma_rel: Relative standard deviation for EMA profile
51
+ gamma: Alternative parameterization via gamma (don't specify both )
52
52
ema_model: Optional pre-initialized EMA model
53
53
update_every: Number of steps between EMA updates
54
- frozen: If True, EMA weights are not updated
55
- param_or_buffer_names_no_ema: Set of parameter /buffer names to exclude from EMA
56
- ignore_names: Set of names to ignore
57
- ignore_startswith_names: Set of name prefixes to ignore
54
+ frozen: Whether to freeze EMA updates
55
+ param_or_buffer_names_no_ema: Parameter /buffer names to exclude from EMA
56
+ ignore_names: Parameter/buffer names to ignore
57
+ ignore_startswith_names: Parameter/buffer name prefixes to ignore
58
58
only_save_diff: If True, only save parameters with requires_grad=True
59
59
"""
60
60
@@ -111,12 +111,11 @@ def __init__(
111
111
# Move model back to original device
112
112
model .to (original_device )
113
113
114
- # Get parameter names that require gradients
114
+ # Get parameter names for floating point or complex parameters
115
115
self .param_names = {
116
116
name
117
117
for name , param in self .ema_model .named_parameters ()
118
- if (not only_save_diff or param .requires_grad )
119
- and (torch .is_floating_point (param ) or torch .is_complex (param ))
118
+ if torch .is_floating_point (param ) or torch .is_complex (param )
120
119
}
121
120
122
121
# Get buffer names for floating point or complex buffers
@@ -161,17 +160,27 @@ def update(self):
161
160
162
161
def copy_params_from_model_to_ema (self ):
163
162
"""Copy parameters from online model to EMA model."""
163
+ # Copy parameters
164
164
for (name , ma_params ), (_ , current_params ) in zip (
165
165
self .get_params_iter (self .ema_model ),
166
166
self .get_params_iter (self .online_model [0 ]),
167
167
):
168
168
if self ._should_update_param (name ):
169
169
inplace_copy (ma_params .data , current_params .data )
170
170
171
+ # Copy buffers
172
+ for (name , ma_buffer ), (_ , current_buffer ) in zip (
173
+ self .get_buffers_iter (self .ema_model ),
174
+ self .get_buffers_iter (self .online_model [0 ]),
175
+ ):
176
+ if self ._should_update_param (name ):
177
+ inplace_copy (ma_buffer .data , current_buffer .data )
178
+
171
179
def update_moving_average (self ):
172
180
"""Update EMA weights using current beta value."""
173
181
current_decay = self .beta
174
182
183
+ # Update parameters
175
184
for (name , current_params ), (_ , ma_params ) in zip (
176
185
self .get_params_iter (self .online_model [0 ]),
177
186
self .get_params_iter (self .ema_model ),
@@ -180,6 +189,15 @@ def update_moving_average(self):
180
189
continue
181
190
inplace_lerp (ma_params .data , current_params .data , 1.0 - current_decay )
182
191
192
+ # Update buffers
193
+ for (name , current_buffer ), (_ , ma_buffer ) in zip (
194
+ self .get_buffers_iter (self .online_model [0 ]),
195
+ self .get_buffers_iter (self .ema_model ),
196
+ ):
197
+ if not self ._should_update_param (name ):
198
+ continue
199
+ inplace_lerp (ma_buffer .data , current_buffer .data , 1.0 - current_decay )
200
+
183
201
def _should_update_param (self , name : str ) -> bool :
184
202
"""Check if parameter should be updated based on ignore rules."""
185
203
if name in self .ignore_names :
@@ -195,8 +213,19 @@ def get_params_iter(self, model):
195
213
for name , param in model .named_parameters ():
196
214
if name not in self .param_names :
197
215
continue
216
+ if self .only_save_diff and not param .requires_grad :
217
+ continue
198
218
yield name , param
199
219
220
+ def get_buffers_iter (self , model ):
221
+ """Get iterator over model's buffers."""
222
+ for name , buffer in model .named_buffers ():
223
+ if name not in self .buffer_names :
224
+ continue
225
+ if self .only_save_diff and not buffer .requires_grad :
226
+ continue
227
+ yield name , buffer
228
+
200
229
def iter_all_ema_params_and_buffers (self ):
201
230
"""Get iterator over all EMA parameters and buffers."""
202
231
for name , param in self .ema_model .named_parameters ():
@@ -250,24 +279,26 @@ def __call__(self, *args, **kwargs):
250
279
return self .ema_model (* args , ** kwargs )
251
280
252
281
def state_dict (self ):
253
- """Get state dict of EMA model."""
282
+ """Get state dict for EMA model."""
254
283
state_dict = {}
255
284
256
- # Add parameters based on only_save_diff flag
285
+ # Save parameters based on only_save_diff flag
257
286
for name , param in self .ema_model .named_parameters ():
258
- if (not self .only_save_diff or param .requires_grad ) and (
259
- torch .is_floating_point (param ) or torch .is_complex (param )
260
- ):
261
- state_dict [name ] = param
287
+ if name not in self .param_names :
288
+ continue
289
+ if self .only_save_diff and not param .requires_grad :
290
+ continue
291
+ state_dict [name ] = param .data
262
292
263
- # Add buffers (always included regardless of only_save_diff)
293
+ # Save buffers
264
294
for name , buffer in self .ema_model .named_buffers ():
265
- if torch .is_floating_point (buffer ) or torch .is_complex (buffer ):
266
- state_dict [name ] = buffer
295
+ if name not in self .buffer_names :
296
+ continue
297
+ state_dict [name ] = buffer .data
267
298
268
- # Add internal state
269
- state_dict ["initted" ] = self .initted
270
- state_dict ["step" ] = self .step
299
+ # Save internal state
300
+ state_dict ["initted" ] = self .initted . data
301
+ state_dict ["step" ] = self .step . data
271
302
272
303
return state_dict
273
304
0 commit comments