Skip to content

Commit 12e0969

Browse files
committed
fix: synthesize model when only saving diff
1 parent 46b06ea commit 12e0969

File tree

4 files changed

+159
-124
lines changed

4 files changed

+159
-124
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,5 @@ dmypy.json
127127

128128
# Pyre type checker
129129
.pyre/
130+
131+
/test_ema_checkpoint

posthoc_ema/posthoc_ema.py

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
from copy import deepcopy
55
from pathlib import Path
6-
from typing import Iterator, Optional
6+
from typing import Iterator, Optional, Generator
77

88
import torch
99
from PIL import Image
@@ -285,41 +285,41 @@ def _cleanup_old_checkpoints(self) -> None:
285285

286286
@contextmanager
287287
def model(
288-
self,
289-
base_model: nn.Module,
290-
sigma_rel: float,
291-
step: int | None = None,
292-
) -> Iterator[nn.Module]:
293-
"""
294-
Context manager for using synthesized EMA model.
288+
self, model: nn.Module, sigma_rel: float
289+
) -> Generator[nn.Module, None, None]:
290+
"""Context manager that temporarily sets model parameters to EMA state.
295291
296292
Args:
297-
base_model: Model to apply EMA weights to
293+
model: Model to update
298294
sigma_rel: Target relative standard deviation
299-
step: Optional specific training step to synthesize for
300295
301-
Yields:
302-
nn.Module: Model with synthesized EMA weights
296+
Returns:
297+
Model with EMA parameters
303298
"""
304-
# Store original device and move base model to CPU
305-
original_device = next(base_model.parameters()).device
306-
base_model.cpu()
299+
# Store original device and move model to CPU
300+
original_device = next(model.parameters()).device
301+
model.cpu()
307302
torch.cuda.empty_cache()
308303

309-
# Get state dict and create EMA model
310-
with self.state_dict(sigma_rel=sigma_rel, step=step) as state_dict:
311-
ema_model = deepcopy(base_model)
312-
ema_model.load_state_dict(state_dict)
313-
314-
try:
304+
try:
305+
with self.state_dict(sigma_rel=sigma_rel) as state_dict:
306+
ema_model = deepcopy(model)
307+
result = ema_model.load_state_dict(
308+
state_dict, strict=not self.only_save_diff
309+
)
310+
assert (
311+
len(result.unexpected_keys) == 0
312+
), f"Unexpected keys: {result.unexpected_keys}"
313+
ema_model.eval() # Set to eval mode to handle BatchNorm
315314
yield ema_model
316-
finally:
317-
# Clean up EMA model and restore base model device
315+
# Clean up EMA model
318316
if hasattr(ema_model, "cuda"):
319317
ema_model.cpu()
320318
del ema_model
321-
base_model.to(original_device)
322-
torch.cuda.empty_cache()
319+
finally:
320+
# Restore model to original device
321+
model.to(original_device)
322+
torch.cuda.empty_cache()
323323

324324
@contextmanager
325325
def state_dict(

posthoc_ema/utils.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -80,34 +80,18 @@ def p_dot_p(t_a: Tensor, gamma_a: Tensor, t_b: Tensor, gamma_b: Tensor) -> Tenso
8080
t_ratio = torch.where(
8181
(t_a == 0) & (t_b == 0),
8282
torch.ones_like(t_a),
83-
t_a / torch.where(t_b == 0, torch.ones_like(t_b), t_b)
83+
t_a / torch.where(t_b == 0, torch.ones_like(t_b), t_b),
8484
)
85-
85+
8686
t_exp = torch.where(t_a < t_b, gamma_b, -gamma_a)
8787
t_max = torch.maximum(t_a, t_b)
88-
88+
8989
# Handle t=0 case: if both times are 0, max is 1
90-
t_max = torch.where(
91-
(t_a == 0) & (t_b == 0),
92-
torch.ones_like(t_max),
93-
t_max
94-
)
95-
96-
# Print debug info for first few values
97-
if t_a.shape[0] < 10: # Only print for small tensors
98-
print(f"\nt_ratio shape: {t_ratio.shape}")
99-
print(f"t_ratio first few: {t_ratio[:5, :5]}")
100-
print(f"t_exp first few: {t_exp[:5, :5]}")
101-
print(f"t_max first few: {t_max[:5, :5]}")
102-
90+
t_max = torch.where((t_a == 0) & (t_b == 0), torch.ones_like(t_max), t_max)
91+
10392
num = (gamma_a + 1) * (gamma_b + 1) * t_ratio**t_exp
10493
den = (gamma_a + gamma_b + 1) * t_max
105-
106-
if t_a.shape[0] < 10: # Only print for small tensors
107-
print(f"num first few: {num[:5, :5]}")
108-
print(f"den first few: {den[:5, :5]}")
109-
print(f"result first few: {(num/den)[:5, :5]}")
110-
94+
11195
return num / den
11296

11397

@@ -127,13 +111,13 @@ def solve_weights(t_i: Tensor, gamma_i: Tensor, t_r: Tensor, gamma_r: Tensor) ->
127111
# Reshape tensors for matrix operations
128112
rv = lambda x: x.reshape(-1, 1) # Column vector
129113
cv = lambda x: x.reshape(1, -1) # Row vector
130-
114+
131115
# Compute matrices A and b using p_dot_p
132116
A = p_dot_p(rv(t_i), rv(gamma_i), cv(t_i), cv(gamma_i))
133117
b = p_dot_p(rv(t_i), rv(gamma_i), cv(t_r), cv(gamma_r))
134-
118+
135119
# Solve linear system
136-
return torch.linalg.solve(A, b)
120+
return torch.linalg.solve(A, b)
137121

138122

139123
def _safe_torch_load(path: str | Path, *, map_location=None):
@@ -142,4 +126,4 @@ def _safe_torch_load(path: str | Path, *, map_location=None):
142126
return torch.load(path, map_location=map_location, weights_only=True)
143127
except TypeError:
144128
# Older PyTorch versions don't support weights_only
145-
return torch.load(path, map_location=map_location)
129+
return torch.load(path, map_location=map_location)

0 commit comments

Comments
 (0)