Skip to content

Commit f6c4f80

Browse files
committed
feature: step for ema synthesis
1 parent 6ac6ebb commit f6c4f80

File tree

2 files changed

+118
-3
lines changed

2 files changed

+118
-3
lines changed

posthoc_ema/posthoc_ema.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def model(
322322
self,
323323
model: nn.Module,
324324
sigma_rel: float,
325-
*,
325+
step: int | None = None,
326326
calculation_dtype: torch.dtype = torch.float32,
327327
) -> Iterator[nn.Module]:
328328
"""
@@ -331,6 +331,7 @@ def model(
331331
Args:
332332
model: Model to temporarily set to EMA state
333333
sigma_rel: Target relative standard deviation
334+
step: Target training step to synthesize for (defaults to latest available)
334335
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
335336
336337
Yields:
@@ -343,7 +344,9 @@ def model(
343344

344345
try:
345346
with self.state_dict(
346-
sigma_rel, calculation_dtype=calculation_dtype
347+
sigma_rel=sigma_rel,
348+
step=step,
349+
calculation_dtype=calculation_dtype,
347350
) as state_dict:
348351
# Store original state only for parameters that will be modified
349352
original_state = {
@@ -376,14 +379,15 @@ def model(
376379
def state_dict(
377380
self,
378381
sigma_rel: float,
379-
*,
382+
step: int | None = None,
380383
calculation_dtype: torch.dtype = torch.float32,
381384
) -> Iterator[Dict[str, torch.Tensor]]:
382385
"""
383386
Context manager for getting state dict for synthesized EMA model.
384387
385388
Args:
386389
sigma_rel: Target relative standard deviation
390+
step: Target training step to synthesize for (defaults to latest available)
387391
calculation_dtype: Data type for synthesis calculations (default=torch.float32)
388392
389393
Yields:
@@ -424,6 +428,23 @@ def state_dict(
424428
if total_checkpoints == 0:
425429
raise ValueError("No checkpoints found")
426430

431+
# Get all timesteps and find max
432+
timesteps = [int(f.stem.split(".")[1]) for f in checkpoint_files]
433+
max_step = max(timesteps)
434+
435+
# Use provided step or default to max
436+
target_step = max_step if step is None else step
437+
assert target_step <= max_step, (
438+
f"Cannot synthesize for step {target_step} as it is greater than "
439+
f"the maximum available step {max_step}"
440+
)
441+
442+
# Filter checkpoints to only use those up to target_step
443+
checkpoint_files = [
444+
f for f, t in zip(checkpoint_files, timesteps) if t <= target_step
445+
]
446+
total_checkpoints = len(checkpoint_files)
447+
427448
# Pre-allocate tensors in calculation dtype
428449
gammas = torch.empty(total_checkpoints, dtype=calculation_dtype, device=device)
429450
timesteps = torch.empty(

tests/test_same_as_reference.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,3 +382,97 @@ def test_update_after_step():
382382
break
383383

384384
assert weights_changed, "EMA weights did not change after update_after_step"
385+
386+
387+
def test_same_output_as_reference_different_step():
388+
"""Test that our implementation produces identical outputs to the reference when synthesizing at a different step."""
389+
# Create a simple model
390+
net = nn.Linear(512, 512)
391+
392+
# Initialize with same parameters
393+
sigma_rels = (0.03, 0.20)
394+
update_every = 10
395+
checkpoint_every = 10
396+
397+
print("\nInitializing with parameters:")
398+
print(f"sigma_rels: {sigma_rels}")
399+
print(f"update_every: {update_every}")
400+
print(f"checkpoint_every: {checkpoint_every}")
401+
402+
# Create both implementations
403+
ref_emas = ReferencePostHocEMA(
404+
net,
405+
sigma_rels=sigma_rels,
406+
update_every=update_every,
407+
checkpoint_every_num_steps=checkpoint_every,
408+
checkpoint_folder="./test-checkpoints-ref",
409+
checkpoint_dtype=torch.float32,
410+
)
411+
412+
our_emas = OurPostHocEMA.from_model(
413+
model=net,
414+
checkpoint_dir="./test-checkpoints-our",
415+
update_every=update_every,
416+
checkpoint_every=checkpoint_every,
417+
sigma_rels=sigma_rels,
418+
checkpoint_dtype=torch.float32,
419+
update_after_step=0, # Start immediately to match reference behavior
420+
)
421+
422+
# Train both with identical updates
423+
torch.manual_seed(42) # For reproducibility
424+
net.train()
425+
426+
print("\nTraining:")
427+
for step in range(100):
428+
# Apply identical mutations to network
429+
with torch.no_grad():
430+
net.weight.copy_(torch.randn_like(net.weight))
431+
net.bias.copy_(torch.randn_like(net.bias))
432+
433+
# Update both EMA wrappers
434+
ref_emas.update()
435+
our_emas.update_(net)
436+
437+
if step % 10 == 0:
438+
print(f"Step {step}: Updated model and EMAs")
439+
440+
# Synthesize EMA models with same parameters at step 50 (middle of training)
441+
target_sigma = 0.15
442+
target_step = 50
443+
print(f"\nSynthesizing with target_sigma = {target_sigma} at step {target_step}")
444+
445+
# Get reference checkpoints and weights
446+
ref_checkpoints = sorted(Path("./test-checkpoints-ref").glob("*.pt"))
447+
print("\nReference checkpoints:")
448+
for cp in ref_checkpoints:
449+
print(f" {cp.name}")
450+
451+
# Get our checkpoints and weights
452+
our_checkpoints = sorted(Path("./test-checkpoints-our").glob("*.pt"))
453+
print("\nOur checkpoints:")
454+
for cp in our_checkpoints:
455+
print(f" {cp.name}")
456+
457+
ref_synth = ref_emas.synthesize_ema_model(sigma_rel=target_sigma, step=target_step)
458+
459+
with our_emas.model(net, target_sigma, step=target_step) as our_synth:
460+
# Test with same input
461+
data = torch.randn(1, 512)
462+
ref_output = ref_synth(data)
463+
our_output = our_synth(data)
464+
465+
print("\nComparing outputs:")
466+
print(f"Reference output mean: {ref_output.mean().item():.4f}")
467+
print(f"Our output mean: {our_output.mean().item():.4f}")
468+
print(f"Max difference: {(ref_output - our_output).abs().max().item():.4f}")
469+
470+
# Verify outputs match
471+
assert torch.allclose(
472+
ref_output, our_output, rtol=1e-4, atol=1e-4
473+
), "Output from our implementation doesn't match reference"
474+
475+
# Clean up
476+
for path in ["./test-checkpoints-ref", "./test-checkpoints-our"]:
477+
if Path(path).exists():
478+
shutil.rmtree(path)

0 commit comments

Comments
 (0)