3
3
from contextlib import contextmanager
4
4
from copy import deepcopy
5
5
from pathlib import Path
6
- from typing import Iterator , Optional , Generator
6
+ from typing import Iterator , Optional , Generator , Dict
7
7
8
8
import torch
9
9
from PIL import Image
@@ -28,6 +28,7 @@ class PostHocEMA:
28
28
update_every: Number of steps between EMA updates
29
29
checkpoint_every: Number of steps between checkpoints
30
30
checkpoint_dtype: Data type for checkpoint storage (if None, uses original parameter dtype)
31
+ calculation_dtype: Data type for synthesis calculations (default=torch.float32)
31
32
only_save_diff: If True, only save parameters with requires_grad=True
32
33
"""
33
34
@@ -39,6 +40,7 @@ def __init__(
39
40
update_every : int = 10 ,
40
41
checkpoint_every : int = 1000 ,
41
42
checkpoint_dtype : Optional [torch .dtype ] = None ,
43
+ calculation_dtype : torch .dtype = torch .float32 ,
42
44
only_save_diff : bool = False ,
43
45
):
44
46
if sigma_rels is None :
@@ -47,6 +49,7 @@ def __init__(
47
49
self .checkpoint_dir = Path (checkpoint_dir )
48
50
self .max_checkpoints = max_checkpoints
49
51
self .checkpoint_dtype = checkpoint_dtype
52
+ self .calculation_dtype = calculation_dtype
50
53
self .update_every = update_every
51
54
self .checkpoint_every = checkpoint_every
52
55
self .only_save_diff = only_save_diff
@@ -67,6 +70,7 @@ def from_model(
67
70
update_every : int = 10 ,
68
71
checkpoint_every : int = 1000 ,
69
72
checkpoint_dtype : Optional [torch .dtype ] = None ,
73
+ calculation_dtype : torch .dtype = torch .float32 ,
70
74
only_save_diff : bool = False ,
71
75
) -> PostHocEMA :
72
76
"""
@@ -80,6 +84,7 @@ def from_model(
80
84
update_every: Number of steps between EMA updates
81
85
checkpoint_every: Number of steps between checkpoints
82
86
checkpoint_dtype: Data type for checkpoint storage (if None, uses original parameter dtype)
87
+ calculation_dtype: Data type for synthesis calculations (default=torch.float32)
83
88
only_save_diff: If True, only save parameters with requires_grad=True
84
89
85
90
Returns:
@@ -92,6 +97,7 @@ def from_model(
92
97
update_every = update_every ,
93
98
checkpoint_every = checkpoint_every ,
94
99
checkpoint_dtype = checkpoint_dtype ,
100
+ calculation_dtype = calculation_dtype ,
95
101
only_save_diff = only_save_diff ,
96
102
)
97
103
instance .checkpoint_dir .mkdir (exist_ok = True , parents = True )
@@ -291,13 +297,16 @@ def model(
291
297
self ,
292
298
model : nn .Module ,
293
299
sigma_rel : float ,
300
+ * ,
301
+ calculation_dtype : torch .dtype = torch .float32 ,
294
302
) -> Iterator [nn .Module ]:
295
303
"""
296
304
Context manager for temporarily setting model parameters to EMA state.
297
305
298
306
Args:
299
307
model: Model to temporarily set to EMA state
300
308
sigma_rel: Target relative standard deviation
309
+ calculation_dtype: Data type for synthesis calculations (default=torch.float32)
301
310
302
311
Yields:
303
312
nn.Module: Model with EMA parameters
@@ -308,7 +317,9 @@ def model(
308
317
torch .cuda .empty_cache ()
309
318
310
319
try :
311
- with self .state_dict (sigma_rel = sigma_rel ) as state_dict :
320
+ with self .state_dict (
321
+ sigma_rel , calculation_dtype = calculation_dtype
322
+ ) as state_dict :
312
323
# Store original state only for parameters that will be modified
313
324
original_state = {
314
325
name : param .detach ().clone ()
@@ -340,14 +351,15 @@ def model(
340
351
def state_dict (
341
352
self ,
342
353
sigma_rel : float ,
343
- step : int | None = None ,
344
- ) -> Iterator [dict [str , torch .Tensor ]]:
354
+ * ,
355
+ calculation_dtype : torch .dtype = torch .float32 ,
356
+ ) -> Iterator [Dict [str , torch .Tensor ]]:
345
357
"""
346
358
Context manager for getting state dict for synthesized EMA model.
347
359
348
360
Args:
349
361
sigma_rel: Target relative standard deviation
350
- step: Optional specific training step to synthesize for
362
+ calculation_dtype: Data type for synthesis calculations (default=torch.float32)
351
363
352
364
Yields:
353
365
dict[str, torch.Tensor]: State dict with synthesized weights
@@ -387,8 +399,8 @@ def state_dict(
387
399
if total_checkpoints == 0 :
388
400
raise ValueError ("No checkpoints found" )
389
401
390
- # Pre-allocate tensors
391
- gammas = torch .empty (total_checkpoints , device = device )
402
+ # Pre-allocate tensors in calculation dtype
403
+ gammas = torch .empty (total_checkpoints , dtype = calculation_dtype , device = device )
392
404
timesteps = torch .empty (total_checkpoints , dtype = torch .long , device = device )
393
405
394
406
# Fill tensors one value at a time
@@ -412,15 +424,20 @@ def state_dict(
412
424
del checkpoint # Free memory immediately
413
425
torch .cuda .empty_cache ()
414
426
415
- # Solve for weights
416
- weights = solve_weights (gammas , timesteps , gamma )
427
+ # Solve for weights in calculation dtype
428
+ weights = solve_weights (
429
+ gammas ,
430
+ timesteps ,
431
+ gamma ,
432
+ calculation_dtype = calculation_dtype ,
433
+ )
417
434
418
435
# Free memory for gamma and timestep tensors
419
436
del gammas
420
437
del timesteps
421
438
torch .cuda .empty_cache ()
422
439
423
- # Load first checkpoint to get parameter names
440
+ # Load first checkpoint to get parameter names and original dtypes
424
441
first_checkpoint = torch .load (
425
442
str (checkpoint_files [0 ]), weights_only = True , map_location = "cpu"
426
443
)
@@ -430,6 +447,12 @@ def state_dict(
430
447
if k .startswith ("ema_model." )
431
448
and k .replace ("ema_model." , "" ) not in ("initted" , "step" )
432
449
}
450
+ # Store original dtypes for each parameter
451
+ param_dtypes = {
452
+ name : first_checkpoint [checkpoint_name ].dtype
453
+ for name , checkpoint_name in param_names .items ()
454
+ if isinstance (first_checkpoint [checkpoint_name ], torch .Tensor )
455
+ }
433
456
del first_checkpoint
434
457
torch .cuda .empty_cache ()
435
458
@@ -450,6 +473,9 @@ def state_dict(
450
473
if not isinstance (param_data , torch .Tensor ):
451
474
continue
452
475
476
+ # Convert to calculation dtype for synthesis
477
+ param_data = param_data .to (calculation_dtype )
478
+
453
479
if file_idx == 0 :
454
480
# Initialize parameter with first weighted contribution
455
481
state_dict [param_name ] = param_data .to (device ) * weight
@@ -461,6 +487,11 @@ def state_dict(
461
487
del checkpoint
462
488
torch .cuda .empty_cache ()
463
489
490
+ # Convert back to original dtypes
491
+ for name , tensor in state_dict .items ():
492
+ if name in param_dtypes :
493
+ state_dict [name ] = tensor .to (param_dtypes [name ])
494
+
464
495
# Free memory
465
496
del weights
466
497
torch .cuda .empty_cache ()
0 commit comments