Skip to content

Commit 30c836e

Browse files
authored
fix AdaLayerNorm (#94)
1 parent 45ec466 commit 30c836e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

diffsynth_engine/models/basic/transformer_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, dim: int, eps: float = 1e-6, device: str = "cuda:0", dtype: t
1515
self.silu = nn.SiLU()
1616

1717
def forward(self, x, emb):
18-
shift, scale = self.linear(self.silu(emb)).unsqueeze(1).chunk(2, dim=1)
18+
shift, scale = self.linear(self.silu(emb)).unsqueeze(1).chunk(2, dim=2)
1919
return modulate(self.norm(x), shift, scale)
2020

2121

@@ -27,7 +27,7 @@ def __init__(self, dim, device: str, dtype: torch.dtype):
2727
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
2828

2929
def forward(self, x, emb):
30-
shift, scale, gate = self.linear(self.silu(emb)).unsqueeze(1).chunk(3, dim=1)
30+
shift, scale, gate = self.linear(self.silu(emb)).unsqueeze(1).chunk(3, dim=2)
3131
return modulate(self.norm(x), shift, scale), gate
3232

3333

0 commit comments

Comments
 (0)