|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import gc |
| 16 | +import tempfile |
16 | 17 | import unittest |
17 | 18 |
|
| 19 | +import numpy as np |
18 | 20 | import torch |
19 | 21 | from transformers import AutoTokenizer, T5EncoderModel |
20 | 22 |
|
@@ -85,29 +87,13 @@ def get_dummy_components(self): |
85 | 87 | rope_max_seq_len=32, |
86 | 88 | ) |
87 | 89 |
|
88 | | - torch.manual_seed(0) |
89 | | - transformer_2 = WanTransformer3DModel( |
90 | | - patch_size=(1, 2, 2), |
91 | | - num_attention_heads=2, |
92 | | - attention_head_dim=12, |
93 | | - in_channels=16, |
94 | | - out_channels=16, |
95 | | - text_dim=32, |
96 | | - freq_dim=256, |
97 | | - ffn_dim=32, |
98 | | - num_layers=2, |
99 | | - cross_attn_norm=True, |
100 | | - qk_norm="rms_norm_across_heads", |
101 | | - rope_max_seq_len=32, |
102 | | - ) |
103 | | - |
104 | 90 | components = { |
105 | 91 | "transformer": transformer, |
106 | 92 | "vae": vae, |
107 | 93 | "scheduler": scheduler, |
108 | 94 | "text_encoder": text_encoder, |
109 | 95 | "tokenizer": tokenizer, |
110 | | - "transformer_2": transformer_2, |
| 96 | + "transformer_2": None, |
111 | 97 | } |
112 | 98 | return components |
113 | 99 |
|
@@ -155,6 +141,45 @@ def test_inference(self): |
155 | 141 | def test_attention_slicing_forward_pass(self): |
156 | 142 | pass |
157 | 143 |
|
| 144 | + # _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline |
| 145 | + def test_save_load_optional_components(self, expected_max_difference=1e-4): |
| 146 | + optional_component = "transformer_2" |
| 147 | + |
| 148 | + components = self.get_dummy_components() |
| 149 | + components[optional_component] = None |
| 150 | + pipe = self.pipeline_class(**components) |
| 151 | + for component in pipe.components.values(): |
| 152 | + if hasattr(component, "set_default_attn_processor"): |
| 153 | + component.set_default_attn_processor() |
| 154 | + pipe.to(torch_device) |
| 155 | + pipe.set_progress_bar_config(disable=None) |
| 156 | + |
| 157 | + generator_device = "cpu" |
| 158 | + inputs = self.get_dummy_inputs(generator_device) |
| 159 | + torch.manual_seed(0) |
| 160 | + output = pipe(**inputs)[0] |
| 161 | + |
| 162 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 163 | + pipe.save_pretrained(tmpdir, safe_serialization=False) |
| 164 | + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) |
| 165 | + for component in pipe_loaded.components.values(): |
| 166 | + if hasattr(component, "set_default_attn_processor"): |
| 167 | + component.set_default_attn_processor() |
| 168 | + pipe_loaded.to(torch_device) |
| 169 | + pipe_loaded.set_progress_bar_config(disable=None) |
| 170 | + |
| 171 | + self.assertTrue( |
| 172 | + getattr(pipe_loaded, optional_component) is None, |
| 173 | + f"`{optional_component}` did not stay set to None after loading.", |
| 174 | + ) |
| 175 | + |
| 176 | + inputs = self.get_dummy_inputs(generator_device) |
| 177 | + torch.manual_seed(0) |
| 178 | + output_loaded = pipe_loaded(**inputs)[0] |
| 179 | + |
| 180 | + max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max() |
| 181 | + self.assertLess(max_diff, expected_max_difference) |
| 182 | + |
158 | 183 |
|
159 | 184 | @slow |
160 | 185 | @require_torch_accelerator |
|
0 commit comments