Skip to content

Commit 3c70a83

Browse files
committed
feat: add device and dtype alignment helpers for ELLA model
1 parent 73a1aa1 commit 3c70a83

File tree

2 files changed

+128
-19
lines changed

2 files changed

+128
-19
lines changed

ella.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,37 @@
2929
current_paths, _ = folder_paths.folder_names_and_paths["ella_encoder"]
3030
folder_paths.folder_names_and_paths["ella_encoder"] = (current_paths, folder_paths.supported_pt_extensions)
3131

32-
32+
# === device/dtype alignment helpers ===
33+
def _infer_float_dtype_from_embeds(d: dict):
34+
import torch
35+
for v in d.values():
36+
if torch.is_tensor(v) and v.is_floating_point():
37+
return v.dtype
38+
if isinstance(v, (list, tuple)):
39+
for t in v:
40+
if torch.is_tensor(t) and t.is_floating_point():
41+
return t.dtype
42+
if isinstance(v, dict):
43+
dt = _infer_float_dtype_from_embeds(v)
44+
if dt is not None:
45+
return dt
46+
return None
47+
48+
def _align_to_model_device_dtype(x, device, dtype):
49+
import torch
50+
if x is None:
51+
return None
52+
if torch.is_tensor(x):
53+
if x.is_floating_point():
54+
return x.to(device=device, dtype=dtype, non_blocking=True)
55+
return x.to(device=device, non_blocking=True)
56+
if isinstance(x, (list, tuple)):
57+
return type(x)(_align_to_model_device_dtype(xx, device, dtype) for xx in x)
58+
if isinstance(x, dict):
59+
return {k: _align_to_model_device_dtype(v, device, dtype) for k, v in x.items()}
60+
return x
61+
62+
# === /helpers ===
3363
def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict):
3464
num_steps = len(timesteps) - 1
3565
# print(f"creating ELLA conds for {num_steps} timesteps")
@@ -39,7 +69,14 @@ def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict):
3969
start = i / num_steps # Start percentage is calculated based on the index
4070
end = (i + 1) / num_steps # End percentage is calculated based on the next index
4171

42-
cond_ella = ella(timestep, **embeds)
72+
# cond_ella = ella(timestep, **embeds)
73+
# align dtype/device to ELLA model
74+
device = getattr(ella, "output_device", timesteps.device)
75+
want_dtype = _infer_float_dtype_from_embeds(embeds) or torch.float16
76+
_t = timestep.to(device=device, dtype=want_dtype)
77+
_embeds = _align_to_model_device_dtype(embeds, device, want_dtype)
78+
79+
cond_ella = ella(_t, **_embeds)
4380

4481
cond_ella_dict = {"start_percent": start, "end_percent": end}
4582
conds.append([cond_ella, cond_ella_dict])
@@ -69,13 +106,24 @@ def __init__(
69106
self.embeds[i][k] = CONDCrossAttn(self.embeds[i][k])
70107

71108
def process_cond(self, embeds: Dict[str, CONDCrossAttn], batch_size, **kwargs):
72-
return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
109+
# return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
110+
out = {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
111+
# align floats to a common dtype inferred from outputs (or fallback fp16)
112+
want_dtype = _infer_float_dtype_from_embeds(out) or torch.float16
113+
return _align_to_model_device_dtype(out, self.ella.output_device, want_dtype)
73114

74115
def prepare_conds(self):
116+
75117
cond_embeds = self.process_cond(self.embeds[0], 1)
76-
cond = self.ella(torch.Tensor([999]), **cond_embeds)
118+
want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16
119+
t999 = torch.tensor([999.0], device=self.ella.output_device, dtype=want_dtype)
120+
cond = self.ella(t999, **cond_embeds)
121+
77122
uncond_embeds = self.process_cond(self.embeds[1], 1)
78-
uncond = self.ella(torch.Tensor([999]), **uncond_embeds)
123+
# same dtype for consistency
124+
t999u = t999
125+
uncond = self.ella(t999u, **uncond_embeds)
126+
79127
if self.mode == APPLY_MODE_ELLA_ONLY:
80128
return cond, uncond
81129
if "clip_embeds" not in cond_embeds or "clip_embeds" not in uncond_embeds:
@@ -94,22 +142,28 @@ def __call__(self, apply_model, kwargs: dict):
94142
_device = c["c_crossattn"].device
95143

96144
time_aware_encoder_hidden_states = []
97-
for i in cond_or_uncond:
145+
# get the dtype of the target model from the cond-data of the first group
146+
# (process_cond has already aligned device to self.ella.output_device)
147+
for idx, i in enumerate(cond_or_uncond):
98148
cond_embeds = self.process_cond(self.embeds[i], input_x.size(0) // len(cond_or_uncond))
99-
h = self.ella(
100-
self.model_sampling.timestep(timestep_[0]),
101-
**cond_embeds,
102-
)
103-
if self.mode == APPLY_MODE_ELLA_ONLY:
149+
want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16
150+
151+
# timestep from sampler can be on CPU and in fp32 - we will align it
152+
t_model = self.model_sampling.timestep(timestep_[0])
153+
t_model = t_model.to(device=self.ella.output_device, dtype=want_dtype)
154+
155+
h = self.ella(t_model, **cond_embeds)
156+
157+
if self.mode == APPLY_MODE_ELLA_ONLY or "clip_embeds" not in cond_embeds:
104158
time_aware_encoder_hidden_states.append(h)
105-
continue
106-
if "clip_embeds" not in cond_embeds:
159+
else:
160+
h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1)
107161
time_aware_encoder_hidden_states.append(h)
108-
continue
109-
h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1)
110-
time_aware_encoder_hidden_states.append(h)
111162

112-
c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0).to(_device)
163+
# build a batch and move it under the downstream-UNet device
164+
hidden = torch.cat(time_aware_encoder_hidden_states, dim=0)
165+
c["c_crossattn"] = hidden.to(_device)
166+
113167

114168
return apply_model(input_x, timestep_, **c)
115169

model.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import Optional
44

55
import torch
6+
import os
7+
from safetensors.torch import load_file
68
from comfy import model_management
79
from comfy.model_patcher import ModelPatcher
810
from safetensors.torch import load_model
@@ -12,6 +14,59 @@
1214
from .utils import patch_device_empty_setter, remove_weights
1315

1416

17+
ELLA_DEBUG = os.getenv("ELLA_DEBUG", "0") in ("1", "true", "True")
18+
19+
def _count_params(m: torch.nn.Module):
20+
total = sum(p.numel() for p in m.parameters())
21+
trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
22+
return total, trainable
23+
24+
def _size_mb(m: torch.nn.Module):
25+
# estimate the size of parameters in MB
26+
bytes_total = sum(p.numel() * p.element_size() for p in m.parameters())
27+
return bytes_total / (1024 ** 2)
28+
29+
30+
31+
def load_model_lenient(model: torch.nn.Module, path: str):
32+
sd_file = load_file(path) # dict[name -> Tensor]
33+
model_sd = model.state_dict()
34+
35+
new_sd = {}
36+
skipped_shape = []
37+
extra = []
38+
casted = []
39+
40+
for k, v in sd_file.items():
41+
if k in model_sd:
42+
if model_sd[k].shape == v.shape:
43+
if model_sd[k].dtype != v.dtype:
44+
v = v.to(model_sd[k].dtype)
45+
casted.append(k)
46+
# transfer to the parameter device
47+
if v.device != model_sd[k].device:
48+
v = v.to(model_sd[k].device)
49+
new_sd[k] = v
50+
else:
51+
skipped_shape.append((k, tuple(v.shape), tuple(model_sd[k].shape)))
52+
else:
53+
extra.append(k)
54+
55+
if skipped_shape:
56+
print(f"[ELLA/load] skipped by shape: {len(skipped_shape)} (e.g. {skipped_shape[:3]})")
57+
if extra:
58+
print(f"[ELLA/load] extra keys in ckpt: {len(extra)} (e.g. {extra[:5]})")
59+
if casted:
60+
print(f"[ELLA/load] dtype casted: {len(casted)} (e.g. {casted[:5]})")
61+
62+
missing = [k for k in model_sd.keys() if k not in new_sd]
63+
if missing:
64+
print(f"[ELLA/load] missing in ckpt: {len(missing)} (e.g. {missing[:5]})")
65+
66+
model.load_state_dict(new_sd, strict=False)
67+
return model
68+
69+
1570
class AdaLayerNorm(nn.Module):
1671
def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None):
1772
super().__init__()
@@ -327,8 +382,8 @@ def __init__(self, path: str, **kwargs) -> None:
327382
self.dtype = model_management.text_encoder_dtype(self.load_device)
328383
self.output_device = model_management.intermediate_device()
329384
self.model = ELLAModel()
330-
load_model(self.model, path, strict=True)
331-
self.model.to(self.dtype) # type: ignore
385+
load_model_lenient(self.model, path)
386+
self.model.to(dtype=torch.float16)
332387
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)
333388

334389
def load_model(self):

0 commit comments

Comments
 (0)