Skip to content

Commit a544522

Browse files
committed
improve: less ram usage
1 parent 2a65aab commit a544522

File tree

4 files changed

+195
-188
lines changed

4 files changed

+195
-188
lines changed

posthoc_ema/karras_ema.py

Lines changed: 139 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def inplace_lerp(tgt: Tensor, src: Tensor, weight):
3838
src: Source tensor to interpolate towards
3939
weight: Interpolation weight between 0 and 1
4040
"""
41-
tgt.lerp_(src.to(tgt.device), weight)
41+
# Check if tensor is integer type - integer tensors can't use lerp
42+
# but we want to silently handle them instead of raising errors
43+
if tgt.dtype in [torch.int, torch.int8, torch.int16, torch.int32, torch.int64, torch.long]:
44+
tgt.copy_(src.to(tgt.device))
45+
else:
46+
tgt.lerp_(src.to(tgt.device), weight)
4247

4348

4449
class KarrasEMA(Module):
@@ -49,94 +54,73 @@ class KarrasEMA(Module):
4954
model: Model to create EMA of
5055
sigma_rel: Relative standard deviation for EMA profile
5156
gamma: Alternative parameterization via gamma (don't specify both)
52-
ema_model: Optional pre-initialized EMA model
5357
update_every: Number of steps between EMA updates
5458
frozen: Whether to freeze EMA updates
5559
param_or_buffer_names_no_ema: Parameter/buffer names to exclude from EMA
5660
ignore_names: Parameter/buffer names to ignore
5761
ignore_startswith_names: Parameter/buffer name prefixes to ignore
5862
only_save_diff: If True, only save parameters with requires_grad=True
63+
device: Device to store EMA parameters on (default='cpu')
5964
"""
6065

66+
# Buffers that should always be included in the state dict even with only_save_diff=True
67+
_ALWAYS_INCLUDE_BUFFERS = {"running_mean", "running_var", "num_batches_tracked"}
68+
6169
def __init__(
6270
self,
6371
model: Module,
6472
sigma_rel: float | None = None,
6573
gamma: float | None = None,
66-
ema_model: Module | Callable[[], Module] | None = None,
6774
update_every: int = 10,
6875
frozen: bool = False,
6976
param_or_buffer_names_no_ema: set[str] = set(),
7077
ignore_names: set[str] = set(),
7178
ignore_startswith_names: set[str] = set(),
7279
only_save_diff: bool = False,
80+
device: str = 'cpu',
7381
):
7482
super().__init__()
75-
76-
assert exists(sigma_rel) ^ exists(
77-
gamma
78-
), "either sigma_rel or gamma must be given"
79-
80-
if exists(sigma_rel):
81-
gamma = sigma_rel_to_gamma(sigma_rel)
82-
83+
84+
# Store all the configuration parameters first
8385
self.gamma = gamma
8486
self.frozen = frozen
8587
self.update_every = update_every
8688
self.only_save_diff = only_save_diff
87-
89+
self.ignore_names = ignore_names
90+
self.ignore_startswith_names = ignore_startswith_names
91+
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
92+
self.device = device
93+
94+
assert exists(sigma_rel) ^ exists(gamma), "either sigma_rel or gamma must be given"
95+
96+
if exists(sigma_rel):
97+
gamma = sigma_rel_to_gamma(sigma_rel)
98+
self.gamma = gamma
99+
88100
# Store reference to online model
89101
self.online_model = [model]
90-
91-
# Initialize EMA model
92-
if callable(ema_model) and not isinstance(ema_model, Module):
93-
ema_model = ema_model()
94-
95-
# Store original device
96-
original_device = next(model.parameters()).device
97-
98-
# Move model to CPU before copying to avoid VRAM spike
99-
model.cpu()
100-
101-
try:
102-
# Create EMA model on CPU
103-
self.ema_model = (ema_model if exists(ema_model) else deepcopy(model)).cpu()
104-
105-
# Ensure all parameters and buffers are on CPU and detached
106-
for p in self.ema_model.parameters():
107-
p.data = p.data.cpu().detach()
108-
for b in self.ema_model.buffers():
109-
b.data = b.data.cpu().detach()
110-
111-
# Move model back to original device
112-
model.to(original_device)
113-
114-
# Get parameter names for floating point or complex parameters
115-
self.param_names = {
116-
name
117-
for name, param in self.ema_model.named_parameters()
118-
if torch.is_floating_point(param) or torch.is_complex(param)
119-
}
120-
121-
# Get buffer names for floating point or complex buffers
122-
self.buffer_names = {
123-
name
124-
for name, buffer in self.ema_model.named_buffers()
125-
if torch.is_floating_point(buffer) or torch.is_complex(buffer)
126-
}
127-
128-
# Names to ignore
129-
self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema
130-
self.ignore_names = ignore_names
131-
self.ignore_startswith_names = ignore_startswith_names
132-
133-
# State buffers on CPU
134-
self.register_buffer("initted", torch.tensor(False, device="cpu"))
135-
self.register_buffer("step", torch.tensor(0, device="cpu"))
136-
except:
137-
# Ensure model is moved back even if initialization fails
138-
model.to(original_device)
139-
raise
102+
103+
# Instead of copying the whole model, just store parameter tensors
104+
self.ema_params = {}
105+
self.ema_buffers = {}
106+
107+
# Get parameter and buffer names to track
108+
with torch.no_grad():
109+
for name, param in model.named_parameters():
110+
if self._should_update_param(name):
111+
if not only_save_diff or param.requires_grad:
112+
self.ema_params[name] = param.detach().clone().to(self.device)
113+
114+
for name, buffer in model.named_buffers():
115+
if self._should_update_param(name):
116+
buffer_name = name.split('.')[-1] # Get the base name
117+
# Always include critical buffers regardless of only_save_diff
118+
if not only_save_diff or buffer.requires_grad or buffer_name in self._ALWAYS_INCLUDE_BUFFERS:
119+
self.ema_buffers[name] = buffer.detach().clone().to(self.device)
120+
121+
# State buffers
122+
self.register_buffer("initted", torch.tensor(False))
123+
self.register_buffer("step", torch.tensor(0))
140124

141125
@property
142126
def beta(self):
@@ -161,42 +145,33 @@ def update(self):
161145
def copy_params_from_model_to_ema(self):
162146
"""Copy parameters from online model to EMA model."""
163147
# Copy parameters
164-
for (name, ma_params), (_, current_params) in zip(
165-
self.get_params_iter(self.ema_model),
166-
self.get_params_iter(self.online_model[0]),
167-
):
168-
if self._should_update_param(name):
169-
inplace_copy(ma_params.data, current_params.data)
170-
171-
# Copy buffers
172-
for (name, ma_buffer), (_, current_buffer) in zip(
173-
self.get_buffers_iter(self.ema_model),
174-
self.get_buffers_iter(self.online_model[0]),
175-
):
176-
if self._should_update_param(name):
177-
inplace_copy(ma_buffer.data, current_buffer.data)
148+
with torch.no_grad():
149+
for name, param in self.online_model[0].named_parameters():
150+
if name in self.ema_params:
151+
# Explicitly move to device (usually CPU)
152+
self.ema_params[name] = param.detach().clone().to(self.device)
153+
154+
# Copy buffers
155+
for name, buffer in self.online_model[0].named_buffers():
156+
if name in self.ema_buffers:
157+
# Explicitly move to device (usually CPU)
158+
self.ema_buffers[name] = buffer.detach().clone().to(self.device)
178159

179160
def update_moving_average(self):
180161
"""Update EMA weights using current beta value."""
181162
current_decay = self.beta
182163

183-
# Update parameters
184-
for (name, current_params), (_, ma_params) in zip(
185-
self.get_params_iter(self.online_model[0]),
186-
self.get_params_iter(self.ema_model),
187-
):
188-
if not self._should_update_param(name):
189-
continue
190-
inplace_lerp(ma_params.data, current_params.data, 1.0 - current_decay)
191-
192-
# Update buffers
193-
for (name, current_buffer), (_, ma_buffer) in zip(
194-
self.get_buffers_iter(self.online_model[0]),
195-
self.get_buffers_iter(self.ema_model),
196-
):
197-
if not self._should_update_param(name):
198-
continue
199-
inplace_lerp(ma_buffer.data, current_buffer.data, 1.0 - current_decay)
164+
# Update parameters using the simplified lerp function (which now handles integer tensors)
165+
for name, current_params in self.online_model[0].named_parameters():
166+
if name in self.ema_params:
167+
# inplace_lerp now handles integer tensors internally
168+
inplace_lerp(self.ema_params[name], current_params.data, 1.0 - current_decay)
169+
170+
# Update buffers with the same simplified approach
171+
for name, current_buffer in self.online_model[0].named_buffers():
172+
if name in self.ema_buffers:
173+
# inplace_lerp now handles integer tensors internally
174+
inplace_lerp(self.ema_buffers[name], current_buffer.data, 1.0 - current_decay)
200175

201176
def _should_update_param(self, name: str) -> bool:
202177
"""Check if parameter should be updated based on ignore rules."""
@@ -208,10 +183,17 @@ def _should_update_param(self, name: str) -> bool:
208183
return False
209184
return True
210185

186+
def _parameter_requires_grad(self, name: str) -> bool:
187+
"""Check if parameter requires gradients in the online model."""
188+
for n, p in self.online_model[0].named_parameters():
189+
if n == name:
190+
return p.requires_grad
191+
return False
192+
211193
def get_params_iter(self, model):
212194
"""Get iterator over model's parameters."""
213195
for name, param in model.named_parameters():
214-
if name not in self.param_names:
196+
if name not in self.ema_params:
215197
continue
216198
if self.only_save_diff and not param.requires_grad:
217199
continue
@@ -220,17 +202,19 @@ def get_params_iter(self, model):
220202
def get_buffers_iter(self, model):
221203
"""Get iterator over model's buffers."""
222204
for name, buffer in model.named_buffers():
223-
if name not in self.buffer_names:
205+
if name not in self.ema_buffers:
224206
continue
225-
if self.only_save_diff and not buffer.requires_grad:
207+
208+
# Handle critical buffers that should always be included
209+
buffer_name = name.split('.')[-1]
210+
if self.only_save_diff and not buffer.requires_grad and buffer_name not in self._ALWAYS_INCLUDE_BUFFERS:
226211
continue
212+
227213
yield name, buffer
228214

229215
def iter_all_ema_params_and_buffers(self):
230216
"""Get iterator over all EMA parameters and buffers."""
231-
for name, param in self.ema_model.named_parameters():
232-
if name not in self.param_names:
233-
continue
217+
for name, param in self.ema_params.items():
234218
if name in self.param_or_buffer_names_no_ema:
235219
continue
236220
if name in self.ignore_names:
@@ -239,21 +223,10 @@ def iter_all_ema_params_and_buffers(self):
239223
continue
240224
yield param
241225

242-
for name, buffer in self.ema_model.named_buffers():
243-
if name not in self.buffer_names:
244-
continue
245-
if name in self.param_or_buffer_names_no_ema:
246-
continue
247-
if name in self.ignore_names:
248-
continue
249-
if any(name.startswith(prefix) for prefix in self.ignore_startswith_names):
250-
continue
251-
yield buffer
252-
253226
def iter_all_model_params_and_buffers(self, model: Module):
254227
"""Get iterator over all model parameters and buffers."""
255228
for name, param in model.named_parameters():
256-
if name not in self.param_names:
229+
if name not in self.ema_params:
257230
continue
258231
if name in self.param_or_buffer_names_no_ema:
259232
continue
@@ -263,59 +236,66 @@ def iter_all_model_params_and_buffers(self, model: Module):
263236
continue
264237
yield param
265238

266-
for name, buffer in model.named_buffers():
267-
if name not in self.buffer_names:
268-
continue
269-
if name in self.param_or_buffer_names_no_ema:
270-
continue
271-
if name in self.ignore_names:
272-
continue
273-
if any(name.startswith(prefix) for prefix in self.ignore_startswith_names):
274-
continue
275-
yield buffer
276-
277239
def __call__(self, *args, **kwargs):
278240
"""Forward pass using EMA model."""
279-
return self.ema_model(*args, **kwargs)
241+
raise NotImplementedError("KarrasEMA no longer maintains a full model copy")
242+
243+
@property
244+
def ema_model(self):
245+
"""
246+
For backward compatibility with tests.
247+
Creates a temporary model with EMA parameters.
248+
249+
Returns:
250+
Module: A copy of the online model with EMA parameters
251+
"""
252+
# Create a copy of the online model
253+
model_copy = deepcopy(self.online_model[0])
254+
255+
# Load EMA parameters into the model
256+
for name, param in model_copy.named_parameters():
257+
if name in self.ema_params:
258+
param.data.copy_(self.ema_params[name])
259+
260+
# Load EMA buffers into the model
261+
for name, buffer in model_copy.named_buffers():
262+
if name in self.ema_buffers:
263+
buffer.data.copy_(self.ema_buffers[name])
264+
265+
# Ensure the model is on CPU
266+
model_copy.to('cpu')
267+
return model_copy
280268

281269
def state_dict(self):
282-
"""Get state dict for EMA model."""
270+
"""Get state dict with EMA parameters."""
283271
state_dict = {}
284-
285-
# Save parameters based on only_save_diff flag
286-
for name, param in self.ema_model.named_parameters():
287-
if name not in self.param_names:
288-
continue
289-
if self.only_save_diff and not param.requires_grad:
290-
continue
291-
state_dict[name] = param.data
292-
293-
# Save buffers
294-
for name, buffer in self.ema_model.named_buffers():
295-
if name not in self.buffer_names:
296-
continue
297-
state_dict[name] = buffer.data
298-
299-
# Save internal state
300-
state_dict["initted"] = self.initted.data
301-
state_dict["step"] = self.step.data
302-
272+
273+
# For parameters, respect only_save_diff
274+
for name, param in self.ema_params.items():
275+
if not self.only_save_diff or self._parameter_requires_grad(name):
276+
state_dict[name] = param.data
277+
278+
# For buffers, identify which ones should always be included
279+
for name, buffer in self.ema_buffers.items():
280+
buffer_name = name.split('.')[-1] # Get the base name
281+
# Always include critical buffers regardless of only_save_diff
282+
if not self.only_save_diff or buffer_name in self._ALWAYS_INCLUDE_BUFFERS:
283+
state_dict[name] = buffer.data
284+
285+
# Add internal state
286+
state_dict["initted"] = self.initted
287+
state_dict["step"] = self.step
288+
303289
return state_dict
304290

305291
def load_state_dict(self, state_dict):
306-
"""Load state dict into EMA model."""
307-
# Load parameters based on only_save_diff flag
308-
for name, param in self.ema_model.named_parameters():
309-
if (not self.only_save_diff or param.requires_grad) and name in state_dict:
310-
param.data.copy_(state_dict[name].data)
311-
312-
# Load buffers
313-
for name, buffer in self.ema_model.named_buffers():
314-
if name in state_dict:
315-
buffer.data.copy_(state_dict[name].data)
316-
317-
# Load internal state
318-
if "initted" in state_dict:
319-
self.initted.data.copy_(state_dict["initted"].data)
320-
if "step" in state_dict:
321-
self.step.data.copy_(state_dict["step"].data)
292+
"""Load state dict with EMA parameters."""
293+
for name, param in state_dict.items():
294+
if name == "initted":
295+
self.initted.data.copy_(param)
296+
elif name == "step":
297+
self.step.data.copy_(param)
298+
elif name in self.ema_params:
299+
self.ema_params[name].data.copy_(param)
300+
elif name in self.ema_buffers:
301+
self.ema_buffers[name].data.copy_(param)

0 commit comments

Comments
 (0)