@@ -92,39 +92,52 @@ def __init__(
92
92
if callable (ema_model ) and not isinstance (ema_model , Module ):
93
93
ema_model = ema_model ()
94
94
95
- # Create EMA model on CPU
96
- self .ema_model = (ema_model if exists (ema_model ) else deepcopy (model )).cpu ()
97
-
98
- # Ensure all parameters and buffers are on CPU and detached
99
- for p in self .ema_model .parameters ():
100
- p .data = p .data .cpu ().detach ()
101
- for b in self .ema_model .buffers ():
102
- b .data = b .data .cpu ().detach ()
103
-
104
- # Get parameter names that require gradients
105
- self .param_names = {
106
- name
107
- for name , param in self .ema_model .named_parameters ()
108
- if (not only_save_diff or param .requires_grad ) and (
109
- torch .is_floating_point (param ) or torch .is_complex (param )
110
- )
111
- }
112
-
113
- # Get buffer names for floating point or complex buffers
114
- self .buffer_names = {
115
- name
116
- for name , buffer in self .ema_model .named_buffers ()
117
- if torch .is_floating_point (buffer ) or torch .is_complex (buffer )
118
- }
119
-
120
- # Names to ignore
121
- self .param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
122
- self .ignore_names = ignore_names
123
- self .ignore_startswith_names = ignore_startswith_names
124
-
125
- # State buffers on CPU
126
- self .register_buffer ("initted" , torch .tensor (False , device = "cpu" ))
127
- self .register_buffer ("step" , torch .tensor (0 , device = "cpu" ))
95
+ # Store original device
96
+ original_device = next (model .parameters ()).device
97
+
98
+ # Move model to CPU before copying to avoid VRAM spike
99
+ model .cpu ()
100
+
101
+ try :
102
+ # Create EMA model on CPU
103
+ self .ema_model = (ema_model if exists (ema_model ) else deepcopy (model )).cpu ()
104
+
105
+ # Ensure all parameters and buffers are on CPU and detached
106
+ for p in self .ema_model .parameters ():
107
+ p .data = p .data .cpu ().detach ()
108
+ for b in self .ema_model .buffers ():
109
+ b .data = b .data .cpu ().detach ()
110
+
111
+ # Move model back to original device
112
+ model .to (original_device )
113
+
114
+ # Get parameter names that require gradients
115
+ self .param_names = {
116
+ name
117
+ for name , param in self .ema_model .named_parameters ()
118
+ if (not only_save_diff or param .requires_grad )
119
+ and (torch .is_floating_point (param ) or torch .is_complex (param ))
120
+ }
121
+
122
+ # Get buffer names for floating point or complex buffers
123
+ self .buffer_names = {
124
+ name
125
+ for name , buffer in self .ema_model .named_buffers ()
126
+ if torch .is_floating_point (buffer ) or torch .is_complex (buffer )
127
+ }
128
+
129
+ # Names to ignore
130
+ self .param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
131
+ self .ignore_names = ignore_names
132
+ self .ignore_startswith_names = ignore_startswith_names
133
+
134
+ # State buffers on CPU
135
+ self .register_buffer ("initted" , torch .tensor (False , device = "cpu" ))
136
+ self .register_buffer ("step" , torch .tensor (0 , device = "cpu" ))
137
+ except :
138
+ # Ensure model is moved back even if initialization fails
139
+ model .to (original_device )
140
+ raise
128
141
129
142
@property
130
143
def beta (self ):
@@ -239,23 +252,23 @@ def __call__(self, *args, **kwargs):
239
252
def state_dict (self ):
240
253
"""Get state dict of EMA model."""
241
254
state_dict = {}
242
-
255
+
243
256
# Add parameters based on only_save_diff flag
244
257
for name , param in self .ema_model .named_parameters ():
245
258
if (not self .only_save_diff or param .requires_grad ) and (
246
259
torch .is_floating_point (param ) or torch .is_complex (param )
247
260
):
248
261
state_dict [name ] = param
249
-
262
+
250
263
# Add buffers (always included regardless of only_save_diff)
251
264
for name , buffer in self .ema_model .named_buffers ():
252
265
if torch .is_floating_point (buffer ) or torch .is_complex (buffer ):
253
266
state_dict [name ] = buffer
254
-
267
+
255
268
# Add internal state
256
269
state_dict ["initted" ] = self .initted
257
270
state_dict ["step" ] = self .step
258
-
271
+
259
272
return state_dict
260
273
261
274
def load_state_dict (self , state_dict ):
@@ -264,12 +277,12 @@ def load_state_dict(self, state_dict):
264
277
for name , param in self .ema_model .named_parameters ():
265
278
if (not self .only_save_diff or param .requires_grad ) and name in state_dict :
266
279
param .data .copy_ (state_dict [name ].data )
267
-
280
+
268
281
# Load buffers
269
282
for name , buffer in self .ema_model .named_buffers ():
270
283
if name in state_dict :
271
284
buffer .data .copy_ (state_dict [name ].data )
272
-
285
+
273
286
# Load internal state
274
287
if "initted" in state_dict :
275
288
self .initted .data .copy_ (state_dict ["initted" ].data )
0 commit comments