|
94 | 94 |
|
95 | 95 |
|
96 | 96 | def basic_clean(text): |
97 | | - """Clean text using ftfy if available and unescape HTML entities.""" |
| 97 | + """ |
| 98 | + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py |
| 99 | + |
| 100 | + Clean text using ftfy if available and unescape HTML entities. |
| 101 | + """ |
98 | 102 | if is_ftfy_available(): |
99 | 103 | text = ftfy.fix_text(text) |
100 | 104 | text = html.unescape(html.unescape(text)) |
101 | 105 | return text.strip() |
102 | 106 |
|
103 | 107 |
|
104 | 108 | def whitespace_clean(text): |
105 | | - """Normalize whitespace in text by replacing multiple spaces with single space.""" |
| 109 | + """ |
| 110 | + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py |
| 111 | + |
| 112 | + Normalize whitespace in text by replacing multiple spaces with single space. |
| 113 | + """ |
106 | 114 | text = re.sub(r"\s+", " ", text) |
107 | 115 | text = text.strip() |
108 | 116 | return text |
109 | 117 |
|
110 | 118 |
|
111 | 119 | def prompt_clean(text): |
112 | | - """Apply both basic cleaning and whitespace normalization to prompts.""" |
| 120 | + """ |
| 121 | + Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py |
| 122 | + |
| 123 | + Apply both basic cleaning and whitespace normalization to prompts. |
| 124 | + """ |
113 | 125 | text = whitespace_clean(basic_clean(text)) |
114 | 126 | return text |
115 | 127 |
|
@@ -396,6 +408,53 @@ def _encode_prompt_clip( |
396 | 408 | pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] |
397 | 409 |
|
398 | 410 | return pooled_embed.to(dtype) |
| 411 | + |
| 412 | + @staticmethod |
| 413 | + def adaptive_mean_std_normalization(source, reference): |
| 414 | + source_mean = source.mean(dim=(1,2,3,4),keepdim=True) |
| 415 | + source_std = source.std(dim=(1,2,3,4),keepdim=True) |
| 416 | + #magic constants - limit changes in latents |
| 417 | + clump_mean_low = 0.05 |
| 418 | + clump_mean_high = 0.1 |
| 419 | + clump_std_low = 0.1 |
| 420 | + clump_std_high = 0.25 |
| 421 | + |
| 422 | + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) |
| 423 | + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) |
| 424 | + |
| 425 | + # normalization |
| 426 | + normalized = (source - source_mean) / source_std |
| 427 | + normalized = normalized * reference_std + reference_mean |
| 428 | + |
| 429 | + return normalized |
| 430 | + |
| 431 | + def normalize_first_frame(self, latents, reference_frames=5, clump_values=False): |
| 432 | + latents_copy = latents.clone() |
| 433 | + samples = latents_copy |
| 434 | + |
| 435 | + if samples.shape[1] <= 1: |
| 436 | + return (latents, "Only one frame, no normalization needed") |
| 437 | + |
| 438 | + nFr = 4 |
| 439 | + first_frames = samples.clone()[:, :nFr] |
| 440 | + reference_frames_data = samples[:, nFr:nFr + min(reference_frames, samples.shape[1] - 1)] |
| 441 | + |
| 442 | + print(samples.shape, first_frames.shape, reference_frames_data.shape, nFr, min(reference_frames, samples.shape[1] - 1)) |
| 443 | + |
| 444 | + print(reference_frames_data.mean(), reference_frames_data.std(), reference_frames_data.shape) |
| 445 | + |
| 446 | + print("First frame stats - Mean:", first_frames.mean(dim=(1,2,3)), "Std: ", first_frames.std(dim=(1,2,3))) |
| 447 | + print(f"Reference frames stats - Mean: {reference_frames_data.mean().item():.4f}, Std: {reference_frames_data.std().item():.4f}") |
| 448 | + |
| 449 | + normalized_first = self.adaptive_mean_std_normalization(first_frames, reference_frames_data) |
| 450 | + if clump_values: |
| 451 | + min_val = reference_frames_data.min() |
| 452 | + max_val = reference_frames_data.max() |
| 453 | + normalized_first = torch.clamp(normalized_first, min_val, max_val) |
| 454 | + |
| 455 | + samples[:, :nFr] = normalized_first |
| 456 | + |
| 457 | + return samples |
399 | 458 |
|
400 | 459 | def encode_prompt( |
401 | 460 | self, |
@@ -973,8 +1032,11 @@ def __call__( |
973 | 1032 |
|
974 | 1033 | # 9. Post-processing - extract main latents |
975 | 1034 | latents = latents[:, :, :, :, :num_channels_latents] |
| 1035 | + |
| 1036 | + # 10. fix mesh artifacts |
| 1037 | + latents = self.normalize_first_frame(latents) |
976 | 1038 |
|
977 | | - # 10. Decode latents to video |
| 1039 | + # 11. Decode latents to video |
978 | 1040 | if output_type != "latent": |
979 | 1041 | latents = latents.to(self.vae.dtype) |
980 | 1042 | # Reshape and normalize latents |
|
0 commit comments