1414# limitations under the License.
1515
1616import gc
17- import tempfile
1817import unittest
1918
2019import numpy as np
@@ -128,10 +127,12 @@ def test_inference(self):
128127 max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
129128 self .assertLessEqual (max_diff , 1e-3 )
130129
130+ @unittest .skip ("Not supported." )
131131 def test_sequential_cpu_offload_forward_pass (self ):
132132 # TODO(YiYi) need to fix later
133133 pass
134134
135+ @unittest .skip ("Not supported." )
135136 def test_sequential_offload_forward_pass_twice (self ):
136137 # TODO(YiYi) need to fix later
137138 pass
@@ -141,99 +142,6 @@ def test_inference_batch_single_identical(self):
141142 expected_max_diff = 1e-3 ,
142143 )
143144
144- def test_save_load_optional_components (self ):
145- components = self .get_dummy_components ()
146- pipe = self .pipeline_class (** components )
147- pipe .to (torch_device )
148- pipe .set_progress_bar_config (disable = None )
149-
150- inputs = self .get_dummy_inputs (torch_device )
151-
152- prompt = inputs ["prompt" ]
153- generator = inputs ["generator" ]
154- num_inference_steps = inputs ["num_inference_steps" ]
155- output_type = inputs ["output_type" ]
156-
157- (
158- prompt_embeds ,
159- negative_prompt_embeds ,
160- prompt_attention_mask ,
161- negative_prompt_attention_mask ,
162- ) = pipe .encode_prompt (prompt , device = torch_device , dtype = torch .float32 , text_encoder_index = 0 )
163-
164- (
165- prompt_embeds_2 ,
166- negative_prompt_embeds_2 ,
167- prompt_attention_mask_2 ,
168- negative_prompt_attention_mask_2 ,
169- ) = pipe .encode_prompt (
170- prompt ,
171- device = torch_device ,
172- dtype = torch .float32 ,
173- text_encoder_index = 1 ,
174- )
175-
176- # inputs with prompt converted to embeddings
177- inputs = {
178- "prompt_embeds" : prompt_embeds ,
179- "prompt_attention_mask" : prompt_attention_mask ,
180- "negative_prompt_embeds" : negative_prompt_embeds ,
181- "negative_prompt_attention_mask" : negative_prompt_attention_mask ,
182- "prompt_embeds_2" : prompt_embeds_2 ,
183- "prompt_attention_mask_2" : prompt_attention_mask_2 ,
184- "negative_prompt_embeds_2" : negative_prompt_embeds_2 ,
185- "negative_prompt_attention_mask_2" : negative_prompt_attention_mask_2 ,
186- "generator" : generator ,
187- "num_inference_steps" : num_inference_steps ,
188- "output_type" : output_type ,
189- "use_resolution_binning" : False ,
190- }
191-
192- # set all optional components to None
193- for optional_component in pipe ._optional_components :
194- setattr (pipe , optional_component , None )
195-
196- output = pipe (** inputs )[0 ]
197-
198- with tempfile .TemporaryDirectory () as tmpdir :
199- pipe .save_pretrained (tmpdir )
200- pipe_loaded = self .pipeline_class .from_pretrained (tmpdir )
201- pipe_loaded .to (torch_device )
202- pipe_loaded .set_progress_bar_config (disable = None )
203-
204- for optional_component in pipe ._optional_components :
205- self .assertTrue (
206- getattr (pipe_loaded , optional_component ) is None ,
207- f"`{ optional_component } ` did not stay set to None after loading." ,
208- )
209-
210- inputs = self .get_dummy_inputs (torch_device )
211-
212- generator = inputs ["generator" ]
213- num_inference_steps = inputs ["num_inference_steps" ]
214- output_type = inputs ["output_type" ]
215-
216- # inputs with prompt converted to embeddings
217- inputs = {
218- "prompt_embeds" : prompt_embeds ,
219- "prompt_attention_mask" : prompt_attention_mask ,
220- "negative_prompt_embeds" : negative_prompt_embeds ,
221- "negative_prompt_attention_mask" : negative_prompt_attention_mask ,
222- "prompt_embeds_2" : prompt_embeds_2 ,
223- "prompt_attention_mask_2" : prompt_attention_mask_2 ,
224- "negative_prompt_embeds_2" : negative_prompt_embeds_2 ,
225- "negative_prompt_attention_mask_2" : negative_prompt_attention_mask_2 ,
226- "generator" : generator ,
227- "num_inference_steps" : num_inference_steps ,
228- "output_type" : output_type ,
229- "use_resolution_binning" : False ,
230- }
231-
232- output_loaded = pipe_loaded (** inputs )[0 ]
233-
234- max_diff = np .abs (to_np (output ) - to_np (output_loaded )).max ()
235- self .assertLess (max_diff , 1e-4 )
236-
237145 def test_feed_forward_chunking (self ):
238146 device = "cpu"
239147
0 commit comments