Skip to content

Commit 1189a35

Browse files
committed
add more tests
1 parent a7bce5f commit 1189a35

File tree

3 files changed

+502
-10
lines changed

3 files changed

+502
-10
lines changed

tests/pipelines/wan/test_wan.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,44 @@ def test_inference(self):
139139
def test_attention_slicing_forward_pass(self):
140140
pass
141141

142+
# _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
143+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
144+
optional_component = "transformer_2"
145+
146+
components = self.get_dummy_components()
147+
components[optional_component] = None
148+
pipe = self.pipeline_class(**components)
149+
for component in pipe.components.values():
150+
if hasattr(component, "set_default_attn_processor"):
151+
component.set_default_attn_processor()
152+
pipe.to(torch_device)
153+
pipe.set_progress_bar_config(disable=None)
154+
155+
generator_device = "cpu"
156+
inputs = self.get_dummy_inputs(generator_device)
157+
torch.manual_seed(0)
158+
output = pipe(**inputs)[0]
159+
160+
with tempfile.TemporaryDirectory() as tmpdir:
161+
pipe.save_pretrained(tmpdir, safe_serialization=False)
162+
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
163+
for component in pipe_loaded.components.values():
164+
if hasattr(component, "set_default_attn_processor"):
165+
component.set_default_attn_processor()
166+
pipe_loaded.to(torch_device)
167+
pipe_loaded.set_progress_bar_config(disable=None)
168+
169+
self.assertTrue(
170+
getattr(pipe_loaded, optional_component) is None,
171+
f"`{optional_component}` did not stay set to None after loading.",
172+
)
173+
174+
inputs = self.get_dummy_inputs(generator_device)
175+
torch.manual_seed(0)
176+
output_loaded = pipe_loaded(**inputs)[0]
177+
178+
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
179+
self.assertLess(max_diff, expected_max_difference)
142180

143181
@slow
144182
@require_torch_accelerator

0 commit comments

Comments
 (0)