@@ -139,6 +139,44 @@ def test_inference(self):
139139 def test_attention_slicing_forward_pass (self ):
140140 pass
141141
142+ # _optional_components include transformer, transformer_2, but only transformer_2 is optional for this wan2.1 t2v pipeline
143+ def test_save_load_optional_components (self , expected_max_difference = 1e-4 ):
144+ optional_component = "transformer_2"
145+
146+ components = self .get_dummy_components ()
147+ components [optional_component ] = None
148+ pipe = self .pipeline_class (** components )
149+ for component in pipe .components .values ():
150+ if hasattr (component , "set_default_attn_processor" ):
151+ component .set_default_attn_processor ()
152+ pipe .to (torch_device )
153+ pipe .set_progress_bar_config (disable = None )
154+
155+ generator_device = "cpu"
156+ inputs = self .get_dummy_inputs (generator_device )
157+ torch .manual_seed (0 )
158+ output = pipe (** inputs )[0 ]
159+
160+ with tempfile .TemporaryDirectory () as tmpdir :
161+ pipe .save_pretrained (tmpdir , safe_serialization = False )
162+ pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
163+ for component in pipe_loaded .components .values ():
164+ if hasattr (component , "set_default_attn_processor" ):
165+ component .set_default_attn_processor ()
166+ pipe_loaded .to (torch_device )
167+ pipe_loaded .set_progress_bar_config (disable = None )
168+
169+ self .assertTrue (
170+ getattr (pipe_loaded , optional_component ) is None ,
171+ f"`{ optional_component } ` did not stay set to None after loading." ,
172+ )
173+
174+ inputs = self .get_dummy_inputs (generator_device )
175+ torch .manual_seed (0 )
176+ output_loaded = pipe_loaded (** inputs )[0 ]
177+
178+ max_diff = np .abs (output .detach ().cpu ().numpy () - output_loaded .detach ().cpu ().numpy ()).max ()
179+ self .assertLess (max_diff , expected_max_difference )
142180
143181@slow
144182@require_torch_accelerator
0 commit comments