Skip to content

Commit 3fd31ee

Browse files
authored
[Core] introduce _no_split_modules to ModelMixin (#6396)
* introduce _no_split_modules. * unnecessary spaces. * remove unnecessary kwargs and style * fix: accelerate imports. * change to _determine_device_map * add the blocks that have residual connections. * add: CrossAttnUpBlock2D * add: testin * style * line-spaces * quality * add disk offload test without safetensors. * checking disk offloading percentages. * change model split * add: utility for checking multi-gpu requirement. * model parallelism test * splits. * splits. * splits * splits. * splits. * splits. * offload folder to test_disk_offload_with_safetensors * add _no_split_modules * fix-copies
1 parent b02e211 commit 3fd31ee

File tree

7 files changed

+221
-5
lines changed

7 files changed

+221
-5
lines changed

src/diffusers/models/autoencoders/autoencoder_kl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
6565
"""
6666

6767
_supports_gradient_checkpointing = True
68+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
6869

6970
@register_to_config
7071
def __init__(

src/diffusers/models/modeling_utils.py

Lines changed: 87 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@
5757

5858
if is_accelerate_available():
5959
import accelerate
60-
from accelerate.utils import set_module_tensor_to_device
60+
from accelerate import infer_auto_device_map
61+
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
6162
from accelerate.utils.versions import is_torch_version
6263

6364

@@ -99,6 +100,29 @@ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
99100
return first_tuple[1].dtype
100101

101102

103+
# Adapted from `transformers` (see modeling_utils.py)
104+
def _determine_device_map(model: "ModelMixin", device_map, max_memory, torch_dtype):
105+
if isinstance(device_map, str):
106+
no_split_modules = model._get_no_split_modules(device_map)
107+
device_map_kwargs = {"no_split_module_classes": no_split_modules}
108+
109+
if device_map != "sequential":
110+
max_memory = get_balanced_memory(
111+
model,
112+
dtype=torch_dtype,
113+
low_zero=(device_map == "balanced_low_0"),
114+
max_memory=max_memory,
115+
**device_map_kwargs,
116+
)
117+
else:
118+
max_memory = get_max_memory(max_memory)
119+
120+
device_map_kwargs["max_memory"] = max_memory
121+
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
122+
123+
return device_map
124+
125+
102126
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
103127
"""
104128
Reads a checkpoint file, returning properly formatted errors if they arise.
@@ -201,6 +225,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
201225
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
202226
_supports_gradient_checkpointing = False
203227
_keys_to_ignore_on_load_unexpected = None
228+
_no_split_modules = None
204229

205230
def __init__(self):
206231
super().__init__()
@@ -560,6 +585,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
560585
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
561586
)
562587

588+
# change device_map into a map if we passed an int, a str or a torch.device
589+
if isinstance(device_map, torch.device):
590+
device_map = {"": device_map}
591+
elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
592+
try:
593+
device_map = {"": torch.device(device_map)}
594+
except RuntimeError:
595+
raise ValueError(
596+
"When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
597+
f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
598+
)
599+
elif isinstance(device_map, int):
600+
if device_map < 0:
601+
raise ValueError(
602+
"You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
603+
)
604+
else:
605+
device_map = {"": device_map}
606+
607+
if device_map is not None:
608+
if low_cpu_mem_usage is None:
609+
low_cpu_mem_usage = True
610+
elif not low_cpu_mem_usage:
611+
raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
612+
613+
if low_cpu_mem_usage:
614+
if device_map is not None and not is_torch_version(">=", "1.10"):
615+
# The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
616+
raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
617+
563618
# Load config if we don't provide a configuration
564619
config_path = pretrained_model_name_or_path
565620

@@ -582,10 +637,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
582637
token=token,
583638
revision=revision,
584639
subfolder=subfolder,
585-
device_map=device_map,
586-
max_memory=max_memory,
587-
offload_folder=offload_folder,
588-
offload_state_dict=offload_state_dict,
589640
user_agent=user_agent,
590641
**kwargs,
591642
)
@@ -690,6 +741,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
690741
else: # else let accelerate handle loading and dispatching.
691742
# Load weights and dispatch according to the device_map
692743
# by default the device_map is None and the weights are loaded on the CPU
744+
device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
693745
try:
694746
accelerate.load_checkpoint_and_dispatch(
695747
model,
@@ -881,6 +933,36 @@ def _find_mismatched_keys(
881933

882934
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
883935

936+
# Adapted from `transformers` modeling_utils.py
937+
def _get_no_split_modules(self, device_map: str):
938+
"""
939+
Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
940+
get the underlying `_no_split_modules`.
941+
942+
Args:
943+
device_map (`str`):
944+
The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
945+
946+
Returns:
947+
`List[str]`: List of modules that should not be split
948+
"""
949+
_no_split_modules = set()
950+
modules_to_check = [self]
951+
while len(modules_to_check) > 0:
952+
module = modules_to_check.pop(-1)
953+
# if the module does not appear in _no_split_modules, we also check the children
954+
if module.__class__.__name__ not in _no_split_modules:
955+
if isinstance(module, ModelMixin):
956+
if module._no_split_modules is None:
957+
raise ValueError(
958+
f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
959+
"class needs to implement the `_no_split_modules` attribute."
960+
)
961+
else:
962+
_no_split_modules = _no_split_modules | set(module._no_split_modules)
963+
modules_to_check += list(module.children())
964+
return list(_no_split_modules)
965+
884966
@property
885967
def device(self) -> torch.device:
886968
"""

src/diffusers/models/transformers/transformer_2d.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
7272
"""
7373

7474
_supports_gradient_checkpointing = True
75+
_no_split_modules = ["BasicTransformerBlock"]
7576

7677
@register_to_config
7778
def __init__(

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ class conditioning with `class_embed_type` equal to `None`.
161161
"""
162162

163163
_supports_gradient_checkpointing = True
164+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
164165

165166
@register_to_config
166167
def __init__(

src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,7 @@ class conditioning with `class_embed_type` equal to `None`.
363363
"""
364364

365365
_supports_gradient_checkpointing = True
366+
_no_split_modules = ["BasicTransformerBlock", "ResnetBlockFlat", "CrossAttnUpBlockFlat"]
366367

367368
@register_to_config
368369
def __init__(

tests/models/test_modeling_common.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import numpy as np
2525
import requests_mock
2626
import torch
27+
from accelerate.utils import compute_module_sizes
2728
from huggingface_hub import ModelCard, delete_repo
2829
from huggingface_hub.utils import is_jinja_available
2930
from requests.exceptions import HTTPError
@@ -39,6 +40,7 @@
3940
require_torch_2,
4041
require_torch_accelerator_with_training,
4142
require_torch_gpu,
43+
require_torch_multi_gpu,
4244
run_test_in_subprocess,
4345
torch_device,
4446
)
@@ -200,6 +202,21 @@ class ModelTesterMixin:
200202
main_input_name = None # overwrite in model specific tester class
201203
base_precision = 1e-3
202204
forward_requires_fresh_args = False
205+
model_split_percents = [0.5, 0.7, 0.9]
206+
207+
def check_device_map_is_respected(self, model, device_map):
208+
for param_name, param in model.named_parameters():
209+
# Find device in device_map
210+
while len(param_name) > 0 and param_name not in device_map:
211+
param_name = ".".join(param_name.split(".")[:-1])
212+
if param_name not in device_map:
213+
raise ValueError("device map is incomplete, it does not contain any device for `param_name`.")
214+
215+
param_device = device_map[param_name]
216+
if param_device in ["cpu", "disk"]:
217+
self.assertEqual(param.device, torch.device("meta"))
218+
else:
219+
self.assertEqual(param.device, torch.device(param_device))
203220

204221
def test_from_save_pretrained(self, expected_max_diff=5e-5):
205222
if self.forward_requires_fresh_args:
@@ -670,6 +687,117 @@ def test_deprecated_kwargs(self):
670687
" from `_deprecated_kwargs = [<deprecated_argument>]`"
671688
)
672689

690+
@require_torch_gpu
691+
def test_cpu_offload(self):
692+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
693+
model = self.model_class(**config).eval()
694+
model = model.to(torch_device)
695+
696+
torch.manual_seed(0)
697+
base_output = model(**inputs_dict)
698+
699+
model_size = compute_module_sizes(model)[""]
700+
# We test several splits of sizes to make sure it works.
701+
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
702+
with tempfile.TemporaryDirectory() as tmp_dir:
703+
model.cpu().save_pretrained(tmp_dir)
704+
705+
for max_size in max_gpu_sizes:
706+
max_memory = {0: max_size, "cpu": model_size * 2}
707+
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
708+
# Making sure part of the model will actually end up offloaded
709+
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, "cpu"})
710+
711+
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
712+
torch.manual_seed(0)
713+
new_output = new_model(**inputs_dict)
714+
715+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
716+
717+
@require_torch_gpu
718+
def test_disk_offload_without_safetensors(self):
719+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
720+
model = self.model_class(**config).eval()
721+
model = model.to(torch_device)
722+
723+
torch.manual_seed(0)
724+
base_output = model(**inputs_dict)
725+
726+
model_size = compute_module_sizes(model)[""]
727+
with tempfile.TemporaryDirectory() as tmp_dir:
728+
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
729+
730+
with self.assertRaises(ValueError):
731+
max_size = int(self.model_split_percents[1] * model_size)
732+
max_memory = {0: max_size, "cpu": max_size}
733+
# This errors out because it's missing an offload folder
734+
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
735+
736+
max_size = int(self.model_split_percents[1] * model_size)
737+
max_memory = {0: max_size, "cpu": max_size}
738+
new_model = self.model_class.from_pretrained(
739+
tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir
740+
)
741+
742+
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
743+
torch.manual_seed(0)
744+
new_output = new_model(**inputs_dict)
745+
746+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
747+
748+
@require_torch_gpu
749+
def test_disk_offload_with_safetensors(self):
750+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
751+
model = self.model_class(**config).eval()
752+
model = model.to(torch_device)
753+
754+
torch.manual_seed(0)
755+
base_output = model(**inputs_dict)
756+
757+
model_size = compute_module_sizes(model)[""]
758+
with tempfile.TemporaryDirectory() as tmp_dir:
759+
model.cpu().save_pretrained(tmp_dir)
760+
761+
max_size = int(self.model_split_percents[1] * model_size)
762+
max_memory = {0: max_size, "cpu": max_size}
763+
new_model = self.model_class.from_pretrained(
764+
tmp_dir, device_map="auto", offload_folder=tmp_dir, max_memory=max_memory
765+
)
766+
767+
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
768+
torch.manual_seed(0)
769+
new_output = new_model(**inputs_dict)
770+
771+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
772+
773+
@require_torch_multi_gpu
774+
def test_model_parallelism(self):
775+
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
776+
model = self.model_class(**config).eval()
777+
model = model.to(torch_device)
778+
779+
torch.manual_seed(0)
780+
base_output = model(**inputs_dict)
781+
782+
model_size = compute_module_sizes(model)[""]
783+
# We test several splits of sizes to make sure it works.
784+
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
785+
with tempfile.TemporaryDirectory() as tmp_dir:
786+
model.cpu().save_pretrained(tmp_dir)
787+
788+
for max_size in max_gpu_sizes:
789+
max_memory = {0: max_size, 1: model_size * 2, "cpu": model_size * 2}
790+
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
791+
# Making sure part of the model will actually end up offloaded
792+
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
793+
794+
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
795+
796+
torch.manual_seed(0)
797+
new_output = new_model(**inputs_dict)
798+
799+
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
800+
673801

674802
@is_staging_test
675803
class ModelPushToHubTester(unittest.TestCase):

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
300300
class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
301301
model_class = UNet2DConditionModel
302302
main_input_name = "sample"
303+
# We override the items here because the unet under consideration is small.
304+
model_split_percents = [0.5, 0.3, 0.4]
303305

304306
@property
305307
def dummy_input(self):

0 commit comments

Comments
 (0)