@@ -341,12 +341,7 @@ def state_dict(
341
341
gamma = sigma_rel_to_gamma (sigma_rel )
342
342
device = torch .device ("cpu" ) # Keep synthesis on CPU for memory efficiency
343
343
344
- # Get all checkpoints
345
- gammas = []
346
- timesteps = []
347
- checkpoints = []
348
-
349
- # Collect checkpoint info
344
+ # Get all checkpoint files
350
345
if self .ema_models is not None :
351
346
# When we have ema_models, use their indices
352
347
indices = range (len (self .ema_models ))
@@ -358,139 +353,78 @@ def state_dict(
358
353
indices .add (idx )
359
354
indices = sorted (indices )
360
355
361
- # Collect checkpoint info
356
+ # Get checkpoint files and info
357
+ checkpoint_files = []
358
+ gammas = []
359
+ timesteps = []
362
360
for idx in indices :
363
- checkpoint_files = sorted (
361
+ files = sorted (
364
362
self .checkpoint_dir .glob (f"{ idx } .*.pt" ),
365
363
key = lambda p : int (p .stem .split ("." )[1 ]),
366
364
)
367
- for file in checkpoint_files :
365
+ for file in files :
368
366
_ , timestep = map (int , file .stem .split ("." ))
369
- # When we have ema_models, use their gammas
370
367
if self .ema_models is not None :
371
368
gammas .append (self .gammas [idx ])
372
369
else :
373
- # When loading from path, load gamma from checkpoint
370
+ # Load gamma from checkpoint
374
371
checkpoint = _safe_torch_load (str (file ))
375
372
sigma_rel = checkpoint .get ("sigma_rel" , None )
376
373
if sigma_rel is not None :
377
374
gammas .append (sigma_rel_to_gamma (sigma_rel ))
378
375
else :
379
- # If no sigma_rel in checkpoint, use index-based gamma
380
376
gammas .append (self .gammas [idx ])
377
+ del checkpoint # Free memory
381
378
timesteps .append (timestep )
382
- checkpoints .append (file )
379
+ checkpoint_files .append (file )
383
380
384
381
if not gammas :
385
- raise ValueError ("No valid gamma values found in checkpoints" )
386
-
387
- # Use latest step if not specified
388
- step = step if step is not None else max (timesteps )
389
- assert step <= max (
390
- timesteps
391
- ), f"Cannot synthesize for step { step } > max available step { max (timesteps )} "
382
+ raise ValueError ("No checkpoints found" )
392
383
393
- # Solve for optimal weights using double precision
394
- gamma_i = torch .tensor (gammas , device = device , dtype = torch .float64 )
395
- t_i = torch .tensor (timesteps , device = device , dtype = torch .float64 )
396
- gamma_r = torch .tensor ([gamma ], device = device , dtype = torch .float64 )
397
- t_r = torch .tensor ([step ], device = device , dtype = torch .float64 )
384
+ # Convert to tensors
385
+ gammas = torch .tensor (gammas , device = device )
386
+ timesteps = torch .tensor (timesteps , device = device )
398
387
399
- weights = self . _solve_weights ( t_i , gamma_i , t_r , gamma_r )
400
- weights = weights . squeeze ( - 1 ). to ( dtype = torch . float64 ) # Keep in float64
388
+ # Solve for weights
389
+ weights = solve_weights ( gammas , timesteps , gamma )
401
390
402
391
# Load first checkpoint to get state dict structure
403
- ckpt = _safe_torch_load (str (checkpoints [0 ]), map_location = device )
392
+ first_checkpoint = _safe_torch_load (str (checkpoint_files [0 ]))
393
+ state_dict = {}
394
+
395
+ # Get parameter names from first checkpoint
396
+ param_names = {
397
+ k .replace ("ema_model." , "" ): k
398
+ for k in first_checkpoint .keys ()
399
+ if k .startswith ("ema_model." )
400
+ and k .replace ("ema_model." , "" ) not in ("initted" , "step" )
401
+ }
404
402
405
- # Extract model parameters, handling both formats and filtering out internal state
406
- model_keys = {}
407
- for k in ckpt .keys ():
408
- # Skip internal EMA tracking variables
409
- if k in ("initted" , "step" , "sigma_rel" ):
403
+ # Process one parameter at a time
404
+ for param_name , checkpoint_name in param_names .items ():
405
+ param = first_checkpoint [checkpoint_name ]
406
+ if not isinstance (param , torch .Tensor ):
410
407
continue
411
- if k .startswith ("ema_model." ):
412
- # Reference format: "ema_model.weight" -> "weight"
413
- model_keys [k ] = k .replace ("ema_model." , "" )
414
- else :
415
- # Our format: "weight" -> "weight"
416
- model_keys [k ] = k
417
-
418
- # Zero initialize synthesized state with double precision
419
- synth_state = {}
420
- original_dtypes = {} # Store original dtypes
421
- for ref_key , our_key in model_keys .items ():
422
- if ref_key in ckpt :
423
- original_dtypes [our_key ] = ckpt [ref_key ].dtype
424
- synth_state [our_key ] = torch .zeros_like (
425
- ckpt [ref_key ], device = device , dtype = torch .float64
426
- )
427
- elif our_key in ckpt :
428
- original_dtypes [our_key ] = ckpt [our_key ].dtype
429
- synth_state [our_key ] = torch .zeros_like (
430
- ckpt [our_key ], device = device , dtype = torch .float64
431
- )
432
408
433
- # Combine checkpoints using solved weights
434
- for checkpoint , weight in zip (checkpoints , weights .tolist ()):
435
- ckpt_state = _safe_torch_load (str (checkpoint ), map_location = device )
436
- for ref_key , our_key in model_keys .items ():
437
- if ref_key in ckpt_state :
438
- ckpt_tensor = ckpt_state [ref_key ]
439
- elif our_key in ckpt_state :
440
- ckpt_tensor = ckpt_state [our_key ]
441
- else :
442
- continue
443
- # Convert checkpoint tensor to double precision
444
- ckpt_tensor = ckpt_tensor .to (dtype = torch .float64 , device = device )
445
- # Use double precision for accumulation
446
- synth_state [our_key ].add_ (ckpt_tensor * weight )
447
-
448
- # Convert final state to target dtype and filter out internal state
449
- # Only include parameters with requires_grad=True and buffers
450
- if self .ema_models is not None :
451
- # When we have ema_models, use their parameter names
452
- param_names = {
453
- name for name , param in self .ema_models [0 ].ema_model .named_parameters ()
454
- }
455
- if self .only_save_diff :
456
- param_names = {
457
- name
458
- for name in param_names
459
- if self .ema_models [0 ].ema_model .get_parameter (name ).requires_grad
460
- }
461
- buffer_names = {
462
- name for name , _ in self .ema_models [0 ].ema_model .named_buffers ()
463
- }
464
- else :
465
- # When loading from path, we can't filter by requires_grad
466
- # since we don't have access to the original model
467
- param_names = set ()
468
- buffer_names = set ()
469
-
470
- synth_state = {
471
- k : v .to (
472
- dtype = (
473
- self .checkpoint_dtype
474
- if self .checkpoint_dtype is not None
475
- else original_dtypes [k ]
476
- ),
477
- device = device ,
478
- )
479
- for k , v in synth_state .items ()
480
- if k not in ("initted" , "step" , "sigma_rel" ) # Filter out internal state
481
- and (
482
- not param_names # If we don't have param_names, include everything
483
- or k in param_names # Include parameters with requires_grad
484
- or k in buffer_names # Include all buffers
485
- )
486
- }
409
+ # Initialize with first weighted contribution
410
+ state_dict [param_name ] = param .to (device ) * weights [0 ]
411
+
412
+ # Add remaining weighted contributions
413
+ for file , weight in zip (checkpoint_files [1 :], weights [1 :]):
414
+ checkpoint = _safe_torch_load (str (file ))
415
+ param = checkpoint [checkpoint_name ]
416
+ if isinstance (param , torch .Tensor ):
417
+ state_dict [param_name ].add_ (param .to (device ) * weight )
418
+ del checkpoint # Free memory
419
+
420
+ # Free memory
421
+ del first_checkpoint
487
422
488
423
try :
489
- yield synth_state
424
+ yield state_dict
490
425
finally :
491
- # Clean up tensors
492
- del synth_state
493
- torch .cuda .empty_cache ()
426
+ # Clean up
427
+ del state_dict
494
428
495
429
def _solve_weights (
496
430
self ,
0 commit comments