|  | 
| 12 | 12 | # See the License for the specific language governing permissions and | 
| 13 | 13 | # limitations under the License. | 
| 14 | 14 | 
 | 
|  | 15 | +import glob | 
| 15 | 16 | import hashlib | 
| 16 | 17 | import os | 
| 17 | 18 | from contextlib import contextmanager, nullcontext | 
| @@ -907,3 +908,90 @@ def _get_group_onload_device(module: torch.nn.Module) -> torch.device: | 
| 907 | 908 |         if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None: | 
| 908 | 909 |             return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device | 
| 909 | 910 |     raise ValueError("Group offloading is not enabled for the provided module.") | 
|  | 911 | + | 
|  | 912 | + | 
|  | 913 | +def _get_expected_safetensors_files( | 
|  | 914 | +    module: torch.nn.Module, | 
|  | 915 | +    offload_to_disk_path: str, | 
|  | 916 | +    offload_type: str, | 
|  | 917 | +    num_blocks_per_group: Optional[int] = None, | 
|  | 918 | +) -> Set[str]: | 
|  | 919 | +    expected_files = set() | 
|  | 920 | + | 
|  | 921 | +    def get_hashed_filename(group_id: str) -> str: | 
|  | 922 | +        hashed_id = hashlib.sha256(group_id.encode("utf-8")).hexdigest() | 
|  | 923 | +        short_hash = hashed_id[:16] | 
|  | 924 | +        return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") | 
|  | 925 | + | 
|  | 926 | +    if offload_type == "block_level": | 
|  | 927 | +        if num_blocks_per_group is None: | 
|  | 928 | +            raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") | 
|  | 929 | + | 
|  | 930 | +        # Handle groups of ModuleList and Sequential blocks | 
|  | 931 | +        for name, submodule in module.named_children(): | 
|  | 932 | +            if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): | 
|  | 933 | +                continue | 
|  | 934 | + | 
|  | 935 | +            for i in range(0, len(submodule), num_blocks_per_group): | 
|  | 936 | +                current_modules = submodule[i : i + num_blocks_per_group] | 
|  | 937 | +                if not current_modules: | 
|  | 938 | +                    continue | 
|  | 939 | +                start_idx = i | 
|  | 940 | +                end_idx = i + len(current_modules) - 1 | 
|  | 941 | +                group_id = f"{name}.{start_idx}_to_{end_idx}" | 
|  | 942 | +                expected_files.add(get_hashed_filename(group_id)) | 
|  | 943 | + | 
|  | 944 | +        # Handle the group for unmatched top-level modules and parameters | 
|  | 945 | +        group_id = "top_level_unmatched_modules" | 
|  | 946 | +        expected_files.add(get_hashed_filename(group_id)) | 
|  | 947 | + | 
|  | 948 | +    elif offload_type == "leaf_level": | 
|  | 949 | +        # Handle leaf-level module groups | 
|  | 950 | +        for name, submodule in module.named_modules(): | 
|  | 951 | +            if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS): | 
|  | 952 | +                # These groups will always have parameters, so a file is expected | 
|  | 953 | +                expected_files.add(get_hashed_filename(name)) | 
|  | 954 | + | 
|  | 955 | +        # Handle groups for non-leaf parameters/buffers | 
|  | 956 | +        modules_with_group_offloading = { | 
|  | 957 | +            name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS) | 
|  | 958 | +        } | 
|  | 959 | +        parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) | 
|  | 960 | +        buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) | 
|  | 961 | + | 
|  | 962 | +        all_orphans = parameters + buffers | 
|  | 963 | +        if all_orphans: | 
|  | 964 | +            parent_to_tensors = {} | 
|  | 965 | +            module_dict = dict(module.named_modules()) | 
|  | 966 | +            for tensor_name, _ in all_orphans: | 
|  | 967 | +                parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict) | 
|  | 968 | +                if parent_name not in parent_to_tensors: | 
|  | 969 | +                    parent_to_tensors[parent_name] = [] | 
|  | 970 | +                parent_to_tensors[parent_name].append(tensor_name) | 
|  | 971 | + | 
|  | 972 | +            for parent_name in parent_to_tensors: | 
|  | 973 | +                # A file is expected for each parent that gathers orphaned tensors | 
|  | 974 | +                expected_files.add(get_hashed_filename(parent_name)) | 
|  | 975 | + | 
|  | 976 | +    else: | 
|  | 977 | +        raise ValueError(f"Unsupported offload_type: {offload_type}") | 
|  | 978 | + | 
|  | 979 | +    return expected_files | 
|  | 980 | + | 
|  | 981 | + | 
|  | 982 | +def _check_safetensors_serialization( | 
|  | 983 | +    module: torch.nn.Module, | 
|  | 984 | +    offload_to_disk_path: str, | 
|  | 985 | +    offload_type: str, | 
|  | 986 | +    num_blocks_per_group: Optional[int] = None, | 
|  | 987 | +) -> bool: | 
|  | 988 | +    if not os.path.isdir(offload_to_disk_path): | 
|  | 989 | +        return False, None, None | 
|  | 990 | + | 
|  | 991 | +    expected_files = _get_expected_safetensors_files(module, offload_to_disk_path, offload_type, num_blocks_per_group) | 
|  | 992 | +    actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) | 
|  | 993 | +    missing_files = expected_files - actual_files | 
|  | 994 | +    extra_files = actual_files - expected_files | 
|  | 995 | + | 
|  | 996 | +    is_correct = not missing_files and not extra_files | 
|  | 997 | +    return is_correct, extra_files, missing_files | 
0 commit comments