8
8
import torch
9
9
from PIL import Image
10
10
from torch import nn
11
+ import pickle
12
+ import io
13
+ import torch .serialization
11
14
12
15
from .karras_ema import KarrasEMA
13
16
from .utils import _safe_torch_load , p_dot_p , sigma_rel_to_gamma , solve_weights
@@ -285,37 +288,49 @@ def _cleanup_old_checkpoints(self) -> None:
285
288
286
289
@contextmanager
287
290
def model (
288
- self , model : nn .Module , sigma_rel : float
289
- ) -> Generator [nn .Module , None , None ]:
290
- """Context manager that temporarily sets model parameters to EMA state.
291
+ self ,
292
+ model : nn .Module ,
293
+ sigma_rel : float ,
294
+ ) -> Iterator [nn .Module ]:
295
+ """
296
+ Context manager for temporarily setting model parameters to EMA state.
291
297
292
298
Args:
293
- model: Model to update
299
+ model: Model to temporarily set to EMA state
294
300
sigma_rel: Target relative standard deviation
295
301
296
- Returns :
297
- Model with EMA parameters
302
+ Yields :
303
+ nn.Module: Model with EMA parameters
298
304
"""
299
- # Store original device and move model to CPU
305
+ # Move model to CPU for memory efficiency
300
306
original_device = next (model .parameters ()).device
301
307
model .cpu ()
302
308
torch .cuda .empty_cache ()
303
309
304
310
try :
305
311
with self .state_dict (sigma_rel = sigma_rel ) as state_dict :
306
- ema_model = deepcopy (model )
307
- result = ema_model .load_state_dict (
312
+ # Store original state only for parameters that will be modified
313
+ original_state = {
314
+ name : param .detach ().clone ()
315
+ for name , param in model .state_dict ().items ()
316
+ if name in state_dict
317
+ }
318
+
319
+ # Load EMA state directly into model
320
+ result = model .load_state_dict (
308
321
state_dict , strict = not self .only_save_diff
309
322
)
310
323
assert (
311
324
len (result .unexpected_keys ) == 0
312
325
), f"Unexpected keys: { result .unexpected_keys } "
313
- ema_model .eval () # Set to eval mode to handle BatchNorm
314
- yield ema_model
315
- # Clean up EMA model
316
- if hasattr (ema_model , "cuda" ):
317
- ema_model .cpu ()
318
- del ema_model
326
+ model .eval () # Set to eval mode to handle BatchNorm
327
+ yield model
328
+
329
+ # Restore original state
330
+ model .load_state_dict (original_state , strict = False )
331
+ del original_state
332
+ del state_dict # Free memory for state dict
333
+ torch .cuda .empty_cache ()
319
334
finally :
320
335
# Restore model to original device
321
336
model .to (original_device )
@@ -341,10 +356,18 @@ def state_dict(
341
356
gamma = sigma_rel_to_gamma (sigma_rel )
342
357
device = torch .device ("cpu" ) # Keep synthesis on CPU for memory efficiency
343
358
344
- # Get all checkpoint files
359
+ # First count total checkpoints to pre-allocate tensors
360
+ total_checkpoints = 0
361
+ checkpoint_files = []
345
362
if self .ema_models is not None :
346
363
# When we have ema_models, use their indices
347
- indices = range (len (self .ema_models ))
364
+ for idx in range (len (self .ema_models )):
365
+ files = sorted (
366
+ self .checkpoint_dir .glob (f"{ idx } .*.pt" ),
367
+ key = lambda p : int (p .stem .split ("." )[1 ]),
368
+ )
369
+ total_checkpoints += len (files )
370
+ checkpoint_files .extend (files )
348
371
else :
349
372
# When loading from path, find all unique indices
350
373
indices = set ()
@@ -353,78 +376,101 @@ def state_dict(
353
376
indices .add (idx )
354
377
indices = sorted (indices )
355
378
356
- # Get checkpoint files and info
357
- checkpoint_files = []
358
- gammas = []
359
- timesteps = []
360
- for idx in indices :
361
- files = sorted (
362
- self .checkpoint_dir .glob (f"{ idx } .*.pt" ),
363
- key = lambda p : int (p .stem .split ("." )[1 ]),
364
- )
365
- for file in files :
366
- _ , timestep = map (int , file .stem .split ("." ))
367
- if self .ema_models is not None :
368
- gammas .append (self .gammas [idx ])
369
- else :
370
- # Load gamma from checkpoint
371
- checkpoint = _safe_torch_load (str (file ))
372
- sigma_rel = checkpoint .get ("sigma_rel" , None )
373
- if sigma_rel is not None :
374
- gammas .append (sigma_rel_to_gamma (sigma_rel ))
375
- else :
376
- gammas .append (self .gammas [idx ])
377
- del checkpoint # Free memory
378
- timesteps .append (timestep )
379
- checkpoint_files .append (file )
380
-
381
- if not gammas :
379
+ for idx in indices :
380
+ files = sorted (
381
+ self .checkpoint_dir .glob (f"{ idx } .*.pt" ),
382
+ key = lambda p : int (p .stem .split ("." )[1 ]),
383
+ )
384
+ total_checkpoints += len (files )
385
+ checkpoint_files .extend (files )
386
+
387
+ if total_checkpoints == 0 :
382
388
raise ValueError ("No checkpoints found" )
383
389
384
- # Convert to tensors
385
- gammas = torch .tensor (gammas , device = device )
386
- timesteps = torch .tensor (timesteps , device = device )
390
+ # Pre-allocate tensors
391
+ gammas = torch .empty (total_checkpoints , device = device )
392
+ timesteps = torch .empty (total_checkpoints , dtype = torch .long , device = device )
393
+
394
+ # Fill tensors one value at a time
395
+ for i , file in enumerate (checkpoint_files ):
396
+ idx = int (file .stem .split ("." )[0 ])
397
+ timestep = int (file .stem .split ("." )[1 ])
398
+ timesteps [i ] = timestep
399
+
400
+ if self .ema_models is not None :
401
+ gammas [i ] = self .gammas [idx ]
402
+ else :
403
+ # Load gamma from checkpoint
404
+ checkpoint = torch .load (
405
+ str (file ), weights_only = True , map_location = "cpu"
406
+ )
407
+ sigma_rel = checkpoint .get ("sigma_rel" , None )
408
+ if sigma_rel is not None :
409
+ gammas [i ] = sigma_rel_to_gamma (sigma_rel )
410
+ else :
411
+ gammas [i ] = self .gammas [idx ]
412
+ del checkpoint # Free memory immediately
413
+ torch .cuda .empty_cache ()
387
414
388
415
# Solve for weights
389
416
weights = solve_weights (gammas , timesteps , gamma )
390
417
391
- # Load first checkpoint to get state dict structure
392
- first_checkpoint = _safe_torch_load (str (checkpoint_files [0 ]))
393
- state_dict = {}
418
+ # Free memory for gamma and timestep tensors
419
+ del gammas
420
+ del timesteps
421
+ torch .cuda .empty_cache ()
394
422
395
- # Get parameter names from first checkpoint
423
+ # Load first checkpoint to get parameter names
424
+ first_checkpoint = torch .load (
425
+ str (checkpoint_files [0 ]), weights_only = True , map_location = "cpu"
426
+ )
396
427
param_names = {
397
428
k .replace ("ema_model." , "" ): k
398
429
for k in first_checkpoint .keys ()
399
430
if k .startswith ("ema_model." )
400
431
and k .replace ("ema_model." , "" ) not in ("initted" , "step" )
401
432
}
433
+ del first_checkpoint
434
+ torch .cuda .empty_cache ()
402
435
403
- # Process one parameter at a time
404
- for param_name , checkpoint_name in param_names .items ():
405
- param = first_checkpoint [checkpoint_name ]
406
- if not isinstance (param , torch .Tensor ):
407
- continue
436
+ # Initialize state dict with empty tensors
437
+ state_dict = {}
438
+
439
+ # Process one checkpoint at a time
440
+ for file_idx , (file , weight ) in enumerate (zip (checkpoint_files , weights )):
441
+ # Load checkpoint
442
+ checkpoint = torch .load (str (file ), weights_only = True , map_location = "cpu" )
408
443
409
- # Initialize with first weighted contribution
410
- state_dict [param_name ] = param .to (device ) * weights [0 ]
444
+ # Process all parameters from this checkpoint
445
+ for param_name , checkpoint_name in param_names .items ():
446
+ if checkpoint_name not in checkpoint :
447
+ continue
411
448
412
- # Add remaining weighted contributions
413
- for file , weight in zip (checkpoint_files [1 :], weights [1 :]):
414
- checkpoint = _safe_torch_load (str (file ))
415
- param = checkpoint [checkpoint_name ]
416
- if isinstance (param , torch .Tensor ):
417
- state_dict [param_name ].add_ (param .to (device ) * weight )
418
- del checkpoint # Free memory
449
+ param_data = checkpoint [checkpoint_name ]
450
+ if not isinstance (param_data , torch .Tensor ):
451
+ continue
452
+
453
+ if file_idx == 0 :
454
+ # Initialize parameter with first weighted contribution
455
+ state_dict [param_name ] = param_data .to (device ) * weight
456
+ else :
457
+ # Add weighted contribution to existing parameter
458
+ state_dict [param_name ].add_ (param_data .to (device ) * weight )
459
+
460
+ # Free memory for this checkpoint
461
+ del checkpoint
462
+ torch .cuda .empty_cache ()
419
463
420
464
# Free memory
421
- del first_checkpoint
465
+ del weights
466
+ torch .cuda .empty_cache ()
422
467
423
468
try :
424
469
yield state_dict
425
470
finally :
426
471
# Clean up
427
472
del state_dict
473
+ torch .cuda .empty_cache ()
428
474
429
475
def _solve_weights (
430
476
self ,
0 commit comments