@@ -38,7 +38,12 @@ def inplace_lerp(tgt: Tensor, src: Tensor, weight):
38
38
src: Source tensor to interpolate towards
39
39
weight: Interpolation weight between 0 and 1
40
40
"""
41
- tgt .lerp_ (src .to (tgt .device ), weight )
41
+ # Check if tensor is integer type - integer tensors can't use lerp
42
+ # but we want to silently handle them instead of raising errors
43
+ if tgt .dtype in [torch .int , torch .int8 , torch .int16 , torch .int32 , torch .int64 , torch .long ]:
44
+ tgt .copy_ (src .to (tgt .device ))
45
+ else :
46
+ tgt .lerp_ (src .to (tgt .device ), weight )
42
47
43
48
44
49
class KarrasEMA (Module ):
@@ -49,94 +54,73 @@ class KarrasEMA(Module):
49
54
model: Model to create EMA of
50
55
sigma_rel: Relative standard deviation for EMA profile
51
56
gamma: Alternative parameterization via gamma (don't specify both)
52
- ema_model: Optional pre-initialized EMA model
53
57
update_every: Number of steps between EMA updates
54
58
frozen: Whether to freeze EMA updates
55
59
param_or_buffer_names_no_ema: Parameter/buffer names to exclude from EMA
56
60
ignore_names: Parameter/buffer names to ignore
57
61
ignore_startswith_names: Parameter/buffer name prefixes to ignore
58
62
only_save_diff: If True, only save parameters with requires_grad=True
63
+ device: Device to store EMA parameters on (default='cpu')
59
64
"""
60
65
66
+ # Buffers that should always be included in the state dict even with only_save_diff=True
67
+ _ALWAYS_INCLUDE_BUFFERS = {"running_mean" , "running_var" , "num_batches_tracked" }
68
+
61
69
def __init__ (
62
70
self ,
63
71
model : Module ,
64
72
sigma_rel : float | None = None ,
65
73
gamma : float | None = None ,
66
- ema_model : Module | Callable [[], Module ] | None = None ,
67
74
update_every : int = 10 ,
68
75
frozen : bool = False ,
69
76
param_or_buffer_names_no_ema : set [str ] = set (),
70
77
ignore_names : set [str ] = set (),
71
78
ignore_startswith_names : set [str ] = set (),
72
79
only_save_diff : bool = False ,
80
+ device : str = 'cpu' ,
73
81
):
74
82
super ().__init__ ()
75
-
76
- assert exists (sigma_rel ) ^ exists (
77
- gamma
78
- ), "either sigma_rel or gamma must be given"
79
-
80
- if exists (sigma_rel ):
81
- gamma = sigma_rel_to_gamma (sigma_rel )
82
-
83
+
84
+ # Store all the configuration parameters first
83
85
self .gamma = gamma
84
86
self .frozen = frozen
85
87
self .update_every = update_every
86
88
self .only_save_diff = only_save_diff
87
-
89
+ self .ignore_names = ignore_names
90
+ self .ignore_startswith_names = ignore_startswith_names
91
+ self .param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
92
+ self .device = device
93
+
94
+ assert exists (sigma_rel ) ^ exists (gamma ), "either sigma_rel or gamma must be given"
95
+
96
+ if exists (sigma_rel ):
97
+ gamma = sigma_rel_to_gamma (sigma_rel )
98
+ self .gamma = gamma
99
+
88
100
# Store reference to online model
89
101
self .online_model = [model ]
90
-
91
- # Initialize EMA model
92
- if callable (ema_model ) and not isinstance (ema_model , Module ):
93
- ema_model = ema_model ()
94
-
95
- # Store original device
96
- original_device = next (model .parameters ()).device
97
-
98
- # Move model to CPU before copying to avoid VRAM spike
99
- model .cpu ()
100
-
101
- try :
102
- # Create EMA model on CPU
103
- self .ema_model = (ema_model if exists (ema_model ) else deepcopy (model )).cpu ()
104
-
105
- # Ensure all parameters and buffers are on CPU and detached
106
- for p in self .ema_model .parameters ():
107
- p .data = p .data .cpu ().detach ()
108
- for b in self .ema_model .buffers ():
109
- b .data = b .data .cpu ().detach ()
110
-
111
- # Move model back to original device
112
- model .to (original_device )
113
-
114
- # Get parameter names for floating point or complex parameters
115
- self .param_names = {
116
- name
117
- for name , param in self .ema_model .named_parameters ()
118
- if torch .is_floating_point (param ) or torch .is_complex (param )
119
- }
120
-
121
- # Get buffer names for floating point or complex buffers
122
- self .buffer_names = {
123
- name
124
- for name , buffer in self .ema_model .named_buffers ()
125
- if torch .is_floating_point (buffer ) or torch .is_complex (buffer )
126
- }
127
-
128
- # Names to ignore
129
- self .param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
130
- self .ignore_names = ignore_names
131
- self .ignore_startswith_names = ignore_startswith_names
132
-
133
- # State buffers on CPU
134
- self .register_buffer ("initted" , torch .tensor (False , device = "cpu" ))
135
- self .register_buffer ("step" , torch .tensor (0 , device = "cpu" ))
136
- except :
137
- # Ensure model is moved back even if initialization fails
138
- model .to (original_device )
139
- raise
102
+
103
+ # Instead of copying the whole model, just store parameter tensors
104
+ self .ema_params = {}
105
+ self .ema_buffers = {}
106
+
107
+ # Get parameter and buffer names to track
108
+ with torch .no_grad ():
109
+ for name , param in model .named_parameters ():
110
+ if self ._should_update_param (name ):
111
+ if not only_save_diff or param .requires_grad :
112
+ self .ema_params [name ] = param .detach ().clone ().to (self .device )
113
+
114
+ for name , buffer in model .named_buffers ():
115
+ if self ._should_update_param (name ):
116
+ buffer_name = name .split ('.' )[- 1 ] # Get the base name
117
+ # Always include critical buffers regardless of only_save_diff
118
+ if not only_save_diff or buffer .requires_grad or buffer_name in self ._ALWAYS_INCLUDE_BUFFERS :
119
+ self .ema_buffers [name ] = buffer .detach ().clone ().to (self .device )
120
+
121
+ # State buffers
122
+ self .register_buffer ("initted" , torch .tensor (False ))
123
+ self .register_buffer ("step" , torch .tensor (0 ))
140
124
141
125
@property
142
126
def beta (self ):
@@ -161,42 +145,33 @@ def update(self):
161
145
def copy_params_from_model_to_ema (self ):
162
146
"""Copy parameters from online model to EMA model."""
163
147
# Copy parameters
164
- for (name , ma_params ), (_ , current_params ) in zip (
165
- self .get_params_iter (self .ema_model ),
166
- self .get_params_iter (self .online_model [0 ]),
167
- ):
168
- if self ._should_update_param (name ):
169
- inplace_copy (ma_params .data , current_params .data )
170
-
171
- # Copy buffers
172
- for (name , ma_buffer ), (_ , current_buffer ) in zip (
173
- self .get_buffers_iter (self .ema_model ),
174
- self .get_buffers_iter (self .online_model [0 ]),
175
- ):
176
- if self ._should_update_param (name ):
177
- inplace_copy (ma_buffer .data , current_buffer .data )
148
+ with torch .no_grad ():
149
+ for name , param in self .online_model [0 ].named_parameters ():
150
+ if name in self .ema_params :
151
+ # Explicitly move to device (usually CPU)
152
+ self .ema_params [name ] = param .detach ().clone ().to (self .device )
153
+
154
+ # Copy buffers
155
+ for name , buffer in self .online_model [0 ].named_buffers ():
156
+ if name in self .ema_buffers :
157
+ # Explicitly move to device (usually CPU)
158
+ self .ema_buffers [name ] = buffer .detach ().clone ().to (self .device )
178
159
179
160
def update_moving_average (self ):
180
161
"""Update EMA weights using current beta value."""
181
162
current_decay = self .beta
182
163
183
- # Update parameters
184
- for (name , current_params ), (_ , ma_params ) in zip (
185
- self .get_params_iter (self .online_model [0 ]),
186
- self .get_params_iter (self .ema_model ),
187
- ):
188
- if not self ._should_update_param (name ):
189
- continue
190
- inplace_lerp (ma_params .data , current_params .data , 1.0 - current_decay )
191
-
192
- # Update buffers
193
- for (name , current_buffer ), (_ , ma_buffer ) in zip (
194
- self .get_buffers_iter (self .online_model [0 ]),
195
- self .get_buffers_iter (self .ema_model ),
196
- ):
197
- if not self ._should_update_param (name ):
198
- continue
199
- inplace_lerp (ma_buffer .data , current_buffer .data , 1.0 - current_decay )
164
+ # Update parameters using the simplified lerp function (which now handles integer tensors)
165
+ for name , current_params in self .online_model [0 ].named_parameters ():
166
+ if name in self .ema_params :
167
+ # inplace_lerp now handles integer tensors internally
168
+ inplace_lerp (self .ema_params [name ], current_params .data , 1.0 - current_decay )
169
+
170
+ # Update buffers with the same simplified approach
171
+ for name , current_buffer in self .online_model [0 ].named_buffers ():
172
+ if name in self .ema_buffers :
173
+ # inplace_lerp now handles integer tensors internally
174
+ inplace_lerp (self .ema_buffers [name ], current_buffer .data , 1.0 - current_decay )
200
175
201
176
def _should_update_param (self , name : str ) -> bool :
202
177
"""Check if parameter should be updated based on ignore rules."""
@@ -208,10 +183,17 @@ def _should_update_param(self, name: str) -> bool:
208
183
return False
209
184
return True
210
185
186
+ def _parameter_requires_grad (self , name : str ) -> bool :
187
+ """Check if parameter requires gradients in the online model."""
188
+ for n , p in self .online_model [0 ].named_parameters ():
189
+ if n == name :
190
+ return p .requires_grad
191
+ return False
192
+
211
193
def get_params_iter (self , model ):
212
194
"""Get iterator over model's parameters."""
213
195
for name , param in model .named_parameters ():
214
- if name not in self .param_names :
196
+ if name not in self .ema_params :
215
197
continue
216
198
if self .only_save_diff and not param .requires_grad :
217
199
continue
@@ -220,17 +202,19 @@ def get_params_iter(self, model):
220
202
def get_buffers_iter (self , model ):
221
203
"""Get iterator over model's buffers."""
222
204
for name , buffer in model .named_buffers ():
223
- if name not in self .buffer_names :
205
+ if name not in self .ema_buffers :
224
206
continue
225
- if self .only_save_diff and not buffer .requires_grad :
207
+
208
+ # Handle critical buffers that should always be included
209
+ buffer_name = name .split ('.' )[- 1 ]
210
+ if self .only_save_diff and not buffer .requires_grad and buffer_name not in self ._ALWAYS_INCLUDE_BUFFERS :
226
211
continue
212
+
227
213
yield name , buffer
228
214
229
215
def iter_all_ema_params_and_buffers (self ):
230
216
"""Get iterator over all EMA parameters and buffers."""
231
- for name , param in self .ema_model .named_parameters ():
232
- if name not in self .param_names :
233
- continue
217
+ for name , param in self .ema_params .items ():
234
218
if name in self .param_or_buffer_names_no_ema :
235
219
continue
236
220
if name in self .ignore_names :
@@ -239,21 +223,10 @@ def iter_all_ema_params_and_buffers(self):
239
223
continue
240
224
yield param
241
225
242
- for name , buffer in self .ema_model .named_buffers ():
243
- if name not in self .buffer_names :
244
- continue
245
- if name in self .param_or_buffer_names_no_ema :
246
- continue
247
- if name in self .ignore_names :
248
- continue
249
- if any (name .startswith (prefix ) for prefix in self .ignore_startswith_names ):
250
- continue
251
- yield buffer
252
-
253
226
def iter_all_model_params_and_buffers (self , model : Module ):
254
227
"""Get iterator over all model parameters and buffers."""
255
228
for name , param in model .named_parameters ():
256
- if name not in self .param_names :
229
+ if name not in self .ema_params :
257
230
continue
258
231
if name in self .param_or_buffer_names_no_ema :
259
232
continue
@@ -263,59 +236,66 @@ def iter_all_model_params_and_buffers(self, model: Module):
263
236
continue
264
237
yield param
265
238
266
- for name , buffer in model .named_buffers ():
267
- if name not in self .buffer_names :
268
- continue
269
- if name in self .param_or_buffer_names_no_ema :
270
- continue
271
- if name in self .ignore_names :
272
- continue
273
- if any (name .startswith (prefix ) for prefix in self .ignore_startswith_names ):
274
- continue
275
- yield buffer
276
-
277
239
def __call__ (self , * args , ** kwargs ):
278
240
"""Forward pass using EMA model."""
279
- return self .ema_model (* args , ** kwargs )
241
+ raise NotImplementedError ("KarrasEMA no longer maintains a full model copy" )
242
+
243
+ @property
244
+ def ema_model (self ):
245
+ """
246
+ For backward compatibility with tests.
247
+ Creates a temporary model with EMA parameters.
248
+
249
+ Returns:
250
+ Module: A copy of the online model with EMA parameters
251
+ """
252
+ # Create a copy of the online model
253
+ model_copy = deepcopy (self .online_model [0 ])
254
+
255
+ # Load EMA parameters into the model
256
+ for name , param in model_copy .named_parameters ():
257
+ if name in self .ema_params :
258
+ param .data .copy_ (self .ema_params [name ])
259
+
260
+ # Load EMA buffers into the model
261
+ for name , buffer in model_copy .named_buffers ():
262
+ if name in self .ema_buffers :
263
+ buffer .data .copy_ (self .ema_buffers [name ])
264
+
265
+ # Ensure the model is on CPU
266
+ model_copy .to ('cpu' )
267
+ return model_copy
280
268
281
269
def state_dict (self ):
282
- """Get state dict for EMA model ."""
270
+ """Get state dict with EMA parameters ."""
283
271
state_dict = {}
284
-
285
- # Save parameters based on only_save_diff flag
286
- for name , param in self .ema_model .named_parameters ():
287
- if name not in self .param_names :
288
- continue
289
- if self .only_save_diff and not param .requires_grad :
290
- continue
291
- state_dict [name ] = param .data
292
-
293
- # Save buffers
294
- for name , buffer in self .ema_model .named_buffers ():
295
- if name not in self .buffer_names :
296
- continue
297
- state_dict [name ] = buffer .data
298
-
299
- # Save internal state
300
- state_dict ["initted" ] = self .initted .data
301
- state_dict ["step" ] = self .step .data
302
-
272
+
273
+ # For parameters, respect only_save_diff
274
+ for name , param in self .ema_params .items ():
275
+ if not self .only_save_diff or self ._parameter_requires_grad (name ):
276
+ state_dict [name ] = param .data
277
+
278
+ # For buffers, identify which ones should always be included
279
+ for name , buffer in self .ema_buffers .items ():
280
+ buffer_name = name .split ('.' )[- 1 ] # Get the base name
281
+ # Always include critical buffers regardless of only_save_diff
282
+ if not self .only_save_diff or buffer_name in self ._ALWAYS_INCLUDE_BUFFERS :
283
+ state_dict [name ] = buffer .data
284
+
285
+ # Add internal state
286
+ state_dict ["initted" ] = self .initted
287
+ state_dict ["step" ] = self .step
288
+
303
289
return state_dict
304
290
305
291
def load_state_dict (self , state_dict ):
306
- """Load state dict into EMA model."""
307
- # Load parameters based on only_save_diff flag
308
- for name , param in self .ema_model .named_parameters ():
309
- if (not self .only_save_diff or param .requires_grad ) and name in state_dict :
310
- param .data .copy_ (state_dict [name ].data )
311
-
312
- # Load buffers
313
- for name , buffer in self .ema_model .named_buffers ():
314
- if name in state_dict :
315
- buffer .data .copy_ (state_dict [name ].data )
316
-
317
- # Load internal state
318
- if "initted" in state_dict :
319
- self .initted .data .copy_ (state_dict ["initted" ].data )
320
- if "step" in state_dict :
321
- self .step .data .copy_ (state_dict ["step" ].data )
292
+ """Load state dict with EMA parameters."""
293
+ for name , param in state_dict .items ():
294
+ if name == "initted" :
295
+ self .initted .data .copy_ (param )
296
+ elif name == "step" :
297
+ self .step .data .copy_ (param )
298
+ elif name in self .ema_params :
299
+ self .ema_params [name ].data .copy_ (param )
300
+ elif name in self .ema_buffers :
301
+ self .ema_buffers [name ].data .copy_ (param )
0 commit comments