@@ -143,9 +143,9 @@ def test_pipeline_call_signature(self):
143143
144144 def _check_for_parameters (parameters , expected_parameters , param_type ):
145145 remaining_parameters = {param for param in parameters if param not in expected_parameters }
146- assert len ( remaining_parameters ) == 0 , (
147- f"Required { param_type } parameters not present: { remaining_parameters } "
148- )
146+ assert (
147+ len ( remaining_parameters ) == 0
148+ ), f"Required { param_type } parameters not present: { remaining_parameters } "
149149
150150 _check_for_parameters (self .params , input_parameters , "input" )
151151 _check_for_parameters (self .intermediate_params , intermediate_parameters , "intermediate" )
@@ -274,9 +274,9 @@ def test_to_device(self):
274274 model_devices = [
275275 component .device .type for component in pipe .components .values () if hasattr (component , "device" )
276276 ]
277- assert all (device == torch_device for device in model_devices ), (
278- "All pipeline components are not on accelerator device"
279- )
277+ assert all (
278+ device == torch_device for device in model_devices
279+ ), "All pipeline components are not on accelerator device"
280280
281281 def test_inference_is_not_nan_cpu (self ):
282282 pipe = self .get_pipeline ()
@@ -318,3 +318,13 @@ def test_num_images_per_prompt(self):
318318 images = pipe (** inputs , num_images_per_prompt = num_images_per_prompt , output = "images" )
319319
320320 assert images .shape [0 ] == batch_size * num_images_per_prompt
321+
322+ @require_accelerator
323+ def test_components_auto_cpu_offload (self ):
324+ base_pipe = self .get_pipeline ().to (torch_device )
325+ for component in base_pipe .components :
326+ assert component .device == torch_device
327+
328+ cm = ComponentsManager ()
329+ cm .enable_auto_cpu_offload (device = torch_device )
330+ offload_pipe = self .get_pipeline (components_manager = cm )
0 commit comments