Skip to content

Commit 278b3b8

Browse files
committed
fix geglu and swiglu gate computation
1 parent d7f369c commit 278b3b8

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/diffusers/models/memory_utils.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,18 @@ def apply_memory_optimized_feedforward(module: torch.nn.Module, num_splits: Opti
133133
num_splits = submodule._mult if num_splits is None else num_splits
134134

135135
# remap net.0.proj.weight
136-
net_0_proj = state_dict.pop("net.0.proj.weight")
137-
net_0_proj = net_0_proj.chunk(num_splits, dim=0)
138-
for i in range(num_splits):
139-
state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i]
136+
if isinstance(submodule.net[0], (GEGLU, SwiGLU)):
137+
net_0_proj = state_dict.pop("net.0.proj.weight")
138+
proj, gate = net_0_proj.chunk(2, dim=0)
139+
proj = proj.chunk(num_splits, dim=0)
140+
gate = gate.chunk(num_splits, dim=0)
141+
for i in range(num_splits):
142+
state_dict[f"proj_in.{i}.proj.weight"] = torch.cat([proj[i], gate[i]], dim=0)
143+
else:
144+
net_0_proj = state_dict.pop("net.0.proj.weight")
145+
net_0_proj = net_0_proj.chunk(num_splits, dim=0)
146+
for i in range(num_splits):
147+
state_dict[f"proj_in.{i}.proj.weight"] = net_0_proj[i]
140148

141149
# remap net.0.proj.bias
142150
if "net.0.proj.bias" in state_dict:

0 commit comments

Comments
 (0)