Skip to content

Commit 7f724a9

Browse files
authored
fix the cpu offload tests (#7544)
fix
1 parent 9bef9f4 commit 7f724a9

File tree

1 file changed

+27
-22
lines changed

1 file changed

+27
-22
lines changed

tests/pipelines/test_pipelines_common.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1144,20 +1144,24 @@ def test_cpu_offload_forward_pass_twice(self, expected_max_diff=2e-4):
11441144
self.assertLess(
11451145
max_diff, expected_max_diff, "running CPU offloading 2nd time should not affect the inference results"
11461146
)
1147-
offloaded_modules = [
1148-
v
1147+
offloaded_modules = {
1148+
k: v
11491149
for k, v in pipe.components.items()
11501150
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1151-
]
1152-
(
1153-
self.assertTrue(all(v.device.type == "cpu" for v in offloaded_modules)),
1154-
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'cpu']}",
1151+
}
1152+
self.assertTrue(
1153+
all(v.device.type == "cpu" for v in offloaded_modules.values()),
1154+
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'cpu']}",
11551155
)
11561156

1157-
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
1158-
(
1159-
self.assertTrue(all(isinstance(v, accelerate.hooks.CpuOffload) for v in offloaded_modules_with_hooks)),
1160-
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.CpuOffload)]}",
1157+
offloaded_modules_with_incorrect_hooks = {}
1158+
for k, v in offloaded_modules.items():
1159+
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.CpuOffload):
1160+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
1161+
1162+
self.assertTrue(
1163+
len(offloaded_modules_with_incorrect_hooks) == 0,
1164+
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
11611165
)
11621166

11631167
@unittest.skipIf(
@@ -1189,22 +1193,23 @@ def test_sequential_offload_forward_pass_twice(self, expected_max_diff=2e-4):
11891193
self.assertLess(
11901194
max_diff, expected_max_diff, "running sequential offloading second time should have the inference results"
11911195
)
1192-
offloaded_modules = [
1193-
v
1196+
offloaded_modules = {
1197+
k: v
11941198
for k, v in pipe.components.items()
11951199
if isinstance(v, torch.nn.Module) and k not in pipe._exclude_from_cpu_offload
1196-
]
1197-
(
1198-
self.assertTrue(all(v.device.type == "meta" for v in offloaded_modules)),
1199-
f"Not offloaded: {[v for v in offloaded_modules if v.device.type != 'meta']}",
1200+
}
1201+
self.assertTrue(
1202+
all(v.device.type == "meta" for v in offloaded_modules.values()),
1203+
f"Not offloaded: {[k for k, v in offloaded_modules.items() if v.device.type != 'meta']}",
12001204
)
1205+
offloaded_modules_with_incorrect_hooks = {}
1206+
for k, v in offloaded_modules.items():
1207+
if hasattr(v, "_hf_hook") and not isinstance(v._hf_hook, accelerate.hooks.AlignDevicesHook):
1208+
offloaded_modules_with_incorrect_hooks[k] = type(v._hf_hook)
12011209

1202-
offloaded_modules_with_hooks = [v for v in offloaded_modules if hasattr(v, "_hf_hook")]
1203-
(
1204-
self.assertTrue(
1205-
all(isinstance(v, accelerate.hooks.AlignDevicesHook) for v in offloaded_modules_with_hooks)
1206-
),
1207-
f"Not installed correct hook: {[v for v in offloaded_modules_with_hooks if not isinstance(v, accelerate.hooks.AlignDevicesHook)]}",
1210+
self.assertTrue(
1211+
len(offloaded_modules_with_incorrect_hooks) == 0,
1212+
f"Not installed correct hook: {offloaded_modules_with_incorrect_hooks}",
12081213
)
12091214

12101215
@unittest.skipIf(

0 commit comments

Comments
 (0)