Skip to content

Commit 527b13a

Browse files
authored
Refactor norm_func in modulate function (#24)
1 parent f1ce7e2 commit 527b13a

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

opendit/models/dit.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool, use_kerne
3939
def modulate(norm_func, x, shift, scale, use_kernel=False):
4040
# Suppose x is (N, T, D), shift is (N, D), scale is (N, D)
4141
dtype = x.dtype
42-
x, shift, scale = x.to(torch.float32), shift.to(torch.float32), scale.to(torch.float32)
43-
x = norm_func(x)
42+
x = norm_func(x.to(torch.float32)).to(dtype)
4443
if use_kernel:
4544
try:
4645
from opendit.kernels.fused_modulate import fused_modulate

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,4 @@ matplotlib
2525
accelerate
2626
diffusers
2727
transformers
28-
flash_attn==2.0.5
28+
flash_attn

0 commit comments

Comments
 (0)