Skip to content

Commit 0270a0b

Browse files
Reduce artifacts on Wan by doing the patch embedding in fp32.
1 parent 26c7baf commit 0270a0b

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

comfy/ldm/wan/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def sinusoidal_embedding_1d(dim, position):
1818
# preprocess
1919
assert dim % 2 == 0
2020
half = dim // 2
21-
position = position.type(torch.float64)
21+
position = position.type(torch.float32)
2222

2323
# calculation
2424
sinusoid = torch.outer(
@@ -353,7 +353,7 @@ def __init__(self,
353353

354354
# embeddings
355355
self.patch_embedding = operations.Conv3d(
356-
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
356+
in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
357357
self.text_embedding = nn.Sequential(
358358
operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
359359
operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
@@ -411,7 +411,7 @@ def forward_orig(
411411
List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
412412
"""
413413
# embeddings
414-
x = self.patch_embedding(x)
414+
x = self.patch_embedding(x.float()).to(x.dtype)
415415
grid_sizes = x.shape[2:]
416416
x = x.flatten(2).transpose(1, 2)
417417

0 commit comments

Comments
 (0)