Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,14 +814,37 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
device_map (`Union[int, str, torch.Device]` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be defined for each
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
same device. Defaults to `None`, meaning that the model will be loaded on CPU.

Examples:
Copy link
Contributor

@Birch-san Birch-san Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like this is just documenting the scalar cases. the bit that I need docs for is the dictionary convention. {'': device.type} as the simpest valid input is extremely hard to guess. there really needs to be an explanation of what the key of the dictionary means.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cc: @SunMarc @stevhliu

How is this documented in transformers?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also can you fix the typing for the DiffusionPipeline from_pretrained for device_map since for this specific function, we only allow balanced value ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 407b67f.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Birch-san I clarified the docs to include the case of {"": torch.device("cuda")} and have added tests for it, too. For other possible and valid dict inputs to device_map, I would have to defer you to https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap as you can notice it's hard to specify that beforehand without doing a bit of investigation.

So, I would suggest loading your model with "auto" device_map, first. And then printing (model.hf_device_map) to get a much better handle. This way, you will have a reasonable starting point which you could then use to tweak things around a bit.

Copy link
Contributor

@Birch-san Birch-san Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not clear to me from the accelerate docs how the key is used. the fact that '' works suggests there's some kind of pattern-matching or special-cases, which aren't documented.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going to defer to @SunMarc for that (again).


```py
>>> from diffusers import AutoModel
>>> import torch

>>> # This works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="cuda"
... )
>>> # This also works (integer accelerator device ID).
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map=0
... )
>>> # Specifying a supported offloading strategy like "auto" also works.
>>> model = AutoModel.from_pretrained(
... "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", device_map="auto"
... )
```

Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). You can
also refer to the [Diffusers-specific
documentation](https://huggingface.co/docs/diffusers/main/en/training/distributed_inference#model-sharding)
for more concrete examples.
max_memory (`Dict`, *optional*):
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
each GPU and the available CPU RAM if unset.
Expand Down Expand Up @@ -1387,7 +1410,7 @@ def _load_pretrained_model(
low_cpu_mem_usage: bool = True,
dtype: Optional[Union[str, torch.dtype]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
device_map: Dict[str, Union[int, str, torch.device]] = None,
device_map: Dict[Union[str, int, torch.device], Union[int, str, torch.device]] = None,
offload_state_dict: Optional[bool] = None,
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
Expand Down
19 changes: 19 additions & 0 deletions tests/models/unets/test_models_unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,25 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self):
assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16)

def test_wrong_device_map_raises_error(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing more of the failure code paths (if there are any; I am not fully aware of the relevant parts of the codebase) could be nice

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We check if the device_map is respected in https://github.com/huggingface/diffusers/blob/main/tests/models/test_modeling_common.py. Then we're testing for invalid device_map values for non-dict entries. I added another one in eb913e2.

Many errors are already handled in accelerate.

So, collectively, we should now be good I think.

with self.assertRaises(ValueError) as err_ctx:
_ = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=-1
)

msg_substring = "You can't pass device_map as a negative int"
assert msg_substring in str(err_ctx.exception)

@parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")])
@require_torch_gpu
def test_passing_non_dict_device_map_works(self, device_map):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test some more cases like: {"": torch.device("meta"), "decoder": torch.device("cuda")}? For example, if this was a VAE, the intention here is to not load the encoder weights, but directly load the decoder weights to device.

Additionally, we should probably run device map tests for all models IMO (can be taken up in future PR)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this would be fantastic, but we can probably tackle that in a separate PR, and leave the scope of this one to tests/docs/bugfixes/assertions.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additionally, we should probably run device map tests for all models IMO (can be taken up in future PR)

We have a bunch of device_map related tests already in https://github.com/huggingface/diffusers/blob/main/tests/models/test_modeling_common.py.

I can shift the current ones being added through this PR to test_modeling_common.py in a separate PR.

Can we test some more cases like: {"": torch.device("meta"), "decoder": torch.device("cuda")}? For example, if this was a VAE, the intention here is to not load the encoder weights, but directly load the decoder weights to device.

Feel free to add that in a separate PR.

_, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(
"hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map
)
output = loaded_model(**inputs_dict)
assert output.sample.shape == (4, 4, 16, 16)

@require_peft_backend
def test_load_attn_procs_raise_warning(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
Expand Down
Loading