Skip to content

Commit 312f022

Browse files
committed
try test fix
1 parent 61143f8 commit 312f022

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/utils/testing_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,9 +1394,9 @@ def get_device_properties() -> DeviceProperties:
13941394
DevicePropertiesUserDict = UserDict
13951395

13961396
if is_torch_available():
1397+
from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
13971398
from diffusers.hooks.group_offloading import (
13981399
_GROUP_ID_LAZY_LEAF,
1399-
_SUPPORTED_PYTORCH_LAYERS,
14001400
_compute_group_hash,
14011401
_find_parent_module_in_module_dict,
14021402
_gather_buffers_with_no_group_offloading_parent,
@@ -1440,13 +1440,13 @@ def get_hashed_filename(group_id: str) -> str:
14401440
elif offload_type == "leaf_level":
14411441
# Handle leaf-level module groups
14421442
for name, submodule in module.named_modules():
1443-
if isinstance(submodule, _SUPPORTED_PYTORCH_LAYERS):
1443+
if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
14441444
# These groups will always have parameters, so a file is expected
14451445
expected_files.add(get_hashed_filename(name))
14461446

14471447
# Handle groups for non-leaf parameters/buffers
14481448
modules_with_group_offloading = {
1449-
name for name, sm in module.named_modules() if isinstance(sm, _SUPPORTED_PYTORCH_LAYERS)
1449+
name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
14501450
}
14511451
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
14521452
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)

0 commit comments

Comments
 (0)