Skip to content

Commit 7c8fc64

Browse files
committed
add more tests
1 parent b6c0f20 commit 7c8fc64

File tree

2 files changed

+113
-3
lines changed

2 files changed

+113
-3
lines changed

src/diffusers/hooks/group_offloading.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import glob
1516
import hashlib
1617
import os
1718
from contextlib import contextmanager, nullcontext
@@ -907,3 +908,90 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
907908
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
908909
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
909910
raise ValueError("Group offloading is not enabled for the provided module.")
911+
912+
913+
def _get_expected_safetensors_files(
914+
module: torch.nn.Module,
915+
offload_to_disk_path: str,
916+
offload_type: str,
917+
num_blocks_per_group: Optional[int] = None,
918+
) -> Set[str]:
919+
expected_files = set()
920+
921+
def get_hashed_filename(group_id: str) -> str:
922+
hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest()
923+
short_hash = hashed_id[:16]
924+
return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
925+
926+
if offload_type == "block_level":
927+
if num_blocks_per_group is None:
928+
raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
929+
930+
# Handle groups of ModuleList and Sequential blocks
931+
for name, submodule in module.named_children():
932+
if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
933+
continue
934+
935+
for i in range(0, len(submodule), num_blocks_per_group):
936+
current_modules = submodule[i : i + num_blocks_per_group]
937+
if not current_modules:
938+
continue
939+
start_idx = i
940+
end_idx = i + len(current_modules) - 1
941+
group_id = f"{name}.{start_idx}_to_{end_idx}"
942+
expected_files.add(get_hashed_filename(group_id))
943+
944+
# Handle the group for unmatched top-level modules and parameters
945+
group_id = "top_level_unmatched_modules"
946+
expected_files.add(get_hashed_filename(group_id))
947+
948+
elif offload_type == "leaf_level":
949+
# Handle leaf-level module groups
950+
for name, submodule in module.named_modules():
951+
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
952+
# These groups will always have parameters, so a file is expected
953+
expected_files.add(get_hashed_filename(name))
954+
955+
# Handle groups for non-leaf parameters/buffers
956+
modules_with_group_offloading = {
957+
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
958+
}
959+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
960+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
961+
962+
all_orphans = parameters + buffers
963+
if all_orphans:
964+
parent_to_tensors = {}
965+
module_dict = dict(module.named_modules())
966+
for tensor_name, _ in all_orphans:
967+
parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
968+
if parent_name not in parent_to_tensors:
969+
parent_to_tensors[parent_name] = []
970+
parent_to_tensors[parent_name].append(tensor_name)
971+
972+
for parent_name in parent_to_tensors:
973+
# A file is expected for each parent that gathers orphaned tensors
974+
expected_files.add(get_hashed_filename(parent_name))
975+
976+
else:
977+
raise ValueError(f"Unsupported offload_type: {offload_type}")
978+
979+
return expected_files
980+
981+
982+
def _check_safetensors_serialization(
983+
module: torch.nn.Module,
984+
offload_to_disk_path: str,
985+
offload_type: str,
986+
num_blocks_per_group: Optional[int] = None,
987+
) -> bool:
988+
if not os.path.isdir(offload_to_disk_path):
989+
return False, None, None
990+
991+
expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group)
992+
actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
993+
missing_files = expected_files - actual_files
994+
extra_files = actual_files - expected_files
995+
996+
is_correct = not missing_files and not extra_files
997+
return is_correct, extra_files, missing_files

tests/models/test_modeling_common.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from parameterized import parameterized
4141
from requests.exceptions import HTTPError
4242

43+
from diffusers.hooks.group_offloading import _check_safetensors_serialization
4344
from diffusers.models import SD3Transformer2DModel, UNet2DConditionModel
4445
from diffusers.models.attention_processor import (
4546
AttnProcessor,
@@ -1697,6 +1698,7 @@ def test_group_offloading_with_layerwise_casting(self, record_stream, offload_ty
16971698
@parameterized.expand([(False, "block_level"), (True, "leaf_level")])
16981699
@require_torch_accelerator
16991700
@torch.no_grad()
1701+
@torch.inference_mode()
17001702
def test_group_offloading_with_disk(self, record_stream, offload_type):
17011703
torch.manual_seed(0)
17021704
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
@@ -1705,11 +1707,15 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17051707
if not getattr(model, "_supports_group_offloading", True):
17061708
return
17071709

1710+
model.eval()
1711+
output_without_group_offloading = model(**inputs_dict)[0]
1712+
17081713
torch.manual_seed(0)
17091714
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
17101715
model = self.model_class(**init_dict)
17111716
model.eval()
1712-
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": 1}
1717+
num_blocks_per_group = None if offload_type == "leaf_level" else 1
1718+
additional_kwargs = {} if offload_type == "leaf_level" else {"num_blocks_per_group": num_blocks_per_group}
17131719
with tempfile.TemporaryDirectory() as tmpdir:
17141720
model.enable_group_offload(
17151721
torch_device,
@@ -1720,8 +1726,24 @@ def test_group_offloading_with_disk(self, record_stream, offload_type):
17201726
**additional_kwargs,
17211727
)
17221728
has_safetensors = glob.glob(f"{tmpdir}/*.safetensors")
1723-
assert has_safetensors, "No safetensors found in the directory."
1724-
_ = model(**inputs_dict)[0]
1729+
# Group offloading with disk support related checks.
1730+
self.assertTrue(has_safetensors, "No safetensors found in the directory.")
1731+
is_correct, extra_files, missing_files = _check_safetensors_serialization(
1732+
module=model,
1733+
offload_to_disk_path=tmpdir,
1734+
offload_type=offload_type,
1735+
num_blocks_per_group=num_blocks_per_group,
1736+
)
1737+
if not is_correct:
1738+
if extra_files:
1739+
raise ValueError(f"Found extra files: {', '.join(extra_files)}")
1740+
elif missing_files:
1741+
raise ValueError(f"Following files are missing: {', '.join(missing_files)}")
1742+
1743+
output_with_group_offloading = model(**inputs_dict)[0]
1744+
self.assertTrue(
1745+
torch.allclose(output_without_group_offloading, output_with_group_offloading, atol=1e-4, rtol=1 - 4)
1746+
)
17251747

17261748
def test_auto_model(self, expected_max_diff=5e-5):
17271749
if self.forward_requires_fresh_args:

0 commit comments

Comments
 (0)