Skip to content

Commit 6784d70

Browse files
committed
fix: decrease peak ram usage
1 parent 485021c commit 6784d70

File tree

2 files changed

+172
-86
lines changed

2 files changed

+172
-86
lines changed

posthoc_ema/posthoc_ema.py

Lines changed: 111 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import torch
99
from PIL import Image
1010
from torch import nn
11+
import pickle
12+
import io
13+
import torch.serialization
1114

1215
from .karras_ema import KarrasEMA
1316
from .utils import _safe_torch_load, p_dot_p, sigma_rel_to_gamma, solve_weights
@@ -285,37 +288,49 @@ def _cleanup_old_checkpoints(self) -> None:
285288

286289
@contextmanager
287290
def model(
288-
self, model: nn.Module, sigma_rel: float
289-
) -> Generator[nn.Module, None, None]:
290-
"""Context manager that temporarily sets model parameters to EMA state.
291+
self,
292+
model: nn.Module,
293+
sigma_rel: float,
294+
) -> Iterator[nn.Module]:
295+
"""
296+
Context manager for temporarily setting model parameters to EMA state.
291297
292298
Args:
293-
model: Model to update
299+
model: Model to temporarily set to EMA state
294300
sigma_rel: Target relative standard deviation
295301
296-
Returns:
297-
Model with EMA parameters
302+
Yields:
303+
nn.Module: Model with EMA parameters
298304
"""
299-
# Store original device and move model to CPU
305+
# Move model to CPU for memory efficiency
300306
original_device = next(model.parameters()).device
301307
model.cpu()
302308
torch.cuda.empty_cache()
303309

304310
try:
305311
with self.state_dict(sigma_rel=sigma_rel) as state_dict:
306-
ema_model = deepcopy(model)
307-
result = ema_model.load_state_dict(
312+
# Store original state only for parameters that will be modified
313+
original_state = {
314+
name: param.detach().clone()
315+
for name, param in model.state_dict().items()
316+
if name in state_dict
317+
}
318+
319+
# Load EMA state directly into model
320+
result = model.load_state_dict(
308321
state_dict, strict=not self.only_save_diff
309322
)
310323
assert (
311324
len(result.unexpected_keys) == 0
312325
), f"Unexpected keys: {result.unexpected_keys}"
313-
ema_model.eval() # Set to eval mode to handle BatchNorm
314-
yield ema_model
315-
# Clean up EMA model
316-
if hasattr(ema_model, "cuda"):
317-
ema_model.cpu()
318-
del ema_model
326+
model.eval() # Set to eval mode to handle BatchNorm
327+
yield model
328+
329+
# Restore original state
330+
model.load_state_dict(original_state, strict=False)
331+
del original_state
332+
del state_dict # Free memory for state dict
333+
torch.cuda.empty_cache()
319334
finally:
320335
# Restore model to original device
321336
model.to(original_device)
@@ -341,10 +356,18 @@ def state_dict(
341356
gamma = sigma_rel_to_gamma(sigma_rel)
342357
device = torch.device("cpu") # Keep synthesis on CPU for memory efficiency
343358

344-
# Get all checkpoint files
359+
# First count total checkpoints to pre-allocate tensors
360+
total_checkpoints = 0
361+
checkpoint_files = []
345362
if self.ema_models is not None:
346363
# When we have ema_models, use their indices
347-
indices = range(len(self.ema_models))
364+
for idx in range(len(self.ema_models)):
365+
files = sorted(
366+
self.checkpoint_dir.glob(f"{idx}.*.pt"),
367+
key=lambda p: int(p.stem.split(".")[1]),
368+
)
369+
total_checkpoints += len(files)
370+
checkpoint_files.extend(files)
348371
else:
349372
# When loading from path, find all unique indices
350373
indices = set()
@@ -353,78 +376,101 @@ def state_dict(
353376
indices.add(idx)
354377
indices = sorted(indices)
355378

356-
# Get checkpoint files and info
357-
checkpoint_files = []
358-
gammas = []
359-
timesteps = []
360-
for idx in indices:
361-
files = sorted(
362-
self.checkpoint_dir.glob(f"{idx}.*.pt"),
363-
key=lambda p: int(p.stem.split(".")[1]),
364-
)
365-
for file in files:
366-
_, timestep = map(int, file.stem.split("."))
367-
if self.ema_models is not None:
368-
gammas.append(self.gammas[idx])
369-
else:
370-
# Load gamma from checkpoint
371-
checkpoint = _safe_torch_load(str(file))
372-
sigma_rel = checkpoint.get("sigma_rel", None)
373-
if sigma_rel is not None:
374-
gammas.append(sigma_rel_to_gamma(sigma_rel))
375-
else:
376-
gammas.append(self.gammas[idx])
377-
del checkpoint # Free memory
378-
timesteps.append(timestep)
379-
checkpoint_files.append(file)
380-
381-
if not gammas:
379+
for idx in indices:
380+
files = sorted(
381+
self.checkpoint_dir.glob(f"{idx}.*.pt"),
382+
key=lambda p: int(p.stem.split(".")[1]),
383+
)
384+
total_checkpoints += len(files)
385+
checkpoint_files.extend(files)
386+
387+
if total_checkpoints == 0:
382388
raise ValueError("No checkpoints found")
383389

384-
# Convert to tensors
385-
gammas = torch.tensor(gammas, device=device)
386-
timesteps = torch.tensor(timesteps, device=device)
390+
# Pre-allocate tensors
391+
gammas = torch.empty(total_checkpoints, device=device)
392+
timesteps = torch.empty(total_checkpoints, dtype=torch.long, device=device)
393+
394+
# Fill tensors one value at a time
395+
for i, file in enumerate(checkpoint_files):
396+
idx = int(file.stem.split(".")[0])
397+
timestep = int(file.stem.split(".")[1])
398+
timesteps[i] = timestep
399+
400+
if self.ema_models is not None:
401+
gammas[i] = self.gammas[idx]
402+
else:
403+
# Load gamma from checkpoint
404+
checkpoint = torch.load(
405+
str(file), weights_only=True, map_location="cpu"
406+
)
407+
sigma_rel = checkpoint.get("sigma_rel", None)
408+
if sigma_rel is not None:
409+
gammas[i] = sigma_rel_to_gamma(sigma_rel)
410+
else:
411+
gammas[i] = self.gammas[idx]
412+
del checkpoint # Free memory immediately
413+
torch.cuda.empty_cache()
387414

388415
# Solve for weights
389416
weights = solve_weights(gammas, timesteps, gamma)
390417

391-
# Load first checkpoint to get state dict structure
392-
first_checkpoint = _safe_torch_load(str(checkpoint_files[0]))
393-
state_dict = {}
418+
# Free memory for gamma and timestep tensors
419+
del gammas
420+
del timesteps
421+
torch.cuda.empty_cache()
394422

395-
# Get parameter names from first checkpoint
423+
# Load first checkpoint to get parameter names
424+
first_checkpoint = torch.load(
425+
str(checkpoint_files[0]), weights_only=True, map_location="cpu"
426+
)
396427
param_names = {
397428
k.replace("ema_model.", ""): k
398429
for k in first_checkpoint.keys()
399430
if k.startswith("ema_model.")
400431
and k.replace("ema_model.", "") not in ("initted", "step")
401432
}
433+
del first_checkpoint
434+
torch.cuda.empty_cache()
402435

403-
# Process one parameter at a time
404-
for param_name, checkpoint_name in param_names.items():
405-
param = first_checkpoint[checkpoint_name]
406-
if not isinstance(param, torch.Tensor):
407-
continue
436+
# Initialize state dict with empty tensors
437+
state_dict = {}
438+
439+
# Process one checkpoint at a time
440+
for file_idx, (file, weight) in enumerate(zip(checkpoint_files, weights)):
441+
# Load checkpoint
442+
checkpoint = torch.load(str(file), weights_only=True, map_location="cpu")
408443

409-
# Initialize with first weighted contribution
410-
state_dict[param_name] = param.to(device) * weights[0]
444+
# Process all parameters from this checkpoint
445+
for param_name, checkpoint_name in param_names.items():
446+
if checkpoint_name not in checkpoint:
447+
continue
411448

412-
# Add remaining weighted contributions
413-
for file, weight in zip(checkpoint_files[1:], weights[1:]):
414-
checkpoint = _safe_torch_load(str(file))
415-
param = checkpoint[checkpoint_name]
416-
if isinstance(param, torch.Tensor):
417-
state_dict[param_name].add_(param.to(device) * weight)
418-
del checkpoint # Free memory
449+
param_data = checkpoint[checkpoint_name]
450+
if not isinstance(param_data, torch.Tensor):
451+
continue
452+
453+
if file_idx == 0:
454+
# Initialize parameter with first weighted contribution
455+
state_dict[param_name] = param_data.to(device) * weight
456+
else:
457+
# Add weighted contribution to existing parameter
458+
state_dict[param_name].add_(param_data.to(device) * weight)
459+
460+
# Free memory for this checkpoint
461+
del checkpoint
462+
torch.cuda.empty_cache()
419463

420464
# Free memory
421-
del first_checkpoint
465+
del weights
466+
torch.cuda.empty_cache()
422467

423468
try:
424469
yield state_dict
425470
finally:
426471
# Clean up
427472
del state_dict
473+
torch.cuda.empty_cache()
428474

429475
def _solve_weights(
430476
self,

tests/test_vram_usage.py

Lines changed: 61 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from pathlib import Path
2+
import time
23

34
import psutil
45
import torch
@@ -73,25 +74,21 @@ def reset_peak_ram():
7374
_peak_ram_usage = get_ram_usage()
7475

7576

76-
def monitor_operation(operation, interval=0.001):
77+
def monitor_operation(operation):
7778
"""Monitor RAM usage during an operation."""
78-
import threading
79-
import time
80-
81-
stop_monitoring = threading.Event()
82-
83-
def monitor():
84-
while not stop_monitoring.is_set():
85-
get_ram_usage() # This will update peak RAM
86-
time.sleep(interval)
87-
88-
monitor_thread = threading.Thread(target=monitor)
89-
monitor_thread.start()
90-
try:
91-
result = operation()
92-
finally:
93-
stop_monitoring.set()
94-
monitor_thread.join()
79+
# Get RAM usage before operation
80+
pre_ram = get_ram_usage()
81+
82+
# Run operation
83+
result = operation()
84+
85+
# Get RAM usage after operation
86+
post_ram = get_ram_usage()
87+
88+
# Update peak RAM if needed
89+
global _peak_ram_usage
90+
_peak_ram_usage = max(_peak_ram_usage, pre_ram, post_ram)
91+
9592
return result
9693

9794

@@ -352,25 +349,68 @@ def test_synthesis_memory_usage():
352349

353350
# Monitor synthesis memory usage
354351
print("\nStarting synthesis...")
352+
353+
# First test: model context manager
354+
print("\nTesting model context manager synthesis:")
355355
pre_synthesis_ram = get_ram_usage()
356356
reset_peak_ram()
357357

358-
def synthesize():
358+
def synthesize_with_model():
359359
with posthoc_ema.model(model, sigma_rel=0.15) as ema_model:
360360
# Force a full synthesis by accessing parameters
361361
for param in ema_model.parameters():
362362
_ = param.shape
363363

364-
monitor_operation(synthesize)
364+
monitor_operation(synthesize_with_model)
365+
366+
post_synthesis_ram = get_ram_usage()
367+
peak_synthesis_ram = get_peak_ram_usage()
368+
print(f"Pre-synthesis RAM: {pre_synthesis_ram:.2f}MB")
369+
print(f"Post-synthesis RAM: {post_synthesis_ram:.2f}MB")
370+
print(f"Peak synthesis RAM: {peak_synthesis_ram:.2f}MB")
371+
print(f"RAM increase: {post_synthesis_ram - pre_synthesis_ram:.2f}MB")
372+
print(f"RAM spike: {peak_synthesis_ram - pre_synthesis_ram:.2f}MB")
373+
print(f"Relative peak RAM: {(peak_synthesis_ram/pre_synthesis_ram)*100:.1f}%")
374+
375+
# Assert RAM usage is within limits
376+
assert peak_synthesis_ram <= pre_synthesis_ram * 2.5, (
377+
f"Synthesis caused excessive RAM usage. "
378+
f"Peak RAM ({peak_synthesis_ram:.2f}MB) was more than 2.5x "
379+
f"pre-synthesis RAM ({pre_synthesis_ram:.2f}MB)"
380+
)
381+
382+
# Second test: state_dict synthesis
383+
print("\nTesting state_dict synthesis:")
384+
pre_synthesis_ram = get_ram_usage()
385+
reset_peak_ram()
386+
387+
synthesis_start = time.perf_counter()
388+
389+
def synthesize_with_state_dict():
390+
with posthoc_ema.state_dict(sigma_rel=0.15) as ema_state_dict:
391+
# Force processing by accessing dict
392+
for param in ema_state_dict.values():
393+
_ = param.shape
394+
395+
monitor_operation(synthesize_with_state_dict)
396+
synthesis_end = time.perf_counter()
365397

366398
post_synthesis_ram = get_ram_usage()
367399
peak_synthesis_ram = get_peak_ram_usage()
368-
print(f"\nSynthesis memory usage:")
369400
print(f"Pre-synthesis RAM: {pre_synthesis_ram:.2f}MB")
370401
print(f"Post-synthesis RAM: {post_synthesis_ram:.2f}MB")
371402
print(f"Peak synthesis RAM: {peak_synthesis_ram:.2f}MB")
372403
print(f"RAM increase: {post_synthesis_ram - pre_synthesis_ram:.2f}MB")
373404
print(f"RAM spike: {peak_synthesis_ram - pre_synthesis_ram:.2f}MB")
405+
print(f"Relative peak RAM: {(peak_synthesis_ram/pre_synthesis_ram)*100:.1f}%")
406+
print(f"Synthesis time: {synthesis_end - synthesis_start:.2f} seconds")
407+
408+
# Assert RAM usage is within limits
409+
assert peak_synthesis_ram <= pre_synthesis_ram * 2.5, (
410+
f"Synthesis caused excessive RAM usage. "
411+
f"Peak RAM ({peak_synthesis_ram:.2f}MB) was more than 2.5x "
412+
f"pre-synthesis RAM ({pre_synthesis_ram:.2f}MB)"
413+
)
374414

375415
# Cleanup
376416
model.cpu()

0 commit comments

Comments
 (0)