Skip to content

Commit 41160c6

Browse files
committed
improve: support larger sigma rel, raise if loading incorrectly, and more
1 parent 43592fa commit 41160c6

10 files changed

+1248
-84
lines changed

notebooks/visualize_error.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import posthoc_ema.visualization
44

55
posthoc_ema.visualization.reconstruction_error(
6-
sigma_rels=(0.15, 0.28),
7-
target_sigma_rel_range=(0.1, 0.3),
8-
max_checkpoints=50,
6+
sigma_rels=(0.15, 0.5),
7+
target_sigma_rel_range=(0.05, 0.5),
8+
max_checkpoints=20,
99
)
1010
# %%

posthoc_ema/karras_ema.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,18 @@ def inplace_lerp(tgt: Tensor, src: Tensor, weight):
4343

4444
class KarrasEMA(Module):
4545
"""
46-
Exponential Moving Average module using hyperparameters from the Karras et al. paper.
46+
Karras EMA implementation with power function decay profile.
4747
4848
Args:
49-
model: The model to create an EMA of
50-
sigma_rel: Relative standard deviation for EMA profile width
51-
gamma: Direct gamma parameter (alternative to sigma_rel)
49+
model: Model to create EMA of
50+
sigma_rel: Relative standard deviation for EMA profile
51+
gamma: Alternative parameterization via gamma (don't specify both)
5252
ema_model: Optional pre-initialized EMA model
5353
update_every: Number of steps between EMA updates
54-
frozen: If True, EMA weights are not updated
55-
param_or_buffer_names_no_ema: Set of parameter/buffer names to exclude from EMA
56-
ignore_names: Set of names to ignore
57-
ignore_startswith_names: Set of name prefixes to ignore
54+
frozen: Whether to freeze EMA updates
55+
param_or_buffer_names_no_ema: Parameter/buffer names to exclude from EMA
56+
ignore_names: Parameter/buffer names to ignore
57+
ignore_startswith_names: Parameter/buffer name prefixes to ignore
5858
only_save_diff: If True, only save parameters with requires_grad=True
5959
"""
6060

@@ -111,12 +111,11 @@ def __init__(
111111
# Move model back to original device
112112
model.to(original_device)
113113

114-
# Get parameter names that require gradients
114+
# Get parameter names for floating point or complex parameters
115115
self.param_names = {
116116
name
117117
for name, param in self.ema_model.named_parameters()
118-
if (not only_save_diff or param.requires_grad)
119-
and (torch.is_floating_point(param) or torch.is_complex(param))
118+
if torch.is_floating_point(param) or torch.is_complex(param)
120119
}
121120

122121
# Get buffer names for floating point or complex buffers
@@ -161,17 +160,27 @@ def update(self):
161160

162161
def copy_params_from_model_to_ema(self):
163162
"""Copy parameters from online model to EMA model."""
163+
# Copy parameters
164164
for (name, ma_params), (_, current_params) in zip(
165165
self.get_params_iter(self.ema_model),
166166
self.get_params_iter(self.online_model[0]),
167167
):
168168
if self._should_update_param(name):
169169
inplace_copy(ma_params.data, current_params.data)
170170

171+
# Copy buffers
172+
for (name, ma_buffer), (_, current_buffer) in zip(
173+
self.get_buffers_iter(self.ema_model),
174+
self.get_buffers_iter(self.online_model[0]),
175+
):
176+
if self._should_update_param(name):
177+
inplace_copy(ma_buffer.data, current_buffer.data)
178+
171179
def update_moving_average(self):
172180
"""Update EMA weights using current beta value."""
173181
current_decay = self.beta
174182

183+
# Update parameters
175184
for (name, current_params), (_, ma_params) in zip(
176185
self.get_params_iter(self.online_model[0]),
177186
self.get_params_iter(self.ema_model),
@@ -180,6 +189,15 @@ def update_moving_average(self):
180189
continue
181190
inplace_lerp(ma_params.data, current_params.data, 1.0 - current_decay)
182191

192+
# Update buffers
193+
for (name, current_buffer), (_, ma_buffer) in zip(
194+
self.get_buffers_iter(self.online_model[0]),
195+
self.get_buffers_iter(self.ema_model),
196+
):
197+
if not self._should_update_param(name):
198+
continue
199+
inplace_lerp(ma_buffer.data, current_buffer.data, 1.0 - current_decay)
200+
183201
def _should_update_param(self, name: str) -> bool:
184202
"""Check if parameter should be updated based on ignore rules."""
185203
if name in self.ignore_names:
@@ -195,8 +213,19 @@ def get_params_iter(self, model):
195213
for name, param in model.named_parameters():
196214
if name not in self.param_names:
197215
continue
216+
if self.only_save_diff and not param.requires_grad:
217+
continue
198218
yield name, param
199219

220+
def get_buffers_iter(self, model):
221+
"""Get iterator over model's buffers."""
222+
for name, buffer in model.named_buffers():
223+
if name not in self.buffer_names:
224+
continue
225+
if self.only_save_diff and not buffer.requires_grad:
226+
continue
227+
yield name, buffer
228+
200229
def iter_all_ema_params_and_buffers(self):
201230
"""Get iterator over all EMA parameters and buffers."""
202231
for name, param in self.ema_model.named_parameters():
@@ -250,24 +279,26 @@ def __call__(self, *args, **kwargs):
250279
return self.ema_model(*args, **kwargs)
251280

252281
def state_dict(self):
253-
"""Get state dict of EMA model."""
282+
"""Get state dict for EMA model."""
254283
state_dict = {}
255284

256-
# Add parameters based on only_save_diff flag
285+
# Save parameters based on only_save_diff flag
257286
for name, param in self.ema_model.named_parameters():
258-
if (not self.only_save_diff or param.requires_grad) and (
259-
torch.is_floating_point(param) or torch.is_complex(param)
260-
):
261-
state_dict[name] = param
287+
if name not in self.param_names:
288+
continue
289+
if self.only_save_diff and not param.requires_grad:
290+
continue
291+
state_dict[name] = param.data
262292

263-
# Add buffers (always included regardless of only_save_diff)
293+
# Save buffers
264294
for name, buffer in self.ema_model.named_buffers():
265-
if torch.is_floating_point(buffer) or torch.is_complex(buffer):
266-
state_dict[name] = buffer
295+
if name not in self.buffer_names:
296+
continue
297+
state_dict[name] = buffer.data
267298

268-
# Add internal state
269-
state_dict["initted"] = self.initted
270-
state_dict["step"] = self.step
299+
# Save internal state
300+
state_dict["initted"] = self.initted.data
301+
state_dict["step"] = self.step.data
271302

272303
return state_dict
273304

posthoc_ema/posthoc_ema.py

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,19 @@ def from_model(
8989
9090
Returns:
9191
PostHocEMA: Instance ready for training
92+
93+
Raises:
94+
ValueError: If checkpoint directory already exists and contains checkpoints
9295
"""
96+
checkpoint_dir = Path(checkpoint_dir)
97+
if checkpoint_dir.exists():
98+
checkpoints = list(checkpoint_dir.glob("*.pt"))
99+
if checkpoints:
100+
raise ValueError(
101+
f"Checkpoint directory {checkpoint_dir} already contains checkpoints. "
102+
"Use from_path() to load existing checkpoints instead of from_model()."
103+
)
104+
93105
instance = cls(
94106
checkpoint_dir=checkpoint_dir,
95107
max_checkpoints=max_checkpoints,
@@ -242,32 +254,35 @@ def _create_checkpoint(self) -> None:
242254
# Create checkpoint file
243255
checkpoint_file = self.checkpoint_dir / f"{idx}.{self.step}.pt"
244256

245-
# Get parameter and buffer names
246-
param_names = {
247-
name for name, param in ema_model.ema_model.named_parameters()
248-
}
257+
# Get state dict from EMA model
258+
state_dict = ema_model.state_dict()
259+
260+
# Filter parameters based on only_save_diff
249261
if self.only_save_diff:
250-
param_names = {
251-
name
252-
for name in param_names
253-
if ema_model.ema_model.get_parameter(name).requires_grad
262+
filtered_state_dict = {}
263+
for name, param in ema_model.ema_model.named_parameters():
264+
if param.requires_grad:
265+
key = name
266+
if key in state_dict:
267+
filtered_state_dict[key] = state_dict[key]
268+
# Add buffers and internal state
269+
for name, buffer in ema_model.ema_model.named_buffers():
270+
key = name
271+
if key in state_dict:
272+
filtered_state_dict[key] = state_dict[key]
273+
for key in ["initted", "step"]:
274+
if key in state_dict:
275+
filtered_state_dict[key] = state_dict[key]
276+
state_dict = filtered_state_dict
277+
278+
# Convert to checkpoint dtype if specified
279+
if self.checkpoint_dtype is not None:
280+
state_dict = {
281+
k: v.to(self.checkpoint_dtype) if isinstance(v, torch.Tensor) else v
282+
for k, v in state_dict.items()
254283
}
255-
buffer_names = {name for name, _ in ema_model.ema_model.named_buffers()}
256-
257-
# Save EMA model state with correct dtype and ema_model prefix
258-
state_dict = {
259-
f"ema_model.{k}": (
260-
v.to(self.checkpoint_dtype)
261-
if self.checkpoint_dtype is not None
262-
else v
263-
)
264-
for k, v in ema_model.state_dict().items()
265-
if (
266-
k in param_names # Include parameters based on only_save_diff
267-
or k in buffer_names # Include all buffers
268-
or k in ("initted", "step") # Include internal state
269-
)
270-
}
284+
285+
# Save checkpoint
271286
torch.save(state_dict, checkpoint_file)
272287

273288
# Remove old checkpoints if needed
@@ -401,13 +416,15 @@ def state_dict(
401416

402417
# Pre-allocate tensors in calculation dtype
403418
gammas = torch.empty(total_checkpoints, dtype=calculation_dtype, device=device)
404-
timesteps = torch.empty(total_checkpoints, dtype=torch.long, device=device)
419+
timesteps = torch.empty(
420+
total_checkpoints, dtype=calculation_dtype, device=device
421+
)
405422

406423
# Fill tensors one value at a time
407424
for i, file in enumerate(checkpoint_files):
408425
idx = int(file.stem.split(".")[0])
409426
timestep = int(file.stem.split(".")[1])
410-
timesteps[i] = timestep
427+
timesteps[i] = float(timestep) # Convert to float
411428

412429
if self.ema_models is not None:
413430
gammas[i] = self.gammas[idx]
@@ -430,6 +447,7 @@ def state_dict(
430447
timesteps,
431448
gamma,
432449
calculation_dtype=calculation_dtype,
450+
target_sigma_rel=sigma_rel,
433451
)
434452

435453
# Free memory for gamma and timestep tensors
@@ -442,10 +460,7 @@ def state_dict(
442460
str(checkpoint_files[0]), weights_only=True, map_location="cpu"
443461
)
444462
param_names = {
445-
k.replace("ema_model.", ""): k
446-
for k in first_checkpoint.keys()
447-
if k.startswith("ema_model.")
448-
and k.replace("ema_model.", "") not in ("initted", "step")
463+
k: k for k in first_checkpoint.keys() if k not in ("initted", "step")
449464
}
450465
# Store original dtypes for each parameter
451466
param_dtypes = {
@@ -467,6 +482,14 @@ def state_dict(
467482
# Process all parameters from this checkpoint
468483
for param_name, checkpoint_name in param_names.items():
469484
if checkpoint_name not in checkpoint:
485+
# If parameter is missing from checkpoint but we're not in only_save_diff mode,
486+
# or if it's a parameter with requires_grad=True, this is an error
487+
if not self.only_save_diff:
488+
raise ValueError(
489+
f"Parameter {param_name} missing from checkpoint {file} "
490+
"but only_save_diff=False"
491+
)
492+
# Skip parameters that are intentionally not saved in only_save_diff mode
470493
continue
471494

472495
param_data = checkpoint[checkpoint_name]

0 commit comments

Comments
 (0)