File tree Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Expand file tree Collapse file tree 2 files changed +22
-1
lines changed Original file line number Diff line number Diff line change 6767 numpy_to_pil ,
6868)
6969from ..utils .hub_utils import _check_legacy_sharding_variant_format , load_or_create_model_card , populate_model_card
70+ from ..utils .testing_utils import torch_device
7071from ..utils .torch_utils import empty_device_cache , get_device , is_compiled_module
7172
7273
109110 LIBRARIES .append (library )
110111
111112# TODO: support single-device namings
112- SUPPORTED_DEVICE_MAP = ["balanced" , "cuda" ]
113+ SUPPORTED_DEVICE_MAP = ["balanced" ] + [ torch_device ]
113114
114115logger = logging .get_logger (__name__ )
115116
Original file line number Diff line number Diff line change @@ -2339,6 +2339,26 @@ def test_torch_dtype_dict(self):
23392339 f"Component '{ name } ' has dtype { component .dtype } but expected { expected_dtype } " ,
23402340 )
23412341
2342+ @require_torch_accelerator
2343+ def test_pipeline_with_accelerator_device_map (self , expected_max_difference = 1e-4 ):
2344+ components = self .get_dummy_components ()
2345+ pipe = self .pipeline_class (** components )
2346+ pipe = pipe .to (torch_device )
2347+ pipe .set_progress_bar_config (disable = None )
2348+
2349+ torch .manual_seed (0 )
2350+ inputs = self .get_dummy_inputs (torch_device )
2351+ inputs ["generator" ] = torch .manual_seed (0 )
2352+ out = pipe (** inputs )[0 ]
2353+
2354+ with tempfile .TemporaryDirectory () as tmpdir :
2355+ pipe .save_pretrained (tmpdir )
2356+ loaded_pipe = self .pipeline_class .from_pretrained (tmpdir , device_map = torch_device )
2357+ inputs = self .get_dummy_inputs (torch_device )
2358+ loaded_out = loaded_pipe (** inputs )[0 ]
2359+ max_diff = np .abs (to_np (out ) - to_np (loaded_out )).max ()
2360+ self .assertLess (max_diff , expected_max_difference )
2361+
23422362
23432363@is_staging_test
23442364class PipelinePushToHubTester (unittest .TestCase ):
You can’t perform that action at this time.
0 commit comments