Skip to content

Commit 46b06ea

Browse files
committed
fix: avoid double vram usage during init
1 parent 68745d5 commit 46b06ea

File tree

3 files changed

+205
-144
lines changed

3 files changed

+205
-144
lines changed

posthoc_ema/karras_ema.py

Lines changed: 52 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -92,39 +92,52 @@ def __init__(
9292
if callable(ema_model) and not isinstance(ema_model, Module):
9393
ema_model = ema_model()
9494

95-
# Create EMA model on CPU
96-
self.ema_model = (ema_model if exists(ema_model) else deepcopy(model)).cpu()
97-
98-
# Ensure all parameters and buffers are on CPU and detached
99-
for p in self.ema_model.parameters():
100-
p.data = p.data.cpu().detach()
101-
for b in self.ema_model.buffers():
102-
b.data = b.data.cpu().detach()
103-
104-
# Get parameter names that require gradients
105-
self.param_names = {
106-
name
107-
for name, param in self.ema_model.named_parameters()
108-
if (not only_save_diff or param.requires_grad) and (
109-
torch.is_floating_point(param) or torch.is_complex(param)
110-
)
111-
}
112-
113-
# Get buffer names for floating point or complex buffers
114-
self.buffer_names = {
115-
name
116-
for name, buffer in self.ema_model.named_buffers()
117-
if torch.is_floating_point(buffer) or torch.is_complex(buffer)
118-
}
119-
120-
# Names to ignore
121-
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
122-
self.ignore_names = ignore_names
123-
self.ignore_startswith_names = ignore_startswith_names
124-
125-
# State buffers on CPU
126-
self.register_buffer("initted", torch.tensor(False, device="cpu"))
127-
self.register_buffer("step", torch.tensor(0, device="cpu"))
95+
# Store original device
96+
original_device = next(model.parameters()).device
97+
98+
# Move model to CPU before copying to avoid VRAM spike
99+
model.cpu()
100+
101+
try:
102+
# Create EMA model on CPU
103+
self.ema_model = (ema_model if exists(ema_model) else deepcopy(model)).cpu()
104+
105+
# Ensure all parameters and buffers are on CPU and detached
106+
for p in self.ema_model.parameters():
107+
p.data = p.data.cpu().detach()
108+
for b in self.ema_model.buffers():
109+
b.data = b.data.cpu().detach()
110+
111+
# Move model back to original device
112+
model.to(original_device)
113+
114+
# Get parameter names that require gradients
115+
self.param_names = {
116+
name
117+
for name, param in self.ema_model.named_parameters()
118+
if (not only_save_diff or param.requires_grad)
119+
and (torch.is_floating_point(param) or torch.is_complex(param))
120+
}
121+
122+
# Get buffer names for floating point or complex buffers
123+
self.buffer_names = {
124+
name
125+
for name, buffer in self.ema_model.named_buffers()
126+
if torch.is_floating_point(buffer) or torch.is_complex(buffer)
127+
}
128+
129+
# Names to ignore
130+
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
131+
self.ignore_names = ignore_names
132+
self.ignore_startswith_names = ignore_startswith_names
133+
134+
# State buffers on CPU
135+
self.register_buffer("initted", torch.tensor(False, device="cpu"))
136+
self.register_buffer("step", torch.tensor(0, device="cpu"))
137+
except:
138+
# Ensure model is moved back even if initialization fails
139+
model.to(original_device)
140+
raise
128141

129142
@property
130143
def beta(self):
@@ -239,23 +252,23 @@ def __call__(self, *args, **kwargs):
239252
def state_dict(self):
240253
"""Get state dict of EMA model."""
241254
state_dict = {}
242-
255+
243256
# Add parameters based on only_save_diff flag
244257
for name, param in self.ema_model.named_parameters():
245258
if (not self.only_save_diff or param.requires_grad) and (
246259
torch.is_floating_point(param) or torch.is_complex(param)
247260
):
248261
state_dict[name] = param
249-
262+
250263
# Add buffers (always included regardless of only_save_diff)
251264
for name, buffer in self.ema_model.named_buffers():
252265
if torch.is_floating_point(buffer) or torch.is_complex(buffer):
253266
state_dict[name] = buffer
254-
267+
255268
# Add internal state
256269
state_dict["initted"] = self.initted
257270
state_dict["step"] = self.step
258-
271+
259272
return state_dict
260273

261274
def load_state_dict(self, state_dict):
@@ -264,12 +277,12 @@ def load_state_dict(self, state_dict):
264277
for name, param in self.ema_model.named_parameters():
265278
if (not self.only_save_diff or param.requires_grad) and name in state_dict:
266279
param.data.copy_(state_dict[name].data)
267-
280+
268281
# Load buffers
269282
for name, buffer in self.ema_model.named_buffers():
270283
if name in state_dict:
271284
buffer.data.copy_(state_dict[name].data)
272-
285+
273286
# Load internal state
274287
if "initted" in state_dict:
275288
self.initted.data.copy_(state_dict["initted"].data)

0 commit comments

Comments
 (0)