Skip to content

Commit dff4a55

Browse files
committed
improve: optional calculation dtype
1 parent 6784d70 commit dff4a55

File tree

3 files changed

+128
-40
lines changed

3 files changed

+128
-40
lines changed

posthoc_ema/posthoc_ema.py

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from copy import deepcopy
55
from pathlib import Path
6-
from typing import Iterator, Optional, Generator
6+
from typing import Iterator, Optional, Generator, Dict
77

88
import torch
99
from PIL import Image
@@ -28,6 +28,7 @@ class PostHocEMA:
2828
update_every: Number of steps between EMA updates
2929
checkpoint_every: Number of steps between checkpoints
3030
checkpoint_dtype: Data type for checkpoint storage (if None, uses original parameter dtype)
31+
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
3132
only_save_diff: If True, only save parameters with requires_grad=True
3233
"""
3334

@@ -39,6 +40,7 @@ def __init__(
3940
update_every: int = 10,
4041
checkpoint_every: int = 1000,
4142
checkpoint_dtype: Optional[torch.dtype] = None,
43+
calculation_dtype: torch.dtype = torch.float32,
4244
only_save_diff: bool = False,
4345
):
4446
if sigma_rels is None:
@@ -47,6 +49,7 @@ def __init__(
4749
self.checkpoint_dir = Path(checkpoint_dir)
4850
self.max_checkpoints = max_checkpoints
4951
self.checkpoint_dtype = checkpoint_dtype
52+
self.calculation_dtype = calculation_dtype
5053
self.update_every = update_every
5154
self.checkpoint_every = checkpoint_every
5255
self.only_save_diff = only_save_diff
@@ -67,6 +70,7 @@ def from_model(
6770
update_every: int = 10,
6871
checkpoint_every: int = 1000,
6972
checkpoint_dtype: Optional[torch.dtype] = None,
73+
calculation_dtype: torch.dtype = torch.float32,
7074
only_save_diff: bool = False,
7175
) -> PostHocEMA:
7276
"""
@@ -80,6 +84,7 @@ def from_model(
8084
update_every: Number of steps between EMA updates
8185
checkpoint_every: Number of steps between checkpoints
8286
checkpoint_dtype: Data type for checkpoint storage (if None, uses original parameter dtype)
87+
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
8388
only_save_diff: If True, only save parameters with requires_grad=True
8489
8590
Returns:
@@ -92,6 +97,7 @@ def from_model(
9297
update_every=update_every,
9398
checkpoint_every=checkpoint_every,
9499
checkpoint_dtype=checkpoint_dtype,
100+
calculation_dtype=calculation_dtype,
95101
only_save_diff=only_save_diff,
96102
)
97103
instance.checkpoint_dir.mkdir(exist_ok=True, parents=True)
@@ -291,13 +297,16 @@ def model(
291297
self,
292298
model: nn.Module,
293299
sigma_rel: float,
300+
*,
301+
calculation_dtype: torch.dtype = torch.float32,
294302
) -> Iterator[nn.Module]:
295303
"""
296304
Context manager for temporarily setting model parameters to EMA state.
297305
298306
Args:
299307
model: Model to temporarily set to EMA state
300308
sigma_rel: Target relative standard deviation
309+
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
301310
302311
Yields:
303312
nn.Module: Model with EMA parameters
@@ -308,7 +317,9 @@ def model(
308317
torch.cuda.empty_cache()
309318

310319
try:
311-
with self.state_dict(sigma_rel=sigma_rel) as state_dict:
320+
with self.state_dict(
321+
sigma_rel, calculation_dtype=calculation_dtype
322+
) as state_dict:
312323
# Store original state only for parameters that will be modified
313324
original_state = {
314325
name: param.detach().clone()
@@ -340,14 +351,15 @@ def model(
340351
def state_dict(
341352
self,
342353
sigma_rel: float,
343-
step: int | None = None,
344-
) -> Iterator[dict[str, torch.Tensor]]:
354+
*,
355+
calculation_dtype: torch.dtype = torch.float32,
356+
) -> Iterator[Dict[str, torch.Tensor]]:
345357
"""
346358
Context manager for getting state dict for synthesized EMA model.
347359
348360
Args:
349361
sigma_rel: Target relative standard deviation
350-
step: Optional specific training step to synthesize for
362+
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
351363
352364
Yields:
353365
dict[str, torch.Tensor]: State dict with synthesized weights
@@ -387,8 +399,8 @@ def state_dict(
387399
if total_checkpoints == 0:
388400
raise ValueError("No checkpoints found")
389401

390-
# Pre-allocate tensors
391-
gammas = torch.empty(total_checkpoints, device=device)
402+
# Pre-allocate tensors in calculation dtype
403+
gammas = torch.empty(total_checkpoints, dtype=calculation_dtype, device=device)
392404
timesteps = torch.empty(total_checkpoints, dtype=torch.long, device=device)
393405

394406
# Fill tensors one value at a time
@@ -412,15 +424,20 @@ def state_dict(
412424
del checkpoint # Free memory immediately
413425
torch.cuda.empty_cache()
414426

415-
# Solve for weights
416-
weights = solve_weights(gammas, timesteps, gamma)
427+
# Solve for weights in calculation dtype
428+
weights = solve_weights(
429+
gammas,
430+
timesteps,
431+
gamma,
432+
calculation_dtype=calculation_dtype,
433+
)
417434

418435
# Free memory for gamma and timestep tensors
419436
del gammas
420437
del timesteps
421438
torch.cuda.empty_cache()
422439

423-
# Load first checkpoint to get parameter names
440+
# Load first checkpoint to get parameter names and original dtypes
424441
first_checkpoint = torch.load(
425442
str(checkpoint_files[0]), weights_only=True, map_location="cpu"
426443
)
@@ -430,6 +447,12 @@ def state_dict(
430447
if k.startswith("ema_model.")
431448
and k.replace("ema_model.", "") not in ("initted", "step")
432449
}
450+
# Store original dtypes for each parameter
451+
param_dtypes = {
452+
name: first_checkpoint[checkpoint_name].dtype
453+
for name, checkpoint_name in param_names.items()
454+
if isinstance(first_checkpoint[checkpoint_name], torch.Tensor)
455+
}
433456
del first_checkpoint
434457
torch.cuda.empty_cache()
435458

@@ -450,6 +473,9 @@ def state_dict(
450473
if not isinstance(param_data, torch.Tensor):
451474
continue
452475

476+
# Convert to calculation dtype for synthesis
477+
param_data = param_data.to(calculation_dtype)
478+
453479
if file_idx == 0:
454480
# Initialize parameter with first weighted contribution
455481
state_dict[param_name] = param_data.to(device) * weight
@@ -461,6 +487,11 @@ def state_dict(
461487
del checkpoint
462488
torch.cuda.empty_cache()
463489

490+
# Convert back to original dtypes
491+
for name, tensor in state_dict.items():
492+
if name in param_dtypes:
493+
state_dict[name] = tensor.to(param_dtypes[name])
494+
464495
# Free memory
465496
del weights
466497
torch.cuda.empty_cache()

posthoc_ema/utils.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,18 +63,19 @@ def sigma_rel_to_gamma(sigma_rel: float) -> float:
6363
return np.roots([1, 7, 16 - t, 12 - t]).real.max().item()
6464

6565

66-
def p_dot_p(t_a: Tensor, gamma_a: Tensor, t_b: Tensor, gamma_b: Tensor) -> Tensor:
67-
"""
68-
Compute dot product between two power function EMA profiles.
66+
def p_dot_p(
67+
t_a: torch.Tensor, gamma_a: torch.Tensor, t_b: torch.Tensor, gamma_b: torch.Tensor
68+
) -> torch.Tensor:
69+
"""Compute p_dot_p value for EMA synthesis.
6970
7071
Args:
71-
t_a: First timestep tensor
72-
gamma_a: First gamma parameter tensor
73-
t_b: Second timestep tensor
74-
gamma_b: Second gamma parameter tensor
72+
t_a: First timestep
73+
gamma_a: First gamma value
74+
t_b: Second timestep
75+
gamma_b: Second gamma value
7576
7677
Returns:
77-
Tensor: Dot product between the profiles
78+
Tensor: p_dot_p value
7879
"""
7980
# Handle t=0 case: if both times are 0, ratio is 1
8081
t_ratio = torch.where(
@@ -99,27 +100,34 @@ def solve_weights(
99100
gammas: torch.Tensor,
100101
timesteps: torch.Tensor,
101102
target_gamma: float,
103+
*,
104+
calculation_dtype: torch.dtype = torch.float32,
102105
) -> torch.Tensor:
103-
"""
104-
Solve for optimal weights to synthesize EMA model with target gamma.
106+
"""Solve for weights that produce target gamma when applied to gammas.
105107
106108
Args:
107-
gammas: Gamma values for each checkpoint
108-
timesteps: Timesteps for each checkpoint
109+
gammas: Tensor of gamma values
110+
timesteps: Tensor of timesteps
109111
target_gamma: Target gamma value
112+
calculation_dtype: Data type for calculations (default=torch.float32)
110113
111114
Returns:
112-
torch.Tensor: Optimal weights for each checkpoint
115+
Tensor of weights
113116
"""
114-
# Convert to float32 for numerical stability
115-
gammas = gammas.to(dtype=torch.float32)
116-
timesteps = timesteps.to(dtype=torch.float32)
117-
target_gamma = torch.tensor(target_gamma, dtype=torch.float32, device=gammas.device)
117+
# Convert inputs to calculation dtype
118+
gammas = gammas.to(calculation_dtype)
119+
timesteps = timesteps.to(calculation_dtype)
120+
target_gamma = torch.tensor(
121+
target_gamma, dtype=calculation_dtype, device=gammas.device
122+
)
123+
target_timestep = timesteps[-1] # Use last timestep as target
118124

119-
# Compute p_dot_p matrix
120-
p_dot_p_matrix = torch.zeros(
121-
(len(gammas), len(gammas)), dtype=torch.float32, device=gammas.device
125+
# Pre-allocate tensor in calculation dtype
126+
p_dot_p_matrix = torch.empty(
127+
(len(gammas), len(gammas)), dtype=calculation_dtype, device=gammas.device
122128
)
129+
130+
# Compute p_dot_p matrix
123131
for i in range(len(gammas)):
124132
for j in range(len(gammas)):
125133
p_dot_p_matrix[i, j] = p_dot_p(
@@ -129,21 +137,15 @@ def solve_weights(
129137
# Compute target vector
130138
target_vector = torch.tensor(
131139
[
132-
p_dot_p(timesteps[i], gammas[i], timesteps[-1], target_gamma)
140+
p_dot_p(timesteps[i], gammas[i], target_timestep, target_gamma)
133141
for i in range(len(gammas))
134142
],
135-
dtype=torch.float32,
143+
dtype=calculation_dtype,
136144
device=gammas.device,
137145
)
138146

139-
# Solve linear system
140-
try:
141-
weights = torch.linalg.solve(p_dot_p_matrix, target_vector)
142-
except RuntimeError:
143-
# If matrix is singular, use least squares
144-
weights = torch.linalg.lstsq(p_dot_p_matrix, target_vector).solution
145-
146-
return weights
147+
# Solve for weights
148+
return torch.linalg.solve(p_dot_p_matrix, target_vector)
147149

148150

149151
def _safe_torch_load(path: str | Path, *, map_location=None):

tests/test_usage.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,3 +600,58 @@ def test_context_manager_with_only_save_diff():
600600
for file in Path("posthoc-ema").glob("*"):
601601
file.unlink()
602602
Path("posthoc-ema").rmdir()
603+
604+
605+
def test_calculation_dtype():
606+
"""Test that synthesis calculations use specified calculation_dtype."""
607+
# Create a model with mixed dtypes
608+
model = torch.nn.Sequential(
609+
torch.nn.Linear(512, 512), # Default is float32
610+
torch.nn.BatchNorm1d(512, track_running_stats=True),
611+
)
612+
613+
# Convert model to float16
614+
model = model.to(torch.float16)
615+
616+
# Create EMA instance
617+
posthoc_ema = PostHocEMA.from_model(
618+
model,
619+
"posthoc-ema",
620+
checkpoint_every=5,
621+
sigma_rels=(0.05,),
622+
)
623+
624+
# Update model
625+
for _ in range(10):
626+
with torch.no_grad():
627+
model[0].weight.copy_(torch.randn_like(model[0].weight))
628+
model[0].bias.copy_(torch.randn_like(model[0].bias))
629+
posthoc_ema.update_(model)
630+
631+
# Test default behavior (float32 calculations, float16 output)
632+
with posthoc_ema.state_dict(sigma_rel=0.05) as state_dict:
633+
# All parameters should be float16 (original dtype)
634+
assert state_dict["0.weight"].dtype == torch.float16
635+
assert state_dict["0.bias"].dtype == torch.float16
636+
assert state_dict["1.weight"].dtype == torch.float16
637+
assert state_dict["1.bias"].dtype == torch.float16
638+
assert state_dict["1.running_mean"].dtype == torch.float16
639+
assert state_dict["1.running_var"].dtype == torch.float16
640+
641+
# Test float64 behavior
642+
with posthoc_ema.state_dict(
643+
sigma_rel=0.05, calculation_dtype=torch.float64
644+
) as state_dict:
645+
# All parameters should still be float16 (original dtype)
646+
assert state_dict["0.weight"].dtype == torch.float16
647+
assert state_dict["0.bias"].dtype == torch.float16
648+
assert state_dict["1.weight"].dtype == torch.float16
649+
assert state_dict["1.bias"].dtype == torch.float16
650+
assert state_dict["1.running_mean"].dtype == torch.float16
651+
assert state_dict["1.running_var"].dtype == torch.float16
652+
653+
# Clean up
654+
if Path("posthoc-ema").exists():
655+
for file in Path("posthoc-ema").glob("*"):
656+
file.unlink()
657+
Path("posthoc-ema").rmdir()

0 commit comments

Comments
 (0)